102: Tweak onnx export

export, unflatten and compile

import torch
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx


class SubNeuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        z = self.linear(x)
        return torch.sigmoid(z)


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.neuron = SubNeuron(n_dims, n_targets)

    def forward(self, x):
        z = self.neuron(x)
        return torch.relu(z)


model = Neuron()
inputs = (torch.randn(1, 5),)
expected = model(*inputs)
exported_program = torch.export.export(model, inputs)

print("-- fx graph with torch.export.export")
print(exported_program.graph)
-- fx graph with torch.export.export
graph():
    %p_neuron_linear_weight : [num_users=1] = placeholder[target=p_neuron_linear_weight]
    %p_neuron_linear_bias : [num_users=1] = placeholder[target=p_neuron_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_neuron_linear_weight, %p_neuron_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%sigmoid,), kwargs = {})
    return (relu,)

The export keeps track of the submodules calls.

print("-- module_call_graph", type(exported_program.module_call_graph))
print(exported_program.module_call_graph)
-- module_call_graph <class 'list'>
[ModuleCallEntry(fqn='', signature=ModuleCallSignature(inputs=[], outputs=[], in_spec=TreeSpec(tuple, None, [TreeSpec(tuple, None, [*]),
  TreeSpec(dict, [], [])]), out_spec=*, forward_arg_names=['x'])), ModuleCallEntry(fqn='neuron', signature=None), ModuleCallEntry(fqn='neuron.linear', signature=None)]

That information can be converted back into a exported program.

ep = torch.export.unflatten(exported_program)
print("-- unflatten", type(exported_program.graph))
print(ep.graph)
-- unflatten <class 'torch.fx.graph.Graph'>
graph():
    %x : [num_users=1] = placeholder[target=x]
    %neuron : [num_users=1] = call_module[target=neuron](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%neuron,), kwargs = {})
    return (relu,)

Another graph obtained with torch.compile.

def my_compiler(gm, example_inputs):
    print("-- graph with torch.compile")
    print(gm.graph)
    return gm.forward


optimized_mod = torch.compile(model, fullgraph=True, backend=my_compiler)
optimized_mod(*inputs)
-- graph with torch.compile
graph():
    %l_self_modules_neuron_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_neuron_modules_linear_parameters_weight_]
    %l_self_modules_neuron_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_neuron_modules_linear_parameters_bias_]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %z : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%l_x_, %l_self_modules_neuron_modules_linear_parameters_weight_, %l_self_modules_neuron_modules_linear_parameters_bias_), kwargs = {})
    %z_1 : [num_users=1] = call_function[target=torch.sigmoid](args = (%z,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%z_1,), kwargs = {})
    return (relu,)

tensor([[0.5898, 0.6889, 0.4627]], grad_fn=<ReluBackward0>)

Unflattened

class SubNeuron2(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        z = self.linear(x)
        return torch.sigmoid(z)


class Neuron2(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.neuron = SubNeuron2(n_dims, n_targets)

    def forward(self, x):
        z = self.neuron(x)
        return torch.relu(z)


model = Neuron2()
inputs = (torch.randn(1, 5),)
expected = model(*inputs)

onx = to_onnx(model, inputs)
print(pretty_onnx(onx))
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[1, 5]
init: name='p_neuron_linear_weight' type=dtype('float32') shape=(3, 5)
init: name='p_neuron_linear_bias' type=dtype('float32') shape=(3,) -- array([ 0.2922235 , -0.24605706, -0.41816953], dtype=float32)
Gemm(x, p_neuron_linear_weight, p_neuron_linear_bias, transB=1) -> linear
  Sigmoid(linear) -> sigmoid
    Relu(sigmoid) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 3]

Let’s preserve the module.

onx = to_onnx(model, inputs, export_modules_as_functions=True)
print(pretty_onnx(onx))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[1, 5]
__main__.SubNeuron2[aten_local_function](x) -> neuron
  Relu(neuron) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 3]
----- function name=Linear domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'x'
input: 'weight'
input: 'bias'
Transpose(weight, perm=[1,0]) -> _onx_transpose0
  Transpose(_onx_transpose0, perm=[1,0]) -> GemmTransposePattern--_onx_transpose0
    Gemm(x, GemmTransposePattern--_onx_transpose0, bias, transB=1) -> output
output: name='output' type=? shape=?
----- function name=__main__.SubNeuron2 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'x'
Constant(value=[[0.379791...) -> neuron.linear.weight
Constant(value=[0.2922235...) -> neuron.linear.bias
  Linear[aten_local_function](x, neuron.linear.weight, neuron.linear.bias) -> linear
    Sigmoid(linear) -> output
output: name='output' type=? shape=?

Total running time of the script: (0 minutes 3.409 seconds)

Gallery generated by Sphinx-Gallery