Custom Exporter

The exporter implemented in this package is built upon a classic architecture with two main classes:

  • a GraphBuilder, it is a container for created nodes and initializers, it stores additional information such as shapes, types, constants, it providers methods to easily create nodes, provides unique names.

  • a DynamoInterpreter, this class goes through the model described as a GraphModule and calls the appropriate converting functions to translate every call into an equivalent ONNX graph.

Both classes are usually not seen by the user. They are called either by function to_onnx which converts a model into an ONNX graph or through a custom backend:

This second backend calls the reference implementation through class ExtendedReferenceEvaluator. This class extends ReferenceEvaluator from package onnx.

One objective: SPEED

The only objective for this exporter is speed. It must be fast as the size of the model to convert grows fast. The exporter may be one piece of the backend calling onnxruntime. This only objective implies a few constraints.

multi-opset support

The converter must support the conversion to different opsets to avoid using the onnx.version_converter.convert_version() which does not fully work when the model includes other domain than the main one.

use shape and type information

The GraphModule comes with the shape and type information of the tensor it manipulates. It must be used to optimize the onnx graph rather than using an optimizer after the conversion happens.

no decorators, no code interpretation

Writing efficient code is easier when the code you see is the code you get. A decorator hides some logic a developper must take into account to avoid writing non efficient code. On the same level, translating a python code into ONNX requires extra logic the developper does not control.

no fallback

The implementation fails if it cannot find a solution to convert the model into ONNX. There are some ways to go around that but there are not enabled by default. The user must know if the exporter follows a different way to produce the model.

GraphBuilder

GraphBuilder starts from empty or take an existing graph as an input. In that case, the builder is usually used by an optimizer.

Internal containers

Beside the onnx structure, the builder holds information about the requested opsets and the dynamic shapes. During the conversion, it stores informations about

  • _unique_names: names already taken for results

  • _unique_node_names: names already taken for node node

  • _known_names: existing names

  • _known_types: known type for every result, it must exist

  • _known_shapes: known shape for every result, either shape or rank is known

  • _known_ranks: declared ranks

  • _known_value_shape: results known as shapes, the implementation tries to capture the logic with string, sympy could be used

The model stores some constant, the builder assumes every node taking only constant as inputs produces a new constant.

  • constants_: constant values

  • constants_computed_: computed constant values, constant built from constant, every computed constant is cached,

The builder tries to minimize the number of intializers to create. It stores a unique value for the small one:

  • _values: cache initializer value to merge those which are equal

The forward/backward graphs may dynamic dimension as input. Some results are reshaped based on this inputs. The following container keep track of this information.

  • dynamic_objects: list of dynamic dimensions coming as inputs

  • dynamic_objects_rev: reverse dictionary to fasten lookups

  • _dynamic_alias: used when the user gives a different

    name to the dynamic shapes

Next container store dynamic shapes.

  • _cache_shape: cache concatenation of shapes

API

The following methods are used to add onnx elements to the graph.

  • get_opset: get the value for a domain

  • make_tensor_input: adds an input to the graph, is_dimension specifies if this input is a dynamic dimension, a single integer,

  • make_tensor_output: adds an output to the graph, is_dimension specifies if this output is a dynamic dimension, a single integer,

  • make_initializer: this method is used to add initializer to the graph,

  • make_node: add a node to the graph

  • to_onnx: produces the final ONNX

Some needs are very common and deserve a dedicated method.

  • make_nodes: adds many nodes in one row, it renames the intermediate result if needed.

  • get_attribute: retrieve an attribute from a NodeProto

  • make_shape_from_results: makes a shape from a tuple having integer, string, or torch.SymInt

It is important to update the shape the information is available.

Get the information:

Set the information:

A function used to provide information to the user and calls in most of the error message:

assert name in self._known_ranks, (
  f"Rank is unknown for result {name!r}, "
  f"known_shapes={self._known_ranks}{self.get_debug_msg()}"
)

Example

<<<

import numpy as np
from onnx import TensorProto
from experimental_experiment.xbuilder import GraphBuilder
from experimental_experiment.reference import ExtendedReferenceEvaluator
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot


