Note
Go to the end to download the full example code.
Half certain nonzero¶
torch.nonzero()
returns the indices of the first zero found
in a tensor. The output shape is unknown in the generic case
but… If you have a 2D tensor with at least a nonzero value
in every row, you can guess the dimension. But torch.export.export()
does not know what you know.
A Model¶
import torch
from onnx_diagnostic import doc
class Model(torch.nn.Module):
def adaptive_enc_mask(self, x_len, chunk_start_idx, left_window=0, right_window=0):
chunk_start_idx = torch.Tensor(chunk_start_idx).long()
start_pad = torch.cat((torch.tensor([0], dtype=torch.int64), chunk_start_idx), dim=0)
end_pad = torch.cat((chunk_start_idx, torch.tensor([x_len], dtype=torch.int64)), dim=0)
seq_range = torch.arange(0, x_len).unsqueeze(-1)
idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1]
seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
idx_left = idx - left_window
idx_left[idx_left < 0] = 0
boundary_left = start_pad[idx_left]
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
idx_right = idx + right_window
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
boundary_right = end_pad[idx_right]
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
return mask_left & mask_right
def forward(self, x):
return self.adaptive_enc_mask(x.shape[1], [])
model = Model()
x = torch.rand((5, 8))
y = model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")
x.shape=torch.Size([5, 8]), y.shape=torch.Size([8, 8])
Export¶
DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, c_lifted_tensor_0: "f32[0]", c_lifted_tensor_1: "i64[1]", c_lifted_tensor_2: "i64[]", c_lifted_tensor_3: "i64[]", x: "f32[s35, s16]"):
#
sym_size_int_2: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1); x = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_nonzero.py:39 in forward, code: return self.adaptive_enc_mask(x.shape[1], [])
lift_fresh_copy: "f32[0]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(lift_fresh_copy, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "i64[0]" = torch.ops.aten.to.dtype(lift_fresh_copy, torch.int64); lift_fresh_copy = None
lift_fresh_copy_1: "i64[1]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1); c_lifted_tensor_1 = None
detach_: "i64[1]" = torch.ops.aten.detach_.default(lift_fresh_copy_1); lift_fresh_copy_1 = None
cat: "i64[1]" = torch.ops.aten.cat.default([detach_, to]); detach_ = None
scalar_tensor: "i64[]" = torch.ops.aten.scalar_tensor.default(sym_size_int_2, dtype = torch.int64, device = device(type='cpu'), pin_memory = False)
stack: "i64[1]" = torch.ops.aten.stack.default([scalar_tensor]); scalar_tensor = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(stack, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "i64[1]" = torch.ops.aten.to.device(stack, device(type='cpu'), torch.int64); stack = None
detach__1: "i64[1]" = torch.ops.aten.detach_.default(to_1); to_1 = None
cat_1: "i64[1]" = torch.ops.aten.cat.default([to, detach__1]); to = detach__1 = None
arange: "i64[s16]" = torch.ops.aten.arange.start(0, sym_size_int_2, device = device(type='cpu'), pin_memory = False)
unsqueeze: "i64[s16, 1]" = torch.ops.aten.unsqueeze.default(arange, -1); arange = None
lt: "b8[s16, 1]" = torch.ops.aten.lt.Tensor(unsqueeze, cat_1)
ge: "b8[s16, 1]" = torch.ops.aten.ge.Tensor(unsqueeze, cat); unsqueeze = None
and_1: "b8[s16, 1]" = torch.ops.aten.__and__.Tensor(lt, ge); lt = ge = None
nonzero: "i64[s16, 2]" = torch.ops.aten.nonzero.default(and_1); and_1 = None
#
sym_size_int_3: "Sym(s16)" = torch.ops.aten.sym_size.int(nonzero, 0)
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_3); sym_constrain_range_for_size_default = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_nonzero.py:39 in forward, code: return self.adaptive_enc_mask(x.shape[1], [])
ge_2: "Sym(s16 >= 2)" = sym_size_int_3 >= 2
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 2 on node 'ge_2'"); ge_2 = _assert_scalar_default = None
#
eq: "Sym(True)" = sym_size_int_2 == sym_size_int_3; sym_size_int_3 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s16, u0) on node 'eq'"); eq = _assert_scalar_default_1 = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_nonzero.py:39 in forward, code: return self.adaptive_enc_mask(x.shape[1], [])
slice_1: "i64[s16, 2]" = torch.ops.aten.slice.Tensor(nonzero); nonzero = None
select: "i64[s16]" = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
arange_1: "i64[s16]" = torch.ops.aten.arange.start(0, sym_size_int_2, device = device(type='cpu'), pin_memory = False)
unsqueeze_1: "i64[1, s16]" = torch.ops.aten.unsqueeze.default(arange_1, 0); arange_1 = None
expand: "i64[s16, s16]" = torch.ops.aten.expand.default(unsqueeze_1, [sym_size_int_2, -1]); unsqueeze_1 = sym_size_int_2 = None
sub: "i64[s16]" = torch.ops.aten.sub.Tensor(select, 0)
lt_1: "b8[s16]" = torch.ops.aten.lt.Scalar(sub, 0)
lift_fresh_copy_2: "i64[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_2); c_lifted_tensor_2 = None
index_put_: "i64[s16]" = torch.ops.aten.index_put_.default(sub, [lt_1], lift_fresh_copy_2); sub = lt_1 = lift_fresh_copy_2 = None
index: "i64[s16]" = torch.ops.aten.index.Tensor(cat, [index_put_]); cat = index_put_ = None
unsqueeze_2: "i64[s16, 1]" = torch.ops.aten.unsqueeze.default(index, -1); index = None
ge_1: "b8[s16, s16]" = torch.ops.aten.ge.Tensor(expand, unsqueeze_2); unsqueeze_2 = None
add: "i64[s16]" = torch.ops.aten.add.Tensor(select, 0); select = None
gt: "b8[s16]" = torch.ops.aten.gt.Scalar(add, 0)
lift_fresh_copy_3: "i64[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_3); c_lifted_tensor_3 = None
index_put__1: "i64[s16]" = torch.ops.aten.index_put_.default(add, [gt], lift_fresh_copy_3); add = gt = lift_fresh_copy_3 = None
index_1: "i64[s16]" = torch.ops.aten.index.Tensor(cat_1, [index_put__1]); cat_1 = index_put__1 = None
unsqueeze_3: "i64[s16, 1]" = torch.ops.aten.unsqueeze.default(index_1, -1); index_1 = None
lt_2: "b8[s16, s16]" = torch.ops.aten.lt.Tensor(expand, unsqueeze_3); expand = unsqueeze_3 = None
and_2: "b8[s16, s16]" = torch.ops.aten.__and__.Tensor(ge_1, lt_2); ge_1 = lt_2 = None
return (and_2,)
Graph signature:
# inputs
c_lifted_tensor_0: CONSTANT_TENSOR target='lifted_tensor_0'
c_lifted_tensor_1: CONSTANT_TENSOR target='lifted_tensor_1'
c_lifted_tensor_2: CONSTANT_TENSOR target='lifted_tensor_2'
c_lifted_tensor_3: CONSTANT_TENSOR target='lifted_tensor_3'
x: USER_INPUT
# outputs
and_2: USER_OUTPUT
Range constraints: {u0: VR[2, 9223372036854775806], s35: VR[2, int_oo], s16: VR[2, 9223372036854775806]}
We can see the following line in the exported program.
It tells what it cannot verify.
torch.ops.aten._assert_scalar.default(eq,
"Runtime assertion failed for expression Eq(s16, u0) on node 'eq'");
doc.plot_legend("dynamic shapes\nnonzero", "dynamic shapes", "yellow")

Total running time of the script: (0 minutes 0.277 seconds)
Related examples