101: Some dummy examples with torch.export.export

torch.export.export() behaviour in various situations.

Easy Case

A simple model.

import torch


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        z = self.linear(x)
        return torch.sigmoid(z)


exported_program = torch.export.export(Neuron(), (torch.randn(1, 5),))
print(exported_program.graph)
graph():
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    return (sigmoid,)

With an integer as input

As torch.export.export documentation, integer do not show up on the graph. An exporter based on torch.export.export() cannot consider the integer as an input.

class NeuronIInt(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x: torch.Tensor, i_input: int):
        z = self.linear(x)
        return torch.sigmoid(z)[:, i_input]


exported_program = torch.export.export(NeuronIInt(), (torch.randn(1, 5), 2))
print(exported_program.graph)
graph():
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %i_input : [num_users=0] = placeholder[target=i_input]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%sigmoid, 1, 2), kwargs = {})
    return (select,)

With an integer as input

But if the integer is wrapped into a Tensor, it works.

class NeuronIInt(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x: torch.Tensor, i_input):
        z = self.linear(x)
        return torch.sigmoid(z)[:, i_input]


exported_program = torch.export.export(
    NeuronIInt(), (torch.randn(1, 5), torch.Tensor([2]).to(torch.int32))
)
print(exported_program.graph)
graph():
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %i_input : [num_users=1] = placeholder[target=i_input]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%sigmoid, [None, %i_input]), kwargs = {})
    return (index,)

Wrapped

Wrapped, it continues to work.

class WrappedNeuronIInt(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, *args, **kwargs):
        return self.model.forward(*args, **kwargs)


exported_program = torch.export.export(
    WrappedNeuronIInt(NeuronIInt()), (torch.randn(1, 5), torch.Tensor([2]).to(torch.int32))
)
print(exported_program.graph)
graph():
    %p_model_linear_weight : [num_users=1] = placeholder[target=p_model_linear_weight]
    %p_model_linear_bias : [num_users=1] = placeholder[target=p_model_linear_bias]
    %args_0 : [num_users=1] = placeholder[target=args_0]
    %args_1 : [num_users=1] = placeholder[target=args_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%args_0, %p_model_linear_weight, %p_model_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%sigmoid, [None, %args_1]), kwargs = {})
    return (index,)

List

The last one does not export. An exporter based on torch.export.export() cannot work.

class NeuronNoneListInt(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x, yz, i_input):
        z = self.linear(x + yz[0] * yz[3])
        return torch.sigmoid(z)[:i_input]


try:
    exported_program = torch.export.export(
        NeuronNoneListInt(),
        (
            torch.randn(1, 5),
            [torch.randn(1, 5), None, None, torch.randn(1, 5)],
            torch.Tensor([2]).to(torch.int32),
        ),
    )
    print(exported_program.graph)
except (torch._dynamo.exc.Unsupported, RuntimeError) as e:
    print(f"-- an error {type(e)} occured:")
    print(e)
def forward(self, arg0_1: "f32[3, 5]", arg1_1: "f32[3]", arg2_1: "f32[1, 5]", arg3_1: "f32[1, 5]", arg4_1, arg5_1, arg6_1: "f32[1, 5]", arg7_1: "i32[1]"):
     # File: ~/github/experimental-experiment/_doc/examples/plot_torch_export_101.py:114 in forward, code: z = self.linear(x + yz[0] * yz[3])
    mul: "f32[1, 5]" = torch.ops.aten.mul.Tensor(arg3_1, arg6_1);  arg3_1 = arg6_1 = None
    add: "f32[1, 5]" = torch.ops.aten.add.Tensor(arg2_1, mul);  arg2_1 = mul = None

     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[1, 3]" = torch.ops.aten.linear.default(add, arg0_1, arg1_1);  add = arg0_1 = arg1_1 = None

     # File: ~/github/experimental-experiment/_doc/examples/plot_torch_export_101.py:115 in forward, code: return torch.sigmoid(z)[:i_input]
    sigmoid: "f32[1, 3]" = torch.ops.aten.sigmoid.default(linear);  linear = sigmoid = None
    item: "Sym(u0)" = torch.ops.aten.item.default(arg7_1);  arg7_1 = item = None




def forward(self, arg0_1: "f32[3, 5]", arg1_1: "f32[3]", arg2_1: "f32[1, 5]", arg3_1: "f32[1, 5]", arg4_1, arg5_1, arg6_1: "f32[1, 5]", arg7_1: "i32[1]"):
     # File: ~/github/experimental-experiment/_doc/examples/plot_torch_export_101.py:114 in forward, code: z = self.linear(x + yz[0] * yz[3])
    mul: "f32[1, 5]" = torch.ops.aten.mul.Tensor(arg3_1, arg6_1);  arg3_1 = arg6_1 = None
    add: "f32[1, 5]" = torch.ops.aten.add.Tensor(arg2_1, mul);  arg2_1 = mul = None

     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[1, 3]" = torch.ops.aten.linear.default(add, arg0_1, arg1_1);  add = arg0_1 = arg1_1 = None

     # File: ~/github/experimental-experiment/_doc/examples/plot_torch_export_101.py:115 in forward, code: return torch.sigmoid(z)[:i_input]
    sigmoid: "f32[1, 3]" = torch.ops.aten.sigmoid.default(linear);  linear = sigmoid = None
    item: "Sym(u0)" = torch.ops.aten.item.default(arg7_1);  arg7_1 = item = None