gr = GraphBuilder(18, ir_version=9)
gr.make_tensor_input("X", TensorProto.FLOAT, ("a", "b"), is_dimension=False)
weight = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32).T)
bias = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32))
mm = gr.make_node("MatMul", ["X", weight], name="n")
out = gr.make_node("Add", [mm, bias], ["Y"], name="n")
gr.make_tensor_output(out, TensorProto.FLOAT, ("a",), indexed=False, is_dimension=False)
onx = gr.to_onnx()

ref = ExtendedReferenceEvaluator(onx)
x = np.random.rand(5, 3).astype(np.float32)
y = ref.run(None, {"X": x})[0]
print(y)

print(onnx_simple_text_plot(onx))

print("Without any information, the known shapes are:")
print(gr._known_shapes)

print("Without any information, the known shapes are:")
print(gr.constants_)

print("The constant are not converted into TensorProto until the very end:")
print(gr.initializers_dict)

>>>

    [[1.383 1.483 1.583]
     [1.052 1.152 1.252]
     [0.915 1.015 1.115]
     [1.327 1.427 1.527]
     [0.962 1.062 1.162]]
    opset: domain='' version=18
    input: name='X' type=dtype('float32') shape=['a', 'b']
    init: name='init1_s3x1_' type=dtype('float32') shape=(3, 1) -- array([0.4, 0.5, 0.6], dtype=float32)
    init: name='init1_s1x3_' type=dtype('float32') shape=(1, 3) -- array([0.4, 0.5, 0.6], dtype=float32)
    MatMul(X, init1_s3x1_) -> _onx_matmul0
      Add(_onx_matmul0, init1_s1x3_) -> Y
    output: name='Y' type=dtype('float32') shape=['a']
    Without any information, the known shapes are:
    {'X': ('a', 'b'), 'init1_s3x1_': (3, 1), 'init1_s1x3_': (1, 3), '_onx_matmul0': ('a', 1), 'Y': ('a', 3)}
    Without any information, the known shapes are:
    {'init1_s3x1_': None, 'init1_s1x3_': None}
    The constant are not converted into TensorProto until the very end:
    {'init1_s3x1_': array([[0.4],
           [0.5],
           [0.6]], dtype=float32), 'init1_s1x3_': array([[0.4, 0.5, 0.6]], dtype=float32)}

The constant are only computed on demand. Their conversion to TensorProto only happens when method to_onnx is called.

Debugging

An exception is raised an error is detected and it displays the result of get_debug_msg.

<<<

import numpy as np
from onnx import TensorProto
from experimental_experiment.xbuilder import GraphBuilder
from experimental_experiment.reference import ExtendedReferenceEvaluator
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot


gr = GraphBuilder(18, ir_version=9)
gr.make_tensor_input("X", TensorProto.FLOAT, ("a", "b"), is_dimension=False)
weight = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32).T)
bias = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32))
mm = gr.make_node("MatMul", ["X", weight], name="N1")
out = gr.make_node("Add", [mm, "bias"], ["Y"], name="N2")
gr.make_tensor_output(out, TensorProto.FLOAT, ("a",), indexed=False, is_dimension=False)
onx = gr.to_onnx()

