101: Some dummy examples with torch.export.export

torch.export.export() behaviour in various situations.

Easy Case

A simple model.

import torch

class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        z = self.linear(x)
        return torch.sigmoid(z)

exported_program = torch.export.export(Neuron(), (torch.randn(1, 5),))
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    return (sigmoid,)

With an integer as input

As torch.export.export documentation, integer do not show up on the graph. An exporter based on torch.export.export() cannot consider the integer as an input.

class NeuronIInt(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x: torch.Tensor, i_input: int):
        z = self.linear(x)
        return torch.sigmoid(z)[:, i_input]

exported_program = torch.export.export(NeuronIInt(), (torch.randn(1, 5), 2))
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %i_input : [num_users=0] = placeholder[target=i_input]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%sigmoid, 0, 0, 9223372036854775807), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 2), kwargs = {})
    return (select,)

With an integer as input

But if the integer is wrapped into a Tensor, it works.

class NeuronIInt(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x: torch.Tensor, i_input):
        z = self.linear(x)
        return torch.sigmoid(z)[:, i_input]

exported_program = torch.export.export(
    NeuronIInt(), (torch.randn(1, 5), torch.Tensor([2]).to(torch.int32))
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %i_input : [num_users=1] = placeholder[target=i_input]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%sigmoid, 0, 0, 9223372036854775807), kwargs = {})
    %index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%slice_1, [None, %i_input]), kwargs = {})
    return (index,)


Wrapped, it continues to work.

class WrappedNeuronIInt(torch.nn.Module):
    def __init__(self, model):
        self.model = model

    def forward(self, *args, **kwargs):
        return self.model.forward(*args, **kwargs)

exported_program = torch.export.export(
    WrappedNeuronIInt(NeuronIInt()), (torch.randn(1, 5), torch.Tensor([2]).to(torch.int32))
    %p_model_linear_weight : [num_users=1] = placeholder[target=p_model_linear_weight]
    %p_model_linear_bias : [num_users=1] = placeholder[target=p_model_linear_bias]
    %args_0 : [num_users=1] = placeholder[target=args_0]
    %args_1 : [num_users=1] = placeholder[target=args_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%args_0, %p_model_linear_weight, %p_model_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%sigmoid, 0, 0, 9223372036854775807), kwargs = {})
    %index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%slice_1, [None, %args_1]), kwargs = {})
    return (index,)


The last one does not export. An exporter based on torch.export.export() cannot work.

class NeuronNoneListInt(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x, yz, i_input):
        z = self.linear(x + yz[0] * yz[3])
        return torch.sigmoid(z)[:i_input]

    exported_program = torch.export.export(
            torch.randn(1, 5),
            [torch.randn(1, 5), None, None, torch.randn(1, 5)],
except (torch._dynamo.exc.Unsupported, RuntimeError) as e:
    print(f"-- an error {type(e)} occured:")
-- an error <class 'torch._dynamo.exc.InternalTorchDynamoError'> occured:
RuntimeError: Node z referenced target L__self___linear but that target was not provided in ``root``!

from user code:
   File "/home/xadupre/github/experimental-experiment/_doc/examples/plot_torch_export_101.py", line 115, in forward
    return torch.sigmoid(z)[:i_input]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


Loops are not captured.

class NeuronLoop(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x, xs):
        z = self.linear(x)
        for i in range(len(xs)):
            x += xs[i] * (i + 1)
        return z

exported_program = torch.export.export(
        torch.randn(1, 5),
        [torch.randn(1, 5), torch.randn(1, 5)],
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=2] = placeholder[target=x]
    %xs_0 : [num_users=1] = placeholder[target=xs_0]
    %xs_1 : [num_users=1] = placeholder[target=xs_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%xs_0, 1), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %mul), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%xs_1, 2), kwargs = {})
    %add__1 : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%add_, %mul_1), kwargs = {})
    return (linear,)

Export for training

In that case, the weights are exported as inputs.

class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        z = self.linear(x)
        return torch.sigmoid(z)

print("-- training")
mod = Neuron()
exported_program = torch.export.export_for_training(mod, (torch.randn(1, 5),))
-- training
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    return (sigmoid,)

Preserve Modules

class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        z = self.linear(x)
        return torch.sigmoid(z)

class NeuronNeuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        self.my_neuron = Neuron(n_dims, n_targets)

    def forward(self, x):
        z = self.my_neuron(x)
        return -z

The list of the modules.

mod = NeuronNeuron()
for item in mod.named_modules():
('', NeuronNeuron(
  (my_neuron): Neuron(
    (linear): Linear(in_features=5, out_features=3, bias=True)
('my_neuron', Neuron(
  (linear): Linear(in_features=5, out_features=3, bias=True)
('my_neuron.linear', Linear(in_features=5, out_features=3, bias=True))

The exported module did not change.

print("-- preserved?")
exported_program = torch.export.export(
    mod, (torch.randn(1, 5),), preserve_module_call_signature=("my_neuron",)
-- preserved?
    %p_my_neuron_linear_weight : [num_users=1] = placeholder[target=p_my_neuron_linear_weight]
    %p_my_neuron_linear_bias : [num_users=1] = placeholder[target=p_my_neuron_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_my_neuron_linear_weight, %p_my_neuron_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%sigmoid,), kwargs = {})
    return (neg,)

And now?

import torch.export._swap

swapped_gm = torch.export._swap._swap_modules(exported_program, {"my_neuron": Neuron()})

print("-- preserved?")
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/unflatten.py:848: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  spec_node = gm.graph.get_attr(name)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/unflatten.py:840: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  spec_node = gm.graph.get_attr(name)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/graph.py:1794: UserWarning: Node _spec_0 target _spec_0 _spec_0 of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/graph.py:1794: UserWarning: Node _spec_1 target _spec_1 _spec_1 of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/graph.py:1794: UserWarning: Node _spec_2 target _spec_2 _spec_2 of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
-- preserved?
    %x_1 : [num_users=1] = placeholder[target=x]
    %_spec_0 : [num_users=1] = get_attr[target=_spec_0]
    %_spec_1 : [num_users=1] = get_attr[target=_spec_1]
    %_spec_2 : [num_users=1] = get_attr[target=_spec_2]
    %tree_flatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_flatten](args = ((%x_1,),), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten, 0), kwargs = {})
    %x : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
    %tree_unflatten_1 : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x], %_spec_1), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_1, 0), kwargs = {})
    %my_neuron : [num_users=1] = call_module[target=my_neuron](args = (%getitem_2,), kwargs = {})
    %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%my_neuron, %_spec_2), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
    %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_4,), kwargs = {})
    %tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%neg,), %_spec_0), kwargs = {})
    return tree_unflatten

Unfortunately this approach does not work well on big models and it is a provite API.

Total running time of the script: (0 minutes 0.706 seconds)

Related examples

102: Tweak onnx export

102: Tweak onnx export

101: A custom backend for torch

101: A custom backend for torch

101: Onnx Model Rewriting

101: Onnx Model Rewriting

201: Use torch to export a scikit-learn model into ONNX

201: Use torch to export a scikit-learn model into ONNX

102: Convolution and Matrix Multiplication

102: Convolution and Matrix Multiplication

Gallery generated by Sphinx-Gallery