-- an error <class 'torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode'> occured:
Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:1066 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "~/github/experimental-experiment/_doc/examples/plot_torch_export_101.py", line 115, in forward
    return torch.sigmoid(z)[:i_input]


The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

Loops

Loops are not captured.

class NeuronLoop(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x, xs):
        z = self.linear(x)
        for i in range(len(xs)):
            x += xs[i] * (i + 1)
        return z


exported_program = torch.export.export(
    NeuronLoop(),
    (
        torch.randn(1, 5),
        [torch.randn(1, 5), torch.randn(1, 5)],
    ),
)
print(exported_program.graph)
graph():
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=2] = placeholder[target=x]
    %xs_0 : [num_users=1] = placeholder[target=xs_0]
    %xs_1 : [num_users=1] = placeholder[target=xs_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%xs_0, 1), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %mul), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%xs_1, 2), kwargs = {})
    %add__1 : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%add_, %mul_1), kwargs = {})
    return (linear,)

Export for training

In that case, the weights are exported as inputs.

class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        z = self.linear(x)
        return torch.sigmoid(z)


print("-- training")
mod = Neuron()
mod.train()
exported_program = torch.export.export_for_training(mod, (torch.randn(1, 5),))
print(exported_program.graph)
-- training
~/github/experimental-experiment/_doc/examples/plot_torch_export_101.py:181: FutureWarning: `torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. Please use `torch.export.export` instead, which is functionally equivalent.
  exported_program = torch.export.export_for_training(mod, (torch.randn(1, 5),))
graph():
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    return (sigmoid,)

Preserve Modules

class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        z = self.linear(x)
        return torch.sigmoid(z)


class NeuronNeuron(torch.nn.Module):
    def __init__(self, n_dims: int = 5, n_targets: int = 3):
        super().__init__()
        self.my_neuron = Neuron(n_dims, n_targets)

    def forward(self, x):
        z = self.my_neuron(x)
        return -z

The list of the modules.

mod = NeuronNeuron()
for item in mod.named_modules():
    print(item)
('', NeuronNeuron(
  (my_neuron): Neuron(
    (linear): Linear(in_features=5, out_features=3, bias=True)
  )
))
('my_neuron', Neuron(
  (linear): Linear(in_features=5, out_features=3, bias=True)
))
('my_neuron.linear', Linear(in_features=5, out_features=3, bias=True))

The exported module did not change.

print("-- preserved?")
exported_program = torch.export.export(
    mod, (torch.randn(1, 5),), preserve_module_call_signature=("my_neuron",)
)
print(exported_program.graph)
-- preserved?
graph():
    %p_my_neuron_linear_weight : [num_users=1] = placeholder[target=p_my_neuron_linear_weight]
    %p_my_neuron_linear_bias : [num_users=1] = placeholder[target=p_my_neuron_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_my_neuron_linear_weight, %p_my_neuron_linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%sigmoid,), kwargs = {})
    return (neg,)

And now?

import torch.export._swap

swapped_gm = torch.export._swap._swap_modules(exported_program, {"my_neuron": Neuron()})

print("-- preserved?")
print(swapped_gm.graph)
~/vv/this312/lib/python3.12/site-packages/torch/export/unflatten.py:975: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  spec_node = gm.graph.get_attr(name)
~/vv/this312/lib/python3.12/site-packages/torch/export/unflatten.py:967: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  spec_node = gm.graph.get_attr(name)
-- preserved?
graph():
    %x_1 : [num_users=1] = placeholder[target=x]
    %_spec_0 : [num_users=1] = get_attr[target=_spec_0]
    %_spec_1 : [num_users=1] = get_attr[target=_spec_1]
    %_spec_2 : [num_users=1] = get_attr[target=_spec_2]
    %tree_flatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_flatten](args = ((%x_1,),), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten, 0), kwargs = {})
    %x : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
    %tree_unflatten_1 : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x], %_spec_1), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_1, 0), kwargs = {})
    %my_neuron : [num_users=1] = call_module[target=my_neuron](args = (%getitem_2,), kwargs = {})
    %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%my_neuron, %_spec_2), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
    %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_4,), kwargs = {})
    %tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%neg,), %_spec_0), kwargs = {})
    return tree_unflatten

Unfortunately this approach does not work well on big models and it is a provite API.

Total running time of the script: (0 minutes 0.210 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

201: Use torch to export a scikit-learn model into ONNX

201: Use torch to export a scikit-learn model into ONNX

101: A custom backend for torch

101: A custom backend for torch

101: Linear Regression and export to ONNX

101: Linear Regression and export to ONNX

102: Tweak onnx export

102: Tweak onnx export

Gallery generated by Sphinx-Gallery