>>>

    
    [runpythonerror]
    
    Traceback (most recent call last):
        exec(obj, globs, loc)
      File "", line 20, in <module>
      File "", line 17, in run_python_script_140594843379776
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py", line 1894, in make_node
        assert self.has_name(i), (
    AssertionError: Input 'bias' does not exist for operator 'Add' (NKQ)
    --DEBUG--
    --SHAPE--
    dynamic_objects={'a': 'a', 'b': 'b'}
    dynamic_objects_rev={'a': ['a'], 'b': ['b']}
    dynamic_alias={}
    dynamic_shapes=None
    _known_value_shape={}
    _known_shapes={'X': ('a', 'b'),
     '_onx_matmul0': ('a', 1),
     'init1_s1x3_': (1, 3),
     'init1_s3x1_': (3, 1)}
    _known_ranks={}
    --TORCH-SHAPES--
    --ONNX--
    --
    [GraphBuilder-NKQ.make_tensor_input] X[1:axb]
    [GraphBuilder-NKQ.make_initializer] init1_s3x1_[float32:float32:[0.4000000059604645, 0.5, 0.6000000238418579]]
    [GraphBuilder-NKQ.make_initializer] init1_s1x3_[float32:float32:[0.4000000059604645, 0.5, 0.6000000238418579]]
    [GraphBuilder-NKQ.make_node] N1              [##:#  ] MatMul:['X', 'init1_s3x1_']->['_onx_matmul0']
    [GraphBuilder-NKQ] Message completed, there are 2 initializers, 1 nodes, 1 inputs, 1 outputs.

It shows the information currently available while building the model. At the end the following lines appear.

[GraphBuilder-EAQ.make_node] N1              [##:-  ] MatMul:['X', 'init1_s3x1_']->['_onx_matmul0']

It says one node named N1 was created. ## means the shape and type are known for the two inputs it has. - means nothing is known for the output. When the type is specified, it shows the following:

<<<

import numpy as np
from onnx import TensorProto
from experimental_experiment.xbuilder import GraphBuilder
from experimental_experiment.reference import ExtendedReferenceEvaluator
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot


gr = GraphBuilder(18, ir_version=9)
gr.make_tensor_input("X", TensorProto.FLOAT, ("a", "b"), is_dimension=False)
weight = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32).T)
bias = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32))
mm = gr.make_node("MatMul", ["X", weight], name="N1")
gr.set_type(mm, TensorProto.FLOAT)
out = gr.make_node("Add", [mm, "bias"], ["Y"], name="N2")
gr.make_tensor_output(out, TensorProto.FLOAT, ("a",), indexed=False, is_dimension=False)
onx = gr.to_onnx()

