Note
Go to the end to download the full example code.
101: Some dummy examples with torch.export.export¶
torch.export.export()
behaviour in various situations.
Easy Case¶
A simple model.
import torch
class Neuron(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
z = self.linear(x)
return torch.sigmoid(z)
exported_program = torch.export.export(Neuron(), (torch.randn(1, 5),))
print(exported_program.graph)
graph():
%p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
%p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
%x : [num_users=1] = placeholder[target=x]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
%sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
return (sigmoid,)
With an integer as input¶
As torch.export.export
documentation, integer do not show up on the graph.
An exporter based on torch.export.export()
cannot consider
the integer as an input.
class NeuronIInt(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x: torch.Tensor, i_input: int):
z = self.linear(x)
return torch.sigmoid(z)[:, i_input]
exported_program = torch.export.export(NeuronIInt(), (torch.randn(1, 5), 2))
print(exported_program.graph)
graph():
%p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
%p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
%x : [num_users=1] = placeholder[target=x]
%i_input : [num_users=0] = placeholder[target=i_input]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
%sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%sigmoid,), kwargs = {})
%select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 2), kwargs = {})
return (select,)
With an integer as input¶
But if the integer is wrapped into a Tensor, it works.
class NeuronIInt(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x: torch.Tensor, i_input):
z = self.linear(x)
return torch.sigmoid(z)[:, i_input]
exported_program = torch.export.export(
NeuronIInt(), (torch.randn(1, 5), torch.Tensor([2]).to(torch.int32))
)
print(exported_program.graph)
graph():
%p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
%p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
%x : [num_users=1] = placeholder[target=x]
%i_input : [num_users=1] = placeholder[target=i_input]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
%sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%sigmoid, 0, 0, 9223372036854775807), kwargs = {})
%index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%slice_1, [None, %i_input]), kwargs = {})
return (index,)
Wrapped¶
Wrapped, it continues to work.
class WrappedNeuronIInt(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, *args, **kwargs):
return self.model.forward(*args, **kwargs)
exported_program = torch.export.export(
WrappedNeuronIInt(NeuronIInt()), (torch.randn(1, 5), torch.Tensor([2]).to(torch.int32))
)
print(exported_program.graph)
graph():
%p_model_linear_weight : [num_users=1] = placeholder[target=p_model_linear_weight]
%p_model_linear_bias : [num_users=1] = placeholder[target=p_model_linear_bias]
%args_0 : [num_users=1] = placeholder[target=args_0]
%args_1 : [num_users=1] = placeholder[target=args_1]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%args_0, %p_model_linear_weight, %p_model_linear_bias), kwargs = {})
%sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%sigmoid, 0, 0, 9223372036854775807), kwargs = {})
%index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%slice_1, [None, %args_1]), kwargs = {})
return (index,)
List¶
The last one does not export. An exporter based on
torch.export.export()
cannot work.
class NeuronNoneListInt(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x, yz, i_input):
z = self.linear(x + yz[0] * yz[3])
return torch.sigmoid(z)[:i_input]
try:
exported_program = torch.export.export(
NeuronNoneListInt(),
(
torch.randn(1, 5),
[torch.randn(1, 5), None, None, torch.randn(1, 5)],
torch.Tensor([2]).to(torch.int32),
),
)
print(exported_program.graph)
except (torch._dynamo.exc.Unsupported, RuntimeError) as e:
print(f"-- an error {type(e)} occured:")
print(e)
-- an error <class 'RuntimeError'> occured:
Overloaded torch operator invoked from Python failed to match any schema:
aten::slice() Expected a value of type 'Optional[int]' for argument 'end' but instead found type 'FakeTensor'.
Position: 3
Value: FakeTensor(..., size=(1,), dtype=torch.int32)
Declaration: aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
Cast error details: Unable to cast Python instance of type <class 'torch._subclasses.fake_tensor.FakeTensor'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)
aten::slice() expected at most 4 argument(s) but received 5 argument(s). Declaration: aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]
aten::slice() expected at most 4 argument(s) but received 5 argument(s). Declaration: aten::slice.str(str string, int? start=None, int? end=None, int step=1) -> str
Loops¶
Loops are not captured.
class NeuronLoop(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x, xs):
z = self.linear(x)
for i in range(len(xs)):
x += xs[i] * (i + 1)
return z
exported_program = torch.export.export(
NeuronLoop(),
(
torch.randn(1, 5),
[torch.randn(1, 5), torch.randn(1, 5)],
),
)
print(exported_program.graph)
graph():
%p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
%p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
%x : [num_users=2] = placeholder[target=x]
%xs_0 : [num_users=1] = placeholder[target=xs_0]
%xs_1 : [num_users=1] = placeholder[target=xs_1]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%xs_0, 1), kwargs = {})
%add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %mul), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%xs_1, 2), kwargs = {})
%add__1 : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%add_, %mul_1), kwargs = {})
return (linear,)
Export for training¶
In that case, the weights are exported as inputs.
class Neuron(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
z = self.linear(x)
return torch.sigmoid(z)
print("-- training")
mod = Neuron()
mod.train()
exported_program = torch.export.export_for_training(mod, (torch.randn(1, 5),))
print(exported_program.graph)
-- training
graph():
%p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
%p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
%x : [num_users=1] = placeholder[target=x]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
%sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
return (sigmoid,)
Preserve Modules¶
class Neuron(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
z = self.linear(x)
return torch.sigmoid(z)
class NeuronNeuron(torch.nn.Module):
def __init__(self, n_dims: int = 5, n_targets: int = 3):
super().__init__()
self.my_neuron = Neuron(n_dims, n_targets)
def forward(self, x):
z = self.my_neuron(x)
return -z
The list of the modules.
mod = NeuronNeuron()
for item in mod.named_modules():
print(item)
('', NeuronNeuron(
(my_neuron): Neuron(
(linear): Linear(in_features=5, out_features=3, bias=True)
)
))
('my_neuron', Neuron(
(linear): Linear(in_features=5, out_features=3, bias=True)
))
('my_neuron.linear', Linear(in_features=5, out_features=3, bias=True))
The exported module did not change.
print("-- preserved?")
exported_program = torch.export.export(
mod, (torch.randn(1, 5),), preserve_module_call_signature=("my_neuron",)
)
print(exported_program.graph)
-- preserved?
graph():
%p_my_neuron_linear_weight : [num_users=1] = placeholder[target=p_my_neuron_linear_weight]
%p_my_neuron_linear_bias : [num_users=1] = placeholder[target=p_my_neuron_linear_bias]
%x : [num_users=1] = placeholder[target=x]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_my_neuron_linear_weight, %p_my_neuron_linear_bias), kwargs = {})
%sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
%neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%sigmoid,), kwargs = {})
return (neg,)
And now?
import torch.export._swap
swapped_gm = torch.export._swap._swap_modules(exported_program, {"my_neuron": Neuron()})
print("-- preserved?")
print(swapped_gm.graph)
~/vv/this312/lib/python3.12/site-packages/torch/export/unflatten.py:872: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
spec_node = gm.graph.get_attr(name)
~/vv/this312/lib/python3.12/site-packages/torch/export/unflatten.py:864: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
spec_node = gm.graph.get_attr(name)
-- preserved?
graph():
%x_1 : [num_users=1] = placeholder[target=x]
%_spec_0 : [num_users=1] = get_attr[target=_spec_0]
%_spec_1 : [num_users=1] = get_attr[target=_spec_1]
%_spec_2 : [num_users=1] = get_attr[target=_spec_2]
%tree_flatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_flatten](args = ((%x_1,),), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten, 0), kwargs = {})
%x : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
%tree_unflatten_1 : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x], %_spec_1), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_1, 0), kwargs = {})
%my_neuron : [num_users=1] = call_module[target=my_neuron](args = (%getitem_2,), kwargs = {})
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%my_neuron, %_spec_2), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
%neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_4,), kwargs = {})
%tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%neg,), %_spec_0), kwargs = {})
return tree_unflatten
Unfortunately this approach does not work well on big models and it is a provite API.
Total running time of the script: (0 minutes 0.325 seconds)
Related examples

201: Evaluate different ways to export a torch model to ONNX