Half certain nonzero

torch.nonzero() returns the indices of the first zero found in a tensor. The output shape is unknown in the generic case but… If you have a 2D tensor with at least a nonzero value in every row, you can guess the dimension. But torch.export.export() does not know what you know.

A Model

import torch
from onnx_diagnostic import doc


class Model(torch.nn.Module):
    def adaptive_enc_mask(self, x_len, chunk_start_idx, left_window=0, right_window=0):
        chunk_start_idx = torch.Tensor(chunk_start_idx).long()
        start_pad = torch.cat((torch.tensor([0], dtype=torch.int64), chunk_start_idx), dim=0)
        end_pad = torch.cat((chunk_start_idx, torch.tensor([x_len], dtype=torch.int64)), dim=0)
        seq_range = torch.arange(0, x_len).unsqueeze(-1)
        idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1]
        seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
        idx_left = idx - left_window
        idx_left[idx_left < 0] = 0
        boundary_left = start_pad[idx_left]
        mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
        idx_right = idx + right_window
        idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
        boundary_right = end_pad[idx_right]
        mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
        return mask_left & mask_right

    def forward(self, x, y):
        return self.adaptive_enc_mask(
            x.shape[1], torch.tensor([], dtype=torch.int64), left_window=y.shape[0]
        )


model = Model()
x, y = torch.rand((2, 546)), torch.rand((18,))
z = model(x, y)
print(f"y.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
y.shape=torch.Size([2, 546]), y.shape=torch.Size([18]), z.shape=torch.Size([546, 546])

Export

DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(model, (x, y), dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN}))
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, c_lifted_tensor_0: "i64[0]", c_lifted_tensor_1: "i64[1]", c_lifted_tensor_2: "i64[]", c_lifted_tensor_3: "i64[]", x: "f32[s35, s16]", y: "f32[s58]"):
             #
            sym_size_int_3: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1);  x = None
            sym_size_int_4: "Sym(s58)" = torch.ops.aten.sym_size.int(y, 0);  y = None

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_nonzero.py:42 in forward, code: x.shape[1], torch.tensor([], dtype=torch.int64), left_window=y.shape[0]
            lift_fresh_copy: "i64[0]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
            detach_: "i64[0]" = torch.ops.aten.detach_.default(lift_fresh_copy);  lift_fresh_copy = None

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_nonzero.py:41 in forward, code: return self.adaptive_enc_mask(
            alias: "i64[0]" = torch.ops.aten.alias.default(detach_);  detach_ = None
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(alias, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "i64[0]" = torch.ops.aten.to.dtype(alias, torch.int64);  alias = None
            lift_fresh_copy_1: "i64[1]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1);  c_lifted_tensor_1 = None
            detach__1: "i64[1]" = torch.ops.aten.detach_.default(lift_fresh_copy_1);  lift_fresh_copy_1 = None
            cat: "i64[1]" = torch.ops.aten.cat.default([detach__1, to]);  detach__1 = None
            scalar_tensor: "i64[]" = torch.ops.aten.scalar_tensor.default(sym_size_int_3, dtype = torch.int64, device = device(type='cpu'), pin_memory = False)
            stack: "i64[1]" = torch.ops.aten.stack.default([scalar_tensor]);  scalar_tensor = None
            _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(stack, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
            to_1: "i64[1]" = torch.ops.aten.to.device(stack, device(type='cpu'), torch.int64);  stack = None
            detach__2: "i64[1]" = torch.ops.aten.detach_.default(to_1);  to_1 = None
            cat_1: "i64[1]" = torch.ops.aten.cat.default([to, detach__2]);  to = detach__2 = None
            arange: "i64[s16]" = torch.ops.aten.arange.start(0, sym_size_int_3, device = device(type='cpu'), pin_memory = False)
            unsqueeze: "i64[s16, 1]" = torch.ops.aten.unsqueeze.default(arange, -1);  arange = None
            lt: "b8[s16, 1]" = torch.ops.aten.lt.Tensor(unsqueeze, cat_1)
            ge: "b8[s16, 1]" = torch.ops.aten.ge.Tensor(unsqueeze, cat);  unsqueeze = None
            and_1: "b8[s16, 1]" = torch.ops.aten.__and__.Tensor(lt, ge);  lt = ge = None
            nonzero: "i64[s16, 2]" = torch.ops.aten.nonzero.default(and_1);  and_1 = None

             #
            sym_size_int_5: "Sym(s16)" = torch.ops.aten.sym_size.int(nonzero, 0)
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_5);  sym_constrain_range_for_size_default = None

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_nonzero.py:41 in forward, code: return self.adaptive_enc_mask(
            ge_2: "Sym(s16 >= 2)" = sym_size_int_5 >= 2
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 2 on node 'ge_2'");  ge_2 = _assert_scalar_default = None

             #
            eq: "Sym(True)" = sym_size_int_3 == sym_size_int_5;  sym_size_int_5 = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s16, u0) on node 'eq'");  eq = _assert_scalar_default_1 = None

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_nonzero.py:41 in forward, code: return self.adaptive_enc_mask(
            slice_1: "i64[s16, 2]" = torch.ops.aten.slice.Tensor(nonzero);  nonzero = None
            select: "i64[s16]" = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
            arange_1: "i64[s16]" = torch.ops.aten.arange.start(0, sym_size_int_3, device = device(type='cpu'), pin_memory = False)
            unsqueeze_1: "i64[1, s16]" = torch.ops.aten.unsqueeze.default(arange_1, 0);  arange_1 = None
            expand: "i64[s16, s16]" = torch.ops.aten.expand.default(unsqueeze_1, [sym_size_int_3, -1]);  unsqueeze_1 = sym_size_int_3 = None
            sub: "i64[s16]" = torch.ops.aten.sub.Tensor(select, sym_size_int_4);  sym_size_int_4 = None
            lt_1: "b8[s16]" = torch.ops.aten.lt.Scalar(sub, 0)
            lift_fresh_copy_2: "i64[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_2);  c_lifted_tensor_2 = None
            index_put_: "i64[s16]" = torch.ops.aten.index_put_.default(sub, [lt_1], lift_fresh_copy_2);  sub = lt_1 = lift_fresh_copy_2 = None
            index: "i64[s16]" = torch.ops.aten.index.Tensor(cat, [index_put_]);  cat = index_put_ = None
            unsqueeze_2: "i64[s16, 1]" = torch.ops.aten.unsqueeze.default(index, -1);  index = None
            ge_1: "b8[s16, s16]" = torch.ops.aten.ge.Tensor(expand, unsqueeze_2);  unsqueeze_2 = None
            add: "i64[s16]" = torch.ops.aten.add.Tensor(select, 0);  select = None
            gt: "b8[s16]" = torch.ops.aten.gt.Scalar(add, 0)
            lift_fresh_copy_3: "i64[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_3);  c_lifted_tensor_3 = None
            index_put__1: "i64[s16]" = torch.ops.aten.index_put_.default(add, [gt], lift_fresh_copy_3);  add = gt = lift_fresh_copy_3 = None
            index_1: "i64[s16]" = torch.ops.aten.index.Tensor(cat_1, [index_put__1]);  cat_1 = index_put__1 = None
            unsqueeze_3: "i64[s16, 1]" = torch.ops.aten.unsqueeze.default(index_1, -1);  index_1 = None
            lt_2: "b8[s16, s16]" = torch.ops.aten.lt.Tensor(expand, unsqueeze_3);  expand = unsqueeze_3 = None
            and_2: "b8[s16, s16]" = torch.ops.aten.__and__.Tensor(ge_1, lt_2);  ge_1 = lt_2 = None
            return (and_2,)

Graph signature:
    # inputs
    c_lifted_tensor_0: CONSTANT_TENSOR target='lifted_tensor_0'
    c_lifted_tensor_1: CONSTANT_TENSOR target='lifted_tensor_1'
    c_lifted_tensor_2: CONSTANT_TENSOR target='lifted_tensor_2'
    c_lifted_tensor_3: CONSTANT_TENSOR target='lifted_tensor_3'
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    and_2: USER_OUTPUT

Range constraints: {u0: VR[2, 9223372036854775806], s35: VR[2, int_oo], s16: VR[2, 9223372036854775806], s58: VR[2, int_oo]}

We can see the following line in the exported program. It tells what it cannot verify. torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s16, u0) on node 'eq'");

doc.plot_legend("dynamic shapes\nnonzero", "dynamic shapes", "yellow")
plot dynamic shapes nonzero

Total running time of the script: (0 minutes 0.428 seconds)

Related examples

Cannot export torch.sym_max(x.shape[0], y.shape[0])

Cannot export torch.sym_max(x.shape[0], y.shape[0])

Builds dynamic shapes from any input

Builds dynamic shapes from any input

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

Gallery generated by Sphinx-Gallery