>>>

    
    [runpythonerror]
    
    Traceback (most recent call last):
        exec(obj, globs, loc)
      File "", line 21, in <module>
      File "", line 18, in run_python_script_140594621299840
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py", line 1894, in make_node
        assert self.has_name(i), (
    AssertionError: Input 'bias' does not exist for operator 'Add' (QRO)
    --DEBUG--
    --SHAPE--
    dynamic_objects={'a': 'a', 'b': 'b'}
    dynamic_objects_rev={'a': ['a'], 'b': ['b']}
    dynamic_alias={}
    dynamic_shapes=None
    _known_value_shape={}
    _known_shapes={'X': ('a', 'b'),
     '_onx_matmul0': ('a', 1),
     'init1_s1x3_': (1, 3),
     'init1_s3x1_': (3, 1)}
    _known_ranks={}
    --TORCH-SHAPES--
    --ONNX--
    --
    [GraphBuilder-QRO.make_tensor_input] X[1:axb]
    [GraphBuilder-QRO.make_initializer] init1_s3x1_[float32:float32:[0.4000000059604645, 0.5, 0.6000000238418579]]
    [GraphBuilder-QRO.make_initializer] init1_s1x3_[float32:float32:[0.4000000059604645, 0.5, 0.6000000238418579]]
    [GraphBuilder-QRO.make_node] N1              [##:#  ] MatMul:['X', 'init1_s3x1_']->['_onx_matmul0']
    [GraphBuilder-QRO] Message completed, there are 2 initializers, 1 nodes, 1 inputs, 1 outputs.

It shows U when the type and rank are known, # if the type and shape are known.

[GraphBuilder-MJG.make_node] N1              [##:U  ] MatMul:['X', 'init1_s3x1_']->['_onx_matmul0']

Simplified API

For the most common nodes, there exists a shortcut to make the syntax shorter.

<<<

import numpy as np
from onnx import TensorProto
from experimental_experiment.xbuilder import GraphBuilder
from experimental_experiment.reference import ExtendedReferenceEvaluator
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot


gr = GraphBuilder(18, ir_version=9)
gr.make_tensor_input("X", TensorProto.FLOAT, ("a", "b"), is_dimension=False)
mm = gr.op.MatMul("X", np.array([[0.4, 0.5, 0.6]], dtype=np.float32).T)
out = gr.op.Add(mm, np.array([0.4, 0.5, 0.6], dtype=np.float32), outputs=["Y"])
gr.make_tensor_output(out, TensorProto.FLOAT, ("a",), indexed=False, is_dimension=False)
onx = gr.to_onnx()

ref = ExtendedReferenceEvaluator(onx)
x = np.random.rand(5, 3).astype(np.float32)
y = ref.run(None, {"X": x})[0]

print(y)

print(onnx_simple_text_plot(onx))

>>>

    [[0.801 0.901 1.001]
     [0.994 1.094 1.194]
     [1.246 1.346 1.446]
     [1.108 1.208 1.308]
     [1.22  1.32  1.42 ]]
    opset: domain='' version=18
    input: name='X' type=dtype('float32') shape=['a', 'b']
    init: name='init1_s3x1_' type=dtype('float32') shape=(3, 1) -- array([0.4, 0.5, 0.6], dtype=float32)
    init: name='init1_s3_' type=dtype('float32') shape=(3,) -- array([0.4, 0.5, 0.6], dtype=float32)
    MatMul(X, init1_s3x1_) -> _onx_matmul0
      Add(_onx_matmul0, init1_s3_) -> Y
    output: name='Y' type=dtype('float32') shape=['a']

Optimizations

GraphBuilder implements three basic optimizations algorithms not using patterns. Except constant folding, they are called by default.

DynamoInterpreter

Class DynamoInterpreter walks through a graph module and selects the best translation for every part. It is a sequence of calls to internal functions called aten functions. It looks like the following:

<<<

import torch


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


x = torch.rand(5, 3)
model = Neuron(3, 1)
graph = torch.export.export(model, (x,))
print(graph)

>>>

    [2024-05-08 14:20:26,059] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/_custom_ops.py:253: DeprecationWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
      return torch.library.impl_abstract(qualname, func, _stacklevel=2)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/_custom_ops.py:253: DeprecationWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
      return torch.library.impl_abstract(qualname, func, _stacklevel=2)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/_custom_ops.py:253: DeprecationWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
      return torch.library.impl_abstract(qualname, func, _stacklevel=2)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/_custom_ops.py:253: DeprecationWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
      return torch.library.impl_abstract(qualname, func, _stacklevel=2)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/_custom_ops.py:253: DeprecationWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
      return torch.library.impl_abstract(qualname, func, _stacklevel=2)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/_custom_ops.py:253: DeprecationWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
      return torch.library.impl_abstract(qualname, func, _stacklevel=2)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/_custom_ops.py:253: DeprecationWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
      return torch.library.impl_abstract(qualname, func, _stacklevel=2)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/_custom_ops.py:253: DeprecationWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
      return torch.library.impl_abstract(qualname, func, _stacklevel=2)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/_custom_ops.py:253: DeprecationWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
      return torch.library.impl_abstract(qualname, func, _stacklevel=2)
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_linear_weight: "f32[1, 3]", p_linear_bias: "f32[1]", x: "f32[5, 3]"):
                # 
                t: "f32[3, 1]" = torch.ops.aten.t.default(p_linear_weight);  p_linear_weight = None
                addmm: "f32[5, 1]" = torch.ops.aten.addmm.default(p_linear_bias, x, t);  p_linear_bias = x = t = None
                sigmoid: "f32[5, 1]" = torch.ops.aten.sigmoid.default(addmm);  addmm = None
                return (sigmoid,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sigmoid'), target=None)])
    Range constraints: {}

The called function such as torch.ops.aten.addmm.default are well identified and those cannot be converted into ONNX. The interpret just maps this string to a function creating the onnx implementation aten_addmm inside a dispatcher run_node which includes the following piece of code:

if node.op == "placeholder":
    return self.placeholder(node)
if node.op == "call_function":
    return self.call_function(node)
if node.op == "output":
    return self.output(node)
if node.op == "call_module":
    return self.call_module(node)
if node.op == "get_attr":
    return self.get_attr(node)
if node.op == "call_method":
    return self.call_method(node)

A converting function

Let’s consider the easy converting following function.

def aten_addmm(
    g: GraphBuilder,
    sts: Optional[Dict[str, Any]],
    outputs: List[str],

    a: T,
    b: T,
    c: T,
    beta: float = 1.0,
    alpha: float = 1.0,
) -> T:
    "gemm"
    res = g.op.Gemm(
        b, c, a, alpha=float(alpha), beta=float(beta), outputs=outputs, name="addmm"
    )
    if not sts:
        g.set_type(res, g.get_type(b))
        g.set_rank(res, 2)
    return res

The three first arguments are the GraphBuilder, a boolean asking the function to set the shape and rank, the output names to make sure the name are the same than the one in the graph provided by torch. It helps debugging.

Shapes And Types

The function can assume the type is always filled. The shapes should be set but in this case, only the rank is provided. It is not mandatory but it helps the following functions to take the right decision. GraphBuilder is setting the type and shape for a limited number of operator type such as Identity. It should be better in the next versions. Some helpers were already implemented to set shape or types as shown in this function.

def aten_asin(g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T) -> T:
    "asin"
    res = g.make_node("Asin", [x], outputs)
    if not sts:
        set_type_shape_unary_op(g, outputs[0], x)
    return res

The boolean sts is None when the graph given by torch contains no information about shape and type. Otherwise, the interpreter gives them to the graph builder through sts.

Different Implementations

In the following case, the function adds a node Identity or CastLike depending on the types. CastLike is only needed when types are different. And the graph builder will remove the Identity node.

def aten_copy(
    g: GraphBuilder,
    sts: Optional[Dict[str, Any]],
    outputs: List[str],
    x: T,
    src: T,
    non_blocking: bool = False,
) -> T:
    "identity"
    assert not non_blocking, "copy implemented when non_blocking is True"
    if g.get_type(x) == g.get_type(src):
        return g.op.Identity(src, name="copy")
    return g.op.CastLike(src, x, name="copy")

Conventions

The node should be given name based on the aten functions they are part of. Doing that helps the developper to find where a failing node comes from.

Functions

All the available functions are listed in one the those three pages:

Every function added to these modules is automatically added to the list of known converter functions.

Pratice

Example

<<<

import torch
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.torch_interpreter import to_onnx


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


model = Neuron(3, 1)

x = torch.rand(5, 3)

onx = to_onnx(model, (x,), input_names=["x"])

print(onnx_simple_text_plot(onx))

>>>

    opset: domain='' version=18
    input: name='x' type=dtype('float32') shape=[5, 3]
    init: name='p_linear_weight' type=dtype('float32') shape=(1, 3) -- array([-0.431,  0.572, -0.398], dtype=float32)
    init: name='p_linear_bias' type=dtype('float32') shape=(1,) -- array([0.075], dtype=float32)
    Gemm(x, p_linear_weight, p_linear_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm
      Sigmoid(addmm) -> output_0
    output: name='output_0' type=dtype('float32') shape=[5, 1]

And visually:

digraph{
  size=7;
  orientation=portrait;
  nodesep=0.05;
  ranksep=0.25;

  x [shape=box color=red label="x\nTensorProto.FLOAT\nshape=[5, 3]" fontsize=10];

  output_0 [shape=box color=green label="output_0\nTensorProto.FLOAT\nshape=[5, 1]" fontsize=10];

  p_linear_weight [shape=box label="p_linear_weight\nfloat32((1, 3))\n[[-0.33710057 -0.36762246  0.2774214 ]]" fontsize=10];
  p_linear_bias [shape=box label="p_linear_bias\nfloat32((1,))\n[0.5049715]" fontsize=10];

  addmm [shape=box label="addmm" fontsize=10];
  TransposeMatMulPattern__addmm [shape=box style="filled,rounded" color=orange label="Gemm\nalpha=1.0\nbeta=1.0\ntransA=0\ntransB=1" fontsize=10];
  x -> TransposeMatMulPattern__addmm;
  p_linear_weight -> TransposeMatMulPattern__addmm;
  p_linear_bias -> TransposeMatMulPattern__addmm;
  TransposeMatMulPattern__addmm -> addmm;

  Opset [shape=box style="filled,rounded" color=orange label="Sigmoid" fontsize=10];
  addmm -> Opset;
  Opset -> output_0;
}

Debugging

There is no fallback by default. The converter fails if the conversion to ONNX cannot happen. In that case, it tries to give you some information why it failed. (The example might succeed in the future.)

<<<

import torch
from experimental_experiment.torch_interpreter import to_onnx


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.celu(self.linear(x))


x = torch.rand(5, 3)
model = Neuron(3, 1)


onx = to_onnx(model, (x,), input_names=["x"])

>>>

    
    [runpythonerror]
    
    Traceback (most recent call last):
        exec(obj, globs, loc)
      File "", line 20, in <module>
      File "", line 19, in run_python_script_140591153436352
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 341, in to_onnx
        builder.process(graph_module, interpreter)
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py", line 2546, in process
        interpreter.run_node(node)
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 98, in run_node
        return self.call_function(node)
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 782, in call_function
        raise FunctionNotFoundError(
    experimental_experiment.torch_interpreter._exceptions.FunctionNotFoundError: Unable to interpret function <class 'torch._ops.OpOverload'>: <OpOverload(op='aten.celu', overload='default')>, searched for ['aten::celu', 'celu_default'] and attributes ['__qualname__', '__name__'], args=(addmm,), kwargs={}
    --DEBUG--
    --SHAPE--
    dynamic_objects={}
    dynamic_objects_rev={}
    dynamic_alias={}
    dynamic_shapes=None
    _known_value_shape={}
    _known_shapes={'addmm': (5, 1),
     'p_linear_bias': (1,),
     'p_linear_weight': (1, 3),
     't': (3, 1),
     'x': (5, 3)}
    _known_ranks={}
    --TORCH-SHAPES--
    p_linear_weight: ('run_node', (('example_value', torch.float32, torch.Size([5, 1])), ('val', torch.float32, torch.Size([1, 3])))) --- 1:2:(1, 3):
    p_linear_bias: ('run_node', (('example_value', torch.float32, torch.Size([5, 1])), ('val', torch.float32, torch.Size([1])))) --- 1:1:(1,):
    x: ('run_node', ('', ('val', torch.float32, torch.Size([5, 3])))) --- 1:2:(5, 3):
    t: ('run_node', ('', ('val', torch.float32, torch.Size([3, 1])))) --- 1:2:(3, 1):
    addmm: ('run_node', ('', ('val', torch.float32, torch.Size([5, 1])))) --- 1:2:(5, 1):
    celu: ('run_node', ('', ('val', torch.float32, torch.Size([5, 1])))) --- :::
    --ONNX--
    -- process.graph_module --
    graph():
        %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
        %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
        %x : [num_users=1] = placeholder[target=x]
        %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%p_linear_weight,), kwargs = {})
        %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%p_linear_bias, %x, %t), kwargs = {})
        %celu : [num_users=1] = call_function[target=torch.ops.aten.celu.default](args = (%addmm,), kwargs = {})
        return (celu,)
    -- process.progress --
    node 5/7 
    --
    [GraphBuilder-VKQ.make_tensor_input] x[1:5x3]
    [GraphBuilder-VKQ.make_initializer] p_linear_weight[torch.float32:torch.float32:[-0.1595292091369629, -0.5146317481994629, -0.3734068274497986]]
    [GraphBuilder-VKQ.make_initializer] p_linear_bias[torch.float32:torch.float32:[-0.0630064532160759]]
    [GraphBuilder-VKQ.make_node] t               [#:#   ] Transpose:['p_linear_weight']->['t']
    [GraphBuilder-VKQ.make_node] addmm           [###:# ] Gemm:['x', 't', 'p_linear_bias']->['addmm']
    [GraphBuilder-VKQ] Message completed, there are 2 initializers, 2 nodes, 1 inputs, 1 outputs.

In particular, the first line of the error message. This one tells you there is currently no known conversion of function aten.celu. A function aten_celu must be added to the file experimental_experiment.torch_interpreter._aten_functions.

Unable to interpret function <class 'torch._ops.OpOverload'>: <OpOverload(op='aten.celu', overload='default')>,
searched for ['aten::celu', 'celu_default'] and attributes ['__qualname__', '__name__'], args=(addmm,), kwargs={}

Below is the graph module:

-- process.graph_module --
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%arg0_1,), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %t), kwargs = {})
    %celu : [num_users=1] = call_function[target=torch.ops.aten.celu.default](args = (%addmm,), kwargs = {})
    return (celu,)
