Custom Exporter¶
The exporter implemented in this package is built upon a classic architecture with two main classes:
a
GraphBuilder
, it is a container for created nodes and initializers, it stores additional information such as shapes, types, constants, it providers methods to easily create nodes, provides unique names.a
DynamoInterpreter
, this class goes through the model described as a GraphModule and calls the appropriate converting functions to translate every call into an equivalent ONNX graph.
Both classes are usually not seen by the user. They are called either by
function to_onnx
which converts a model into an ONNX graph or through a custom backend:
onnx_custom_backend
, this backend leverages onnxruntime to run the inference, it is fast,onnx_debug_backend
, a backend using numpy, meant to debug, it is slow.
This second backend calls the reference implementation through class
ExtendedReferenceEvaluator
.
This class extends ReferenceEvaluator
from package onnx.
One objective: SPEED¶
The only objective for this exporter is speed. It must be fast as the size of the model to convert grows fast. The exporter may be one piece of the backend calling onnxruntime. This only objective implies a few constraints.
multi-opset support
The converter must support the conversion to different
opsets to avoid using the
onnx.version_converter.convert_version()
which
does not fully work when the model includes other domain than the main one.
use shape and type information
The GraphModule comes with the shape and type information of the tensor it manipulates. It must be used to optimize the onnx graph rather than using an optimizer after the conversion happens.
no decorators, no code interpretation
Writing efficient code is easier when the code you see is the code you get. A decorator hides some logic a developper must take into account to avoid writing non efficient code. On the same level, translating a python code into ONNX requires extra logic the developper does not control.
no fallback
The implementation fails if it cannot find a solution to convert the model into ONNX. There are some ways to go around that but there are not enabled by default. The user must know if the exporter follows a different way to produce the model.
GraphBuilder¶
GraphBuilder
starts from empty or take an existing graph as an input.
In that case, the builder is usually used by an optimizer.
Internal containers¶
Beside the onnx structure, the builder holds information about the requested opsets and the dynamic shapes. During the conversion, it stores informations about
_unique_names: names already taken for results
_unique_node_names: names already taken for node node
_known_names: existing names
_known_types: known type for every result, it must exist
_known_shapes: known shape for every result, either shape or rank is known
_known_ranks: declared ranks
_known_value_shape: results known as shapes, the implementation tries to capture the logic with string, sympy could be used
The model stores some constant, the builder assumes every node taking only constant as inputs produces a new constant.
constants_: constant values
constants_computed_: computed constant values, constant built from constant, every computed constant is cached,
The builder tries to minimize the number of intializers to create. It stores a unique value for the small one:
_values: cache initializer value to merge those which are equal
The forward/backward graphs may dynamic dimension as input. Some results are reshaped based on this inputs. The following container keep track of this information.
dynamic_objects: list of dynamic dimensions coming as inputs
dynamic_objects_rev: reverse dictionary to fasten lookups
- _dynamic_alias: used when the user gives a different
name to the dynamic shapes
Next container store dynamic shapes.
_cache_shape: cache concatenation of shapes
API¶
The following methods are used to add onnx elements to the graph.
get_opset
: get the value for a domainmake_tensor_input
: adds an input to the graph, is_dimension specifies if this input is a dynamic dimension, a single integer,make_tensor_output
: adds an output to the graph, is_dimension specifies if this output is a dynamic dimension, a single integer,make_initializer
: this method is used to add initializer to the graph,make_node
: add a node to the graphto_onnx
: produces the final ONNX
Some needs are very common and deserve a dedicated method.
make_nodes
: adds many nodes in one row, it renames the intermediate results if needed.get_attribute
: retrieves an attribute from a NodeProtomake_shape_from_results
: makes a shape from a tuple having integer, string, or torch.SymInt
It is important to update the shape the information is available.
Get the information:
Set the information:
A function used to provide information to the user and calls in most of the error message:
assert name in self._known_ranks, (
f"Rank is unknown for result {name!r}, "
f"known_shapes={self._known_ranks}{self.get_debug_msg()}"
)
Example¶
<<<
import numpy as np
from onnx import TensorProto
from experimental_experiment.xbuilder import GraphBuilder
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.helpers import pretty_onnx
gr = GraphBuilder(18, ir_version=9)
gr.make_tensor_input("X", TensorProto.FLOAT, ("a", "b"), is_dimension=False)
weight = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32).T)
bias = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32))
mm = gr.make_node("MatMul", ["X", weight], name="n")
out = gr.make_node("Add", [mm, bias], ["Y"], name="n")
gr.make_tensor_output(out, TensorProto.FLOAT, ("a",), indexed=False, is_dimension=False)
onx = gr.to_onnx()
ref = ExtendedReferenceEvaluator(onx)
x = np.random.rand(5, 3).astype(np.float32)
y = ref.run(None, {"X": x})[0]
print(y)
print(pretty_onnx(onx))
print("Without any information, the known shapes are:")
print(gr._known_shapes)
print("Without any information, the known shapes are:")
print(gr.constants_)
print("The constant are not converted into TensorProto until the very end:")
print(gr.initializers_dict)
>>>
[[1.481 1.581 1.681]
[0.942 1.042 1.142]
[1.647 1.747 1.847]
[1.178 1.278 1.378]
[1.587 1.687 1.787]]
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='X' type=dtype('float32') shape=['a', 'b']
init: name='init1_s3x1_' type=dtype('float32') shape=(3, 1) -- array([0.4, 0.5, 0.6], dtype=float32)
init: name='init1_s1x3_' type=dtype('float32') shape=(1, 3) -- array([0.4, 0.5, 0.6], dtype=float32)
MatMul(X, init1_s3x1_) -> _onx_matmul0
Add(_onx_matmul0, init1_s1x3_) -> Y
output: name='Y' type=dtype('float32') shape=['a']
Without any information, the known shapes are:
{'X': ('a', 'b'), 'init1_s3x1_': (3, 1), 'init1_s1x3_': (1, 3), '_onx_matmul0': ('a', 1), 'Y': ('a', 3)}
Without any information, the known shapes are:
{'init1_s3x1_': None, 'init1_s1x3_': None}
The constant are not converted into TensorProto until the very end:
{'init1_s3x1_': array([[0.4],
[0.5],
[0.6]], dtype=float32), 'init1_s1x3_': array([[0.4, 0.5, 0.6]], dtype=float32)}
The constant are only computed on demand. Their conversion to TensorProto
only happens when method
to_onnx
is called.
Debugging¶
An exception is raised an error is detected and it displays the result
of get_debug_msg
.
<<<
import numpy as np
from onnx import TensorProto
from experimental_experiment.xbuilder import GraphBuilder
from experimental_experiment.reference import ExtendedReferenceEvaluator
gr = GraphBuilder(18, ir_version=9)
gr.make_tensor_input("X", TensorProto.FLOAT, ("a", "b"), is_dimension=False)
weight = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32).T)
bias = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32))
mm = gr.make_node("MatMul", ["X", weight], name="N1")
out = gr.make_node("Add", [mm, "bias"], ["Y"], name="N2")
gr.make_tensor_output(out, TensorProto.FLOAT, ("a",), indexed=False, is_dimension=False)
onx = gr.to_onnx()
>>>
[runpythonerror]
Traceback (most recent call last):
exec(obj, globs, loc)
File "", line 19, in <module>
File "", line 16, in run_python_script_139640045716992
File "/home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py", line 3074, in make_node
assert self.has_name(i), (
AssertionError: Input 'bias' does not exist for operator 'Add', inputs=['_onx_matmul0', 'bias'], outputs=['Y'], name='N2' (CUQ)
--DEBUG--
[GraphBuilder-CUQ] Message starts, there are 2 initializers, 1 nodes, 1 inputs, 1 outputs.
--LOCAL FUNCTIONS--
--PARAMETERS--
dynamic_examples=
--SHAPE--
dynamic_examples=
dynamic_objects=
a = 'a'
b = 'b'
dynamic_objects_rev=
dynamic_dimensions_source={}
dynamic_alias={}
dynamic_shapes=None
_known_value_shape={}
_known_types={'X': 1, '_onx_matmul0': 1, 'init1_s1x3_': 1, 'init1_s3x1_': 1}
_known_shapes={'X': ('a', 'b'),
'_onx_matmul0': ('a', 1),
'init1_s1x3_': (1, 3),
'init1_s3x1_': (3, 1)}
_known_constants=['init1_s1x3_', 'init1_s3x1_']
_known_ranks={}
--TORCH-USERS--
--TORCH-SHAPES--
--ONNX--
--
[GraphBuilder-CUQ.make_tensor_input] X[1:axb]
[GraphBuilder-CUQ.make_initializer] init1_s3x1_[float32:float32:[0.4000000059604645, 0.5, 0.6000000238418579]]
[GraphBuilder-CUQ.make_initializer] init1_s1x3_[float32:float32:[0.4000000059604645, 0.5, 0.6000000238418579]]
[GraphBuilder-CUQ.make_node] N1 [##:# ] MatMul:['X', 'init1_s3x1_']->['_onx_matmul0']
[GraphBuilder-CUQ] Message completed, there are 2 initializers, 1 nodes, 1 inputs, 1 outputs.
It shows the information currently available while building the model. At the end the following lines appear.
[GraphBuilder-EAQ.make_node] N1 [##:- ] MatMul:['X', 'init1_s3x1_']->['_onx_matmul0']
It says one node named N1
was created. ##
means the shape and type are
known for the two inputs it has. - means nothing is known for the output.
When the type is specified, it shows the following:
<<<
import numpy as np
from onnx import TensorProto
from experimental_experiment.xbuilder import GraphBuilder
from experimental_experiment.reference import ExtendedReferenceEvaluator
gr = GraphBuilder(18, ir_version=9)
gr.make_tensor_input("X", TensorProto.FLOAT, ("a", "b"), is_dimension=False)
weight = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32).T)
bias = gr.make_initializer("", np.array([[0.4, 0.5, 0.6]], dtype=np.float32))
mm = gr.make_node("MatMul", ["X", weight], name="N1")
gr.set_type(mm, TensorProto.FLOAT)
out = gr.make_node("Add", [mm, "bias"], ["Y"], name="N2")
gr.make_tensor_output(out, TensorProto.FLOAT, ("a",), indexed=False, is_dimension=False)
onx = gr.to_onnx()
>>>
[runpythonerror]
Traceback (most recent call last):
exec(obj, globs, loc)
File "", line 20, in <module>
File "", line 17, in run_python_script_139640045691520
File "/home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py", line 3074, in make_node
assert self.has_name(i), (
AssertionError: Input 'bias' does not exist for operator 'Add', inputs=['_onx_matmul0', 'bias'], outputs=['Y'], name='N2' (DIE)
--DEBUG--
[GraphBuilder-DIE] Message starts, there are 2 initializers, 1 nodes, 1 inputs, 1 outputs.
--LOCAL FUNCTIONS--
--PARAMETERS--
dynamic_examples=
--SHAPE--
dynamic_examples=
dynamic_objects=
a = 'a'
b = 'b'
dynamic_objects_rev=
dynamic_dimensions_source={}
dynamic_alias={}
dynamic_shapes=None
_known_value_shape={}
_known_types={'X': 1, '_onx_matmul0': 1, 'init1_s1x3_': 1, 'init1_s3x1_': 1}
_known_shapes={'X': ('a', 'b'),
'_onx_matmul0': ('a', 1),
'init1_s1x3_': (1, 3),
'init1_s3x1_': (3, 1)}
_known_constants=['init1_s1x3_', 'init1_s3x1_']
_known_ranks={}
--TORCH-USERS--
--TORCH-SHAPES--
--ONNX--
--
[GraphBuilder-DIE.make_tensor_input] X[1:axb]
[GraphBuilder-DIE.make_initializer] init1_s3x1_[float32:float32:[0.4000000059604645, 0.5, 0.6000000238418579]]
[GraphBuilder-DIE.make_initializer] init1_s1x3_[float32:float32:[0.4000000059604645, 0.5, 0.6000000238418579]]
[GraphBuilder-DIE.make_node] N1 [##:# ] MatMul:['X', 'init1_s3x1_']->['_onx_matmul0']
[GraphBuilder-DIE] Message completed, there are 2 initializers, 1 nodes, 1 inputs, 1 outputs.
It shows U
when the type and rank are known, #
if the type and shape are known.
[GraphBuilder-MJG.make_node] N1 [##:U ] MatMul:['X', 'init1_s3x1_']->['_onx_matmul0']
Simplified API¶
For the most common nodes, there exists a shortcut to make the syntax shorter.
<<<
import numpy as np
from onnx import TensorProto
from experimental_experiment.xbuilder import GraphBuilder
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.helpers import pretty_onnx
gr = GraphBuilder(18, ir_version=9)
gr.make_tensor_input("X", TensorProto.FLOAT, ("a", "b"), is_dimension=False)
mm = gr.op.MatMul("X", np.array([[0.4, 0.5, 0.6]], dtype=np.float32).T)
out = gr.op.Add(mm, np.array([0.4, 0.5, 0.6], dtype=np.float32), outputs=["Y"])
gr.make_tensor_output(out, TensorProto.FLOAT, ("a",), indexed=False, is_dimension=False)
onx = gr.to_onnx()
ref = ExtendedReferenceEvaluator(onx)
x = np.random.rand(5, 3).astype(np.float32)
y = ref.run(None, {"X": x})[0]
print(y)
print(pretty_onnx(onx))
>>>
[[0.756 0.856 0.956]
[1.511 1.611 1.711]
[0.796 0.896 0.996]
[0.901 1.001 1.101]
[1.167 1.267 1.367]]
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='X' type=dtype('float32') shape=['a', 'b']
init: name='init1_s3x1_' type=dtype('float32') shape=(3, 1) -- array([0.4, 0.5, 0.6], dtype=float32)
init: name='init1_s3_' type=dtype('float32') shape=(3,) -- array([0.4, 0.5, 0.6], dtype=float32)
MatMul(X, init1_s3x1_) -> _onx_matmul0
Add(_onx_matmul0, init1_s3_) -> Y
output: name='Y' type=dtype('float32') shape=['a']
Optimizations¶
GraphBuilder
implements three basic optimizations algorithms not using patterns.
Except constant folding, they are called by default.
remove_unused
: removes unused nodesremove_identity_nodes
: removes identity nodesconstant_folding
: replaces constants whenever it is possible and it makes sense
DynamoInterpreter¶
Class DynamoInterpreter
walks through a graph module and selects the best translation
for every part. It is a sequence of calls to internal functions
called aten functions. It looks like the following:
<<<
import torch
class Neuron(torch.nn.Module):
def __init__(self, n_dims: int, n_targets: int):
super(Neuron, self).__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
return torch.sigmoid(self.linear(x))
x = torch.rand(5, 3)
model = Neuron(3, 1)
graph = torch.export.export(model, (x,))
print(graph)
>>>
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_linear_weight: "f32[1, 3]", p_linear_bias: "f32[1]", x: "f32[5, 3]"):
#
linear: "f32[5, 1]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None
sigmoid: "f32[5, 1]" = torch.ops.aten.sigmoid.default(linear); linear = None
return (sigmoid,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sigmoid'), target=None)])
Range constraints: {}
The called function such as torch.ops.aten.addmm.default
are well
identified and those cannot be converted into ONNX.
The interpret just maps this string to a function creating
the onnx implementation aten_addmm
inside a dispatcher
run_node
which includes the following piece of code:
if node.op == "placeholder":
return self.placeholder(node)
if node.op == "call_function":
return self.call_function(node)
if node.op == "output":
return self.output(node)
if node.op == "call_module":
return self.call_module(node)
if node.op == "get_attr":
return self.get_attr(node)
if node.op == "call_method":
return self.call_method(node)
A converting function¶
Let’s consider the easy converting following function.
def aten_addmm(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
a: T,
b: T,
c: T,
beta: float = 1.0,
alpha: float = 1.0,
) -> T:
"gemm"
res = g.op.Gemm(
b, c, a, alpha=float(alpha), beta=float(beta), outputs=outputs, name="addmm"
)
if not sts:
g.set_type(res, g.get_type(b))
g.set_rank(res, 2)
return res
The three first arguments are the
GraphBuilder
,
a boolean asking the function to set the shape and rank,
the output names to make sure the name are the same than the one in the graph
provided by torch. It helps debugging.
Shapes And Types¶
The function can assume the type is always filled.
The shapes should be set but in this case, only the rank is provided.
It is not mandatory but it helps the following functions to take
the right decision.
GraphBuilder
is setting the type and shape for a limited number of operator type
such as Identity. It should be better in the next versions.
Some helpers were already implemented to set shape or types
as shown in this function.
def aten_asin(g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T) -> T:
"asin"
res = g.make_node("Asin", [x], outputs)
if not sts:
set_type_shape_unary_op(g, outputs[0], x)
return res
The boolean sts
is None
when the graph given by torch contains
no information about shape and type. Otherwise, the interpreter
gives them to the graph builder through sts
.
Different Implementations¶
In the following case, the function adds a node Identity
or CastLike
depending on the types. CastLike
is only needed when types
are different. And the graph builder will remove the Identity
node.
def aten_copy(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
src: T,
non_blocking: bool = False,
) -> T:
"identity"
assert not non_blocking, "copy implemented when non_blocking is True"
if g.get_type(x) == g.get_type(src):
return g.op.Identity(src, name="copy")
return g.op.CastLike(src, x, name="copy")
Conventions¶
The node should be given name based on the aten functions they are part of. Doing that helps the developper to find where a failing node comes from.
Functions¶
All the available functions are listed in one the those three pages:
experimental_experiment.torch_interpreter._aten_functions: functions
experimental_experiment.torch_interpreter._aten_methods: methods
experimental_experiment.torch_interpreter._prims_functions: primitives
Every function added to these modules is automatically added to the list of known converter functions.
Pratice¶
First Example¶
<<<
import torch
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx
class Neuron(torch.nn.Module):
def __init__(self, n_dims: int, n_targets: int):
super(Neuron, self).__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
return torch.sigmoid(self.linear(x))
model = Neuron(3, 1)
x = torch.rand(5, 3)
onx = to_onnx(model, (x,), input_names=["x"])
print(pretty_onnx(onx))
>>>
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[5, 3]
init: name='p_linear_weight' type=dtype('float32') shape=(1, 3) -- array([ 0.426, 0.058, -0.153], dtype=float32)
init: name='p_linear_bias' type=dtype('float32') shape=(1,) -- array([-0.112], dtype=float32)
Gemm(x, p_linear_weight, p_linear_bias, transB=1) -> linear
Sigmoid(linear) -> output_0
output: name='output_0' type=dtype('float32') shape=[5, 1]
And visually:
Debugging¶
There is no fallback by default. The converter fails if the conversion to ONNX cannot happen. In that case, it tries to give you some information why it failed. (The example might succeed in the future.)
<<<
import torch
from experimental_experiment.torch_interpreter import to_onnx
class Neuron(torch.nn.Module):
def __init__(self, n_dims: int, n_targets: int):
super(Neuron, self).__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
return torch.celu(self.linear(x))
x = torch.rand(5, 3)
model = Neuron(3, 1)
onx = to_onnx(model, (x,), input_names=["x"])
>>>
[runpythonerror]
Traceback (most recent call last):
exec(obj, globs, loc)
File "", line 20, in <module>
File "", line 19, in run_python_script_139640044739328
File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 820, in to_onnx
builder.process(graph_module, interpreter)
File "/home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py", line 4024, in process
interpreter.run_node(node)
File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 194, in run_node
res = self.call_function(node)
File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 1126, in call_function
raise FunctionNotFoundError(
experimental_experiment.torch_interpreter._exceptions.FunctionNotFoundError: Unable to interpret function <class 'torch._ops.OpOverload'>: <OpOverload(op='aten.celu', overload='default')>, searched for ['aten::celu', 'celu_default'] and attributes ['__qualname__', '__name__'], args=(linear,), kwargs={}
--DEBUG--
[GraphBuilder-MMI] Message starts, there are 2 initializers, 3 nodes, 1 inputs, 1 outputs.
--LOCAL FUNCTIONS--
--PARAMETERS--
dynamic_examples=
--SHAPE--
dynamic_examples=
dynamic_objects=
dynamic_objects_rev=
dynamic_dimensions_source={}
dynamic_alias={}
dynamic_shapes=None
_known_value_shape={}
_known_types={'_onx_matmul0': 1,
'_onx_transpose0': 1,
'linear': 1,
'p_linear_bias': 1,
'p_linear_weight': 1,
'x': 1}
_known_shapes={'_onx_matmul0': (5, 1),
'_onx_transpose0': (3, 1),
'linear': (5, 1),
'p_linear_bias': (1,),
'p_linear_weight': (1, 3),
'x': (5, 3)}
_known_constants=['_onx_transpose0', 'p_linear_bias', 'p_linear_weight']
_known_ranks={}
--TORCH-USERS--
celu -> {output}
linear -> {celu}
p_linear_bias -> {linear}
p_linear_weight -> {linear}
x -> {linear}
--TORCH-SHAPES--
p_linear_weight: ('run_node', (('example_value', torch.float32, torch.Size([5, 1])), ('val', torch.float32, torch.Size([1, 3])))) --- 1:2:(1, 3):
p_linear_bias: ('run_node', (('example_value', torch.float32, torch.Size([5, 1])), ('val', torch.float32, torch.Size([1])))) --- 1:1:(1,):
x: ('run_node', ('', ('val', torch.float32, torch.Size([5, 3])))) --- 1:2:(5, 3):
linear: ('run_node', ('', ('val', torch.float32, torch.Size([5, 1])))) --- 1:2:(5, 1):
celu: ('run_node', ('', ('val', torch.float32, torch.Size([5, 1])))) --- :::
--ONNX--
-- process.graph_module --
graph():
%p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
%p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
%x : [num_users=1] = placeholder[target=x]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
%celu : [num_users=1] = call_function[target=torch.ops.aten.celu.default](args = (%linear,), kwargs = {})
return (celu,)
-- process.progress --
node 4/6 target=aten.celu.default
--
[GraphBuilder-MMI.make_tensor_input] x[1:5x3]
[GraphBuilder-MMI.make_initializer] p_linear_weight[torch.float32:torch.float32:[0.1179031953215599, -0.44609642028808594, -0.32514604926109314]]
[GraphBuilder-MMI.make_initializer] p_linear_bias[torch.float32:torch.float32:[-0.27352070808410645]]
[GraphBuilder-MMI.make_node] linear [#:# ] Transpose:['p_linear_weight']->['_onx_transpose0']
[GraphBuilder-MMI.make_node] Opset [##:# ] MatMul:['x', '_onx_transpose0']->['_onx_matmul0']
[GraphBuilder-MMI.make_node] Opset2 [##:# ] Add:['_onx_matmul0', 'p_linear_bias']->['linear']
[GraphBuilder-MMI] Message completed, there are 2 initializers, 3 nodes, 1 inputs, 1 outputs.
In particular, the first line of the error message. This one tells you there is currently no known
conversion of function aten.celu
. A function aten_celu
must be added
to the file experimental_experiment.torch_interpreter._aten_functions
.
Unable to interpret function <class 'torch._ops.OpOverload'>: <OpOverload(op='aten.celu', overload='default')>,
searched for ['aten::celu', 'celu_default'] and attributes ['__qualname__', '__name__'], args=(addmm,), kwargs={}
Below is the graph module:
-- process.graph_module --
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%arg0_1,), kwargs = {})
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %t), kwargs = {})
%celu : [num_users=1] = call_function[target=torch.ops.aten.celu.default](args = (%addmm,), kwargs = {})
return (celu,)
-- process.progress --
node 5/7
The last line tells you, it stopped at line 5/7 which helps to find what functions were called before. Next is the information of all nodes added so far. We can see that except this function, everything looks good and all shapes are known.
[GraphBuilder-BQU.make_tensor_input] x[1:5x3]
[GraphBuilder-BQU.make_initializer] arg0_1[torch.float32:torch.Size([1, 3]):[-0.44980645179748535, 0.29780903458595276, -0.32629191875457764]]
[GraphBuilder-BQU.make_initializer] arg1_1[torch.float32:torch.Size([1]):[0.2905656397342682]]
[GraphBuilder-BQU.make_node] [#:# ] Identity:['x']->['arg2_1']
[GraphBuilder-BQU.make_node] t [#:# ] Transpose:['arg0_1']->['t']
[GraphBuilder-BQU.make_node] addmm [###:# ] Gemm:['arg2_1', 't', 'arg1_1']->['addmm']
There is also this section starting with --TORCH-SHAPE--
which shows which shapes are given by torch.
--TORCH-SHAPES--
arg0_1: ('run_node', (('example_value', torch.float32, torch.Size([5, 1])), ('val', torch.float32, torch.Size([1, 3])))) --- 1:2:(1, 3):
arg1_1: ('run_node', (('example_value', torch.float32, torch.Size([5, 1])), ('val', torch.float32, torch.Size([1])))) --- 1:1:(1,):
arg2_1: ('run_node', ('', ('val', torch.float32, torch.Size([5, 3])))) --- 1:2:(5, 3):
t: ('run_node', ('', ('val', torch.float32, torch.Size([3, 1])))) --- 1:2:(3, 1):
addmm: ('run_node', ('', ('val', torch.float32, torch.Size([5, 1])))) --- 1:2:(5, 1):
celu: ('run_node', ('', ('val', torch.float32, torch.Size([5, 1])))) --- :::
Dynamic Shapes¶
It just needs to be added when calling function
to_onnx
:
dynamic_shapes={"x": {0: torch.export.Dim("batch")}}
.
<<<
import torch
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx
class Neuron(torch.nn.Module):
def __init__(self, n_dims: int, n_targets: int):
super(Neuron, self).__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
return torch.sigmoid(self.linear(x))
model = Neuron(3, 1)
x = torch.rand(5, 3)
onx = to_onnx(
model,
(x,),
input_names=["x"],
dynamic_shapes={"x": {0: torch.export.Dim("batch")}},
)
print(pretty_onnx(onx))
>>>
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=['batch', 3]
init: name='p_linear_weight' type=dtype('float32') shape=(1, 3) -- array([-0.101, 0.321, 0.226], dtype=float32)
init: name='p_linear_bias' type=dtype('float32') shape=(1,) -- array([0.082], dtype=float32)
Gemm(x, p_linear_weight, p_linear_bias, transB=1) -> linear
Sigmoid(linear) -> output_0
output: name='output_0' type=dtype('float32') shape=['batch', 1]
Fallback¶
The current library does not always have a converting function for evert aten functions implemented in torch. A mechanism exists to intercept the function returned by the interpreter and replace it by a function coming from another source such as onnxscript.
<<<
import torch
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.torch_interpreter.oxs_dispatcher import OxsDispatcher
class Neuron(torch.nn.Module):
def __init__(self, n_dims: int, n_targets: int):
super(Neuron, self).__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
return torch.celu(self.linear(x))
x = torch.rand(5, 3)
model = Neuron(3, 1)
onx = to_onnx(model, (x,), input_names=["x"], dispatcher=OxsDispatcher(verbose=2))
print(pretty_onnx(onx))
>>>
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[5, 3]
init: name='p_linear_weight' type=dtype('float32') shape=(1, 3) -- array([-0.476, 0.319, 0.042], dtype=float32)
init: name='p_linear_bias' type=dtype('float32') shape=(1,) -- array([-0.313], dtype=float32)
Gemm(x, p_linear_weight, p_linear_bias, transB=1) -> linear
Celu(linear, alpha=1.00) -> output_0
output: name='output_0' type=dtype('float32') shape=[5, 1]
Fake Tensors¶
See example in function experimental_experiment.torch_interpreter.match_input_parameters()
.
Export Options¶
Parameter export_options of function
to_onnx
can take
an instance of class
ExportOptions
to change the way the model is exported into a graph (strict=True or False)
for example.
bypass_export_some_errors¶
If the converter to onnx fails, function bypass_export_some_errors
may help solving some of them.
from experimental_experiment.torch_interpreter.onnx_export_errors import bypass_export_some_errors
with bypass_export_some_errors():
onx = to_onnx(...)
Use of local function¶
Local functions amy appear when a model is converted with
torch.autocast
enabled. They can be removed by setting
inline=True to_onnx
.
Export module as local functions¶
to_onnx
supports parameter export_modules_as_functions. By setting to True,
all submodules will be exported as local functions in the onnx graph.
<<<
import torch
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx
class SubNeuron2(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.linear1 = torch.nn.Linear(n_dims, n_targets)
self.linear2 = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
z1 = self.linear1(x)
z2 = self.linear2(x)
return torch.sigmoid(z1) + torch.sigmoid(z2)
class Neuron2(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.neuron = SubNeuron2(n_dims, n_targets)
def forward(self, x):
z = self.neuron(x)
return torch.relu(z)
model = Neuron2()
inputs = (torch.randn(1, 5),)
expected = model(*inputs)
feeds = {"x": inputs[0].numpy()}
onx = to_onnx(
model,
inputs,
export_modules_as_functions=True,
optimize=False,
verbose=0,
)
print(pretty_onnx(onx))
>>>
opset: domain='' version=18
opset: domain='aten_local_function' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[1, 5]
<locals>.SubNeuron2[aten_local_function](x) -> neuron
Relu(neuron) -> relu
Identity(relu) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 3]
----- function name=Linear domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'x'
input: 'weight'
input: 'bias'
Transpose(weight, perm=[1,0]) -> _onx_transpose0
MatMul(x, _onx_transpose0) -> _onx_matmul0
Add(_onx_matmul0, bias) -> linear
Identity(linear) -> output
output: name='output' type=? shape=?
----- function name=<locals>.SubNeuron2 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'x'
Constant(value=[[-0.40519...) -> neuron.linear1.weight
Constant(value=[-0.117329...) -> neuron.linear1.bias
Linear[aten_local_function](x, neuron.linear1.weight, neuron.linear1.bias) -> linear1
Sigmoid(linear1) -> sigmoid
Constant(value=[[-0.18301...) -> neuron.linear2.weight
Constant(value=[-0.075738...) -> neuron.linear2.bias
Linear[aten_local_function](x, neuron.linear2.weight, neuron.linear2.bias) -> linear2
Sigmoid(linear2) -> sigmoid_1
Add(sigmoid, sigmoid_1) -> add
Identity(add) -> output
output: name='output' type=? shape=?