102: Tweak onnx export

export, unflatten and compile

import torch
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx

class SubNeuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        z = self.linear(x)
        return torch.sigmoid(z)

class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.neuron = SubNeuron(n_dims, n_targets)

    def forward(self, x):
        z = self.neuron(x)
        return torch.relu(z)

model = Neuron()
inputs = (torch.randn(1, 5),)
expected = model(*inputs)
exported_program = torch.export.export(model, inputs)

print("-- fx graph with torch.export.export")
-- fx graph with torch.export.export
    %p_neuron_linear_weight : [num_users=1] = placeholder[target=p_neuron_linear_weight]
    %p_neuron_linear_bias : [num_users=1] = placeholder[target=p_neuron_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_neuron_linear_weight, %p_neuron_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%sigmoid,), kwargs = {})
    return (relu,)

The export keeps track of the submodules calls.

print("-- module_call_graph", type(exported_program.module_call_graph))
-- module_call_graph <class 'list'>
[ModuleCallEntry(fqn='', signature=ModuleCallSignature(inputs=[], outputs=[], in_spec=TreeSpec(tuple, None, [TreeSpec(tuple, None, [*]),
  TreeSpec(dict, [], [])]), out_spec=*, forward_arg_names=['x'])), ModuleCallEntry(fqn='neuron', signature=None), ModuleCallEntry(fqn='neuron.linear', signature=None)]

That information can be converted back into a exported program.

ep = torch.export.unflatten(exported_program)
print("-- unflatten", type(exported_program.graph))
-- unflatten <class 'torch.fx.graph.Graph'>
    %x : [num_users=1] = placeholder[target=x]
    %neuron : [num_users=1] = call_module[target=neuron](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%neuron,), kwargs = {})
    return (relu,)

Another graph obtained with torch.compile.

def my_compiler(gm, example_inputs):
    print("-- graph with torch.compile")
    return gm.forward

optimized_mod = torch.compile(model, fullgraph=True, backend=my_compiler)
-- graph with torch.compile
    %l_self_modules_neuron_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_neuron_modules_linear_parameters_weight_]
    %l_self_modules_neuron_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_neuron_modules_linear_parameters_bias_]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %z : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%l_x_, %l_self_modules_neuron_modules_linear_parameters_weight_, %l_self_modules_neuron_modules_linear_parameters_bias_), kwargs = {})
    %z_1 : [num_users=1] = call_function[target=torch.sigmoid](args = (%z,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%z_1,), kwargs = {})
    return (relu,)

tensor([[0.4336, 0.2997, 0.5592]], grad_fn=<ReluBackward0>)


class SubNeuron2(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        z = self.linear(x)
        return torch.sigmoid(z)

class Neuron2(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.neuron = SubNeuron2(n_dims, n_targets)

    def forward(self, x):
        z = self.neuron(x)
        return torch.relu(z)

model = Neuron2()
inputs = (torch.randn(1, 5),)
expected = model(*inputs)

onx = to_onnx(model, inputs)
opset: domain='' version=18
input: name='x' type=dtype('float32') shape=[1, 5]
init: name='neuron.linear.weight' type=float32 shape=(3, 5)           -- DynamoInterpret.placeholder.1/P(neuron.linear.weight)
init: name='neuron.linear.bias' type=float32 shape=(3,) -- array([ 0.05158822, -0.40079534, -0.25078472], dtype=float32)-- DynamoInterpret.placeholder.1/P(neuron.linear.bias)
Gemm(x, neuron.linear.weight, neuron.linear.bias, transB=1) -> linear
  Sigmoid(linear) -> sigmoid
    Relu(sigmoid) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 3]

Let’s preserve the module.

onx = to_onnx(model, inputs, export_modules_as_functions=True)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='x' type=dtype('float32') shape=[1, 5]
__main__.SubNeuron2[aten_local_function](x) -> neuron
  Relu(neuron) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 3]
----- function name=Linear domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'x'
Constant(value=[[0.013307...) -> weight
Constant(value=[0.0515882...) -> bias
  Gemm(x, weight, bias, transB=1) -> output
Constant(value=[[0.013307...) -> neuron.linear.weight
Constant(value=[0.0515882...) -> neuron.linear.bias
output: name='output' type=? shape=?
----- function name=__main__.SubNeuron2 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'x'
Linear[aten_local_function](x) -> linear
  Sigmoid(linear) -> output
output: name='output' type=? shape=?

Total running time of the script: (0 minutes 0.361 seconds)