-- process.progress --
node 5/7

The last line tells you, it stopped at line 5/7 which helps to find what functions were called before. Next is the information of all nodes added so far. We can see that except this function, everything looks good and all shapes are known.

[GraphBuilder-BQU.make_tensor_input] x[1:5x3]
[GraphBuilder-BQU.make_initializer] arg0_1[torch.float32:torch.Size([1, 3]):[-0.44980645179748535, 0.29780903458595276, -0.32629191875457764]]
[GraphBuilder-BQU.make_initializer] arg1_1[torch.float32:torch.Size([1]):[0.2905656397342682]]
[GraphBuilder-BQU.make_node]                 [#:#   ] Identity:['x']->['arg2_1']
[GraphBuilder-BQU.make_node] t               [#:#   ] Transpose:['arg0_1']->['t']
[GraphBuilder-BQU.make_node] addmm           [###:# ] Gemm:['arg2_1', 't', 'arg1_1']->['addmm']

There is also this section starting with --TORCH-SHAPE-- which shows which shapes are given by torch.

--TORCH-SHAPES--
arg0_1: ('run_node', (('example_value', torch.float32, torch.Size([5, 1])), ('val', torch.float32, torch.Size([1, 3])))) --- 1:2:(1, 3):
arg1_1: ('run_node', (('example_value', torch.float32, torch.Size([5, 1])), ('val', torch.float32, torch.Size([1])))) --- 1:1:(1,):
arg2_1: ('run_node', ('', ('val', torch.float32, torch.Size([5, 3])))) --- 1:2:(5, 3):
t: ('run_node', ('', ('val', torch.float32, torch.Size([3, 1])))) --- 1:2:(3, 1):
addmm: ('run_node', ('', ('val', torch.float32, torch.Size([5, 1])))) --- 1:2:(5, 1):
celu: ('run_node', ('', ('val', torch.float32, torch.Size([5, 1])))) --- :::

Dynamic Shapes

It just needs to be added when calling function to_onnx: dynamic_shapes={"x": {0: torch.export.Dim("batch")}}.

<<<

import torch
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.torch_interpreter import to_onnx


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


model = Neuron(3, 1)

x = torch.rand(5, 3)

onx = to_onnx(
    model,
    (x,),
    input_names=["x"],
    dynamic_shapes={"x": {0: torch.export.Dim("batch")}},
)

print(onnx_simple_text_plot(onx))

>>>

    opset: domain='' version=18
    input: name='x' type=dtype('float32') shape=['batch', 3]
    init: name='p_linear_weight' type=dtype('float32') shape=(1, 3) -- array([ 0.267, -0.471,  0.145], dtype=float32)
    init: name='p_linear_bias' type=dtype('float32') shape=(1,) -- array([-0.361], dtype=float32)
    Gemm(x, p_linear_weight, p_linear_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm
      Sigmoid(addmm) -> output_0
    output: name='output_0' type=dtype('float32') shape=['batch', 1]

Fallback

The current library does not always have a converting function for evert aten functions implemented in torch. A mechanism exists to intercept the function returned by the interpreter and replace it by a function coming from another source such as onnx-script.

<<<

import torch
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.torch_interpreter.oxs_dispatcher import OxsDispatcher


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.celu(self.linear(x))


x = torch.rand(5, 3)
model = Neuron(3, 1)

onx = to_onnx(model, (x,), input_names=["x"], dispatcher=OxsDispatcher(verbose=2))

print(onnx_simple_text_plot(onx))

>>>

    opset: domain='' version=18
    input: name='x' type=dtype('float32') shape=[5, 3]
    init: name='p_linear_weight' type=dtype('float32') shape=(1, 3) -- array([ 0.505, -0.358, -0.055], dtype=float32)
    init: name='p_linear_bias' type=dtype('float32') shape=(1,) -- array([-0.335], dtype=float32)
    Gemm(x, p_linear_weight, p_linear_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm
      Celu(addmm, alpha=1.00) -> output_0
    output: name='output_0' type=dtype('float32') shape=[5, 1]