Half certain nonzero

torch.nonzero() returns the indices of the first zero found in a tensor. The output shape is unknown in the generic case but… If you have a 2D tensor with at least a nonzero value in every row, you can guess the dimension. But torch.export.export() does not know what you know.

A Model

import torch
from onnx_diagnostic import doc


class Model(torch.nn.Module):
    def adaptive_enc_mask(self, x_len, chunk_start_idx, left_window=0, right_window=0):
        chunk_start_idx = torch.Tensor(chunk_start_idx).long()
        start_pad = torch.cat((torch.tensor([0], dtype=torch.int64), chunk_start_idx), dim=0)
        end_pad = torch.cat((chunk_start_idx, torch.tensor([x_len], dtype=torch.int64)), dim=0)
        seq_range = torch.arange(0, x_len).unsqueeze(-1)
        idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1]
        seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
        idx_left = idx - left_window
        idx_left[idx_left < 0] = 0
        boundary_left = start_pad[idx_left]
        mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
        idx_right = idx + right_window
        idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
        boundary_right = end_pad[idx_right]
        mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
        return mask_left & mask_right

    def forward(self, x):
        return self.adaptive_enc_mask(x.shape[1], [])


model = Model()
x = torch.rand((5, 8))
y = model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")
x.shape=torch.Size([5, 8]), y.shape=torch.Size([8, 8])

Export

DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, c_lifted_tensor_0: "f32[0]", c_lifted_tensor_1: "i64[1]", c_lifted_tensor_2: "i64[]", c_lifted_tensor_3: "i64[]", x: "f32[s35, s16]"):
             #
            sym_size_int_2: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1);  x = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_nonzero.py:39 in forward, code: return self.adaptive_enc_mask(x.shape[1], [])
            lift_fresh_copy: "f32[0]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(lift_fresh_copy, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "i64[0]" = torch.ops.aten.to.dtype(lift_fresh_copy, torch.int64);  lift_fresh_copy = None
            lift_fresh_copy_1: "i64[1]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1);  c_lifted_tensor_1 = None
            detach_: "i64[1]" = torch.ops.aten.detach_.default(lift_fresh_copy_1);  lift_fresh_copy_1 = None
            cat: "i64[1]" = torch.ops.aten.cat.default([detach_, to]);  detach_ = None
            scalar_tensor: "i64[]" = torch.ops.aten.scalar_tensor.default(sym_size_int_2, dtype = torch.int64, device = device(type='cpu'), pin_memory = False)
            stack: "i64[1]" = torch.ops.aten.stack.default([scalar_tensor]);  scalar_tensor = None
            _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(stack, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
            to_1: "i64[1]" = torch.ops.aten.to.device(stack, device(type='cpu'), torch.int64);  stack = None
            detach__1: "i64[1]" = torch.ops.aten.detach_.default(to_1);  to_1 = None
            cat_1: "i64[1]" = torch.ops.aten.cat.default([to, detach__1]);  to = detach__1 = None
            arange: "i64[s16]" = torch.ops.aten.arange.start(0, sym_size_int_2, device = device(type='cpu'), pin_memory = False)
            unsqueeze: "i64[s16, 1]" = torch.ops.aten.unsqueeze.default(arange, -1);  arange = None
            lt: "b8[s16, 1]" = torch.ops.aten.lt.Tensor(unsqueeze, cat_1)
            ge: "b8[s16, 1]" = torch.ops.aten.ge.Tensor(unsqueeze, cat);  unsqueeze = None
            and_1: "b8[s16, 1]" = torch.ops.aten.__and__.Tensor(lt, ge);  lt = ge = None
            nonzero: "i64[s16, 2]" = torch.ops.aten.nonzero.default(and_1);  and_1 = None

             #
            sym_size_int_3: "Sym(s16)" = torch.ops.aten.sym_size.int(nonzero, 0)
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_3);  sym_constrain_range_for_size_default = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_nonzero.py:39 in forward, code: return self.adaptive_enc_mask(x.shape[1], [])
            ge_2: "Sym(s16 >= 2)" = sym_size_int_3 >= 2
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 2 on node 'ge_2'");  ge_2 = _assert_scalar_default = None

             #
            eq: "Sym(True)" = sym_size_int_2 == sym_size_int_3;  sym_size_int_3 = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s16, u0) on node 'eq'");  eq = _assert_scalar_default_1 = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_nonzero.py:39 in forward, code: return self.adaptive_enc_mask(x.shape[1], [])
            slice_1: "i64[s16, 2]" = torch.ops.aten.slice.Tensor(nonzero);  nonzero = None
            select: "i64[s16]" = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
            arange_1: "i64[s16]" = torch.ops.aten.arange.start(0, sym_size_int_2, device = device(type='cpu'), pin_memory = False)
            unsqueeze_1: "i64[1, s16]" = torch.ops.aten.unsqueeze.default(arange_1, 0);  arange_1 = None
            expand: "i64[s16, s16]" = torch.ops.aten.expand.default(unsqueeze_1, [sym_size_int_2, -1]);  unsqueeze_1 = sym_size_int_2 = None
            sub: "i64[s16]" = torch.ops.aten.sub.Tensor(select, 0)
            lt_1: "b8[s16]" = torch.ops.aten.lt.Scalar(sub, 0)
            lift_fresh_copy_2: "i64[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_2);  c_lifted_tensor_2 = None
            index_put_: "i64[s16]" = torch.ops.aten.index_put_.default(sub, [lt_1], lift_fresh_copy_2);  sub = lt_1 = lift_fresh_copy_2 = None
            index: "i64[s16]" = torch.ops.aten.index.Tensor(cat, [index_put_]);  cat = index_put_ = None
            unsqueeze_2: "i64[s16, 1]" = torch.ops.aten.unsqueeze.default(index, -1);  index = None
            ge_1: "b8[s16, s16]" = torch.ops.aten.ge.Tensor(expand, unsqueeze_2);  unsqueeze_2 = None
            add: "i64[s16]" = torch.ops.aten.add.Tensor(select, 0);  select = None
            gt: "b8[s16]" = torch.ops.aten.gt.Scalar(add, 0)
            lift_fresh_copy_3: "i64[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_3);  c_lifted_tensor_3 = None
            index_put__1: "i64[s16]" = torch.ops.aten.index_put_.default(add, [gt], lift_fresh_copy_3);  add = gt = lift_fresh_copy_3 = None
            index_1: "i64[s16]" = torch.ops.aten.index.Tensor(cat_1, [index_put__1]);  cat_1 = index_put__1 = None
            unsqueeze_3: "i64[s16, 1]" = torch.ops.aten.unsqueeze.default(index_1, -1);  index_1 = None
            lt_2: "b8[s16, s16]" = torch.ops.aten.lt.Tensor(expand, unsqueeze_3);  expand = unsqueeze_3 = None
            and_2: "b8[s16, s16]" = torch.ops.aten.__and__.Tensor(ge_1, lt_2);  ge_1 = lt_2 = None
            return (and_2,)

Graph signature:
    # inputs
    c_lifted_tensor_0: CONSTANT_TENSOR target='lifted_tensor_0'
    c_lifted_tensor_1: CONSTANT_TENSOR target='lifted_tensor_1'
    c_lifted_tensor_2: CONSTANT_TENSOR target='lifted_tensor_2'
    c_lifted_tensor_3: CONSTANT_TENSOR target='lifted_tensor_3'
    x: USER_INPUT

    # outputs
    and_2: USER_OUTPUT

Range constraints: {u0: VR[2, 9223372036854775806], s35: VR[2, int_oo], s16: VR[2, 9223372036854775806]}

We can see the following line in the exported program. It tells what it cannot verify. torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s16, u0) on node 'eq'");

doc.plot_legend("dynamic shapes\nnonzero", "dynamic shapes", "yellow")
plot dynamic shapes nonzero

Total running time of the script: (0 minutes 0.435 seconds)

Related examples

Do not use python int with dynamic shapes

Do not use python int with dynamic shapes

Half certain nonzero

Half certain nonzero

Gallery generated by Sphinx-Gallery