Note
Go to the end to download the full example code.
102: Tweak onnx export¶
export, unflatten and compile¶
import torch
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx
class SubNeuron(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
z = self.linear(x)
return torch.sigmoid(z)
class Neuron(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.neuron = SubNeuron(n_dims, n_targets)
def forward(self, x):
z = self.neuron(x)
return torch.relu(z)
model = Neuron()
inputs = (torch.randn(1, 5),)
expected = model(*inputs)
exported_program = torch.export.export(model, inputs)
print("-- fx graph with torch.export.export")
print(exported_program.graph)
-- fx graph with torch.export.export
graph():
%p_neuron_linear_weight : [num_users=1] = placeholder[target=p_neuron_linear_weight]
%p_neuron_linear_bias : [num_users=1] = placeholder[target=p_neuron_linear_bias]
%x : [num_users=1] = placeholder[target=x]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_neuron_linear_weight, %p_neuron_linear_bias), kwargs = {})
%sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%sigmoid,), kwargs = {})
return (relu,)
The export keeps track of the submodules calls.
print("-- module_call_graph", type(exported_program.module_call_graph))
print(exported_program.module_call_graph)
-- module_call_graph <class 'list'>
[ModuleCallEntry(fqn='', signature=ModuleCallSignature(inputs=[], outputs=[], in_spec=TreeSpec(tuple, None, [TreeSpec(tuple, None, [*]),
TreeSpec(dict, [], [])]), out_spec=*, forward_arg_names=['x'])), ModuleCallEntry(fqn='neuron', signature=None), ModuleCallEntry(fqn='neuron.linear', signature=None)]
That information can be converted back into a exported program.
ep = torch.export.unflatten(exported_program)
print("-- unflatten", type(exported_program.graph))
print(ep.graph)
-- unflatten <class 'torch.fx.graph.Graph'>
graph():
%x : [num_users=1] = placeholder[target=x]
%neuron : [num_users=1] = call_module[target=neuron](args = (%x,), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%neuron,), kwargs = {})
return (relu,)
Another graph obtained with torch.compile.
def my_compiler(gm, example_inputs):
print("-- graph with torch.compile")
print(gm.graph)
return gm.forward
optimized_mod = torch.compile(model, fullgraph=True, backend=my_compiler)
optimized_mod(*inputs)
-- graph with torch.compile
graph():
%l_self_modules_neuron_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_neuron_modules_linear_parameters_weight_]
%l_self_modules_neuron_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_neuron_modules_linear_parameters_bias_]
%l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
%z : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%l_x_, %l_self_modules_neuron_modules_linear_parameters_weight_, %l_self_modules_neuron_modules_linear_parameters_bias_), kwargs = {})
%z_1 : [num_users=1] = call_function[target=torch.sigmoid](args = (%z,), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.relu](args = (%z_1,), kwargs = {})
return (relu,)
tensor([[0.5898, 0.6889, 0.4627]], grad_fn=<ReluBackward0>)
Unflattened¶
class SubNeuron2(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
z = self.linear(x)
return torch.sigmoid(z)
class Neuron2(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.neuron = SubNeuron2(n_dims, n_targets)
def forward(self, x):
z = self.neuron(x)
return torch.relu(z)
model = Neuron2()
inputs = (torch.randn(1, 5),)
expected = model(*inputs)
onx = to_onnx(model, inputs)
print(pretty_onnx(onx))
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[1, 5]
init: name='p_neuron_linear_weight' type=dtype('float32') shape=(3, 5)
init: name='p_neuron_linear_bias' type=dtype('float32') shape=(3,) -- array([ 0.2922235 , -0.24605706, -0.41816953], dtype=float32)
Gemm(x, p_neuron_linear_weight, p_neuron_linear_bias, transB=1) -> linear
Sigmoid(linear) -> sigmoid
Relu(sigmoid) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 3]
Let’s preserve the module.
onx = to_onnx(model, inputs, export_modules_as_functions=True)
print(pretty_onnx(onx))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[1, 5]
__main__.SubNeuron2[aten_local_function](x) -> neuron
Relu(neuron) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 3]
----- function name=Linear domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'x'
input: 'weight'
input: 'bias'
Transpose(weight, perm=[1,0]) -> _onx_transpose0
Transpose(_onx_transpose0, perm=[1,0]) -> GemmTransposePattern--_onx_transpose0
Gemm(x, GemmTransposePattern--_onx_transpose0, bias, transB=1) -> output
output: name='output' type=? shape=?
----- function name=__main__.SubNeuron2 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'x'
Constant(value=[[0.379791...) -> neuron.linear.weight
Constant(value=[0.2922235...) -> neuron.linear.bias
Linear[aten_local_function](x, neuron.linear.weight, neuron.linear.bias) -> linear
Sigmoid(linear) -> output
output: name='output' type=? shape=?
Total running time of the script: (0 minutes 3.409 seconds)