Parameter torch.export.export(…, strict: bool)¶
The exporter relies on torch.export.export()
. It exposes a parameter called
strict: bool = True (true by default).
The behaviour is different in some specific configuration.
struct=True¶
torch.ops.higher_order.scan¶
Not all signatures work with this mode. Here is an example with scan.
<<<
import torch
def add(carry: torch.Tensor, y: torch.Tensor):
next_carry = carry + y
return [next_carry, next_carry]
class ScanModel(torch.nn.Module):
def forward(self, x):
init = torch.zeros_like(x[0])
carry, out = torch.ops.higher_order.scan(
add, [init], [x], dim=0, reverse=False, additional_inputs=[]
)
return carry
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32)
model = ScanModel()
expected = model(x)
print("------")
print(expected, x.sum(axis=0))
print("------")
print(torch.export.export(model, (x,), strict=True).graph)
>>>
------
tensor([12., 15., 18.]) tensor([12., 15., 18.])
------
graph():
%x : [num_users=2] = placeholder[target=x]
%select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
%zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
%scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
%scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], 0, False, []), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
return (getitem,)
strict=False¶
‘from_node’ missing in node.meta¶
Every node of the obtained with strict=False
has no key 'from_node'
in dictionary node.meta
. It is therefore difficult to trace where a parameter
is coming from unless this information is passed along when looking
into the submodules.
inplace x[…, i] = y¶
This expression cannot be captured with strict=False
.
<<<
import torch
class UpdateModel(torch.nn.Module):
def forward(
self, x: torch.Tensor, update: torch.Tensor, kv_index: torch.LongTensor
):
x = x.clone()
x[..., kv_index] = update
return x
example_inputs = (
torch.ones((4, 4, 10)).to(torch.float32),
(torch.arange(2) + 10).to(torch.float32).reshape((1, 1, 2)),
torch.Tensor([1, 2]).to(torch.int32),
)
try:
torch.export.export(model, (x,), strict=False)
except Exception as e:
print(e)
>>>
name 'model' is not defined