Exported Programs with Dynamic Shapes

The following script shows the exported program for many short cases and various l-plot-export-with-dynamic-shape to retrieve an ONNX model equivalent to the original model.

<<<

import inspect
import textwrap
import pandas
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches.eval import discover, run_exporter
from onnx_diagnostic.ext_test_case import unit_test_going

cases = discover()
print()
print(":ref:`Summary <led-summary-exported-program>`")
print()
sorted_cases = sorted(cases.items())
if unit_test_going():
    sorted_cases = sorted_cases[:3]
for name, cls_model in sorted_cases:
    print(f"* :ref:`{name} <led-model-case-export-{name}>`")
print()

obs = []
for name, cls_model in sorted(cases.items()):
    print()
    print(f".. _led-model-case-export-{name}:")
    print()
    print(name)
    print("=" * len(name))
    print()
    print("forward")
    print("+++++++")
    print()
    print("::")
    print()
    print(
        textwrap.indent(textwrap.dedent(inspect.getsource(cls_model.forward)), "    ")
    )
    print()
    for exporter in (
        "export-strict",
        "export-nostrict",
        "export-nostrict-decall",
    ):
        expname = exporter.replace("export-", "")
        print()
        print(expname)
        print("+" * len(expname))
        print()
        res = run_exporter(exporter, cls_model, True, quiet=True)
        case_ref = f":ref:`{name} <led-model-case-export-{name}>`"
        expo = exporter.split("-", maxsplit=1)[-1]
        if "inputs" in res:
            print(f"* **inputs:** ``{string_type(res['inputs'], with_shape=True)}``")
        if "dynamic_shapes" in res:
            print(f"* **shapes:** ``{string_type(res['dynamic_shapes'])}``")
        print()
        if "exported" in res:
            print("::")
            print()
            print(textwrap.indent(str(res["exported"].graph), "    "))
            print()
            obs.append(dict(case=case_ref, error="", exporter=expo))
        else:
            print("**FAILED**")
            print()
            print("::")
            print()
            print(textwrap.indent(str(res["error"]), "    "))
            print()
            obs.append(dict(case=case_ref, error="FAIL", exporter=expo))

print()
print(".. _led-summary-exported-program:")
print()
print("Summary")
print("+++++++")
print()
df = pandas.DataFrame(obs)
piv = df.pivot(index="case", columns="exporter", values="error")
print(piv.to_markdown(tablefmt="rst"))
print()

>>>

Summary

AtenAsStrided

forward

def forward(self, x):
    y = torch.as_strided(x, (2, 2, 8, 4), (128, 8, 16, 1))
    return y

strict

  • inputs: #1[(T1s2x2x8x8,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

nostrict

  • inputs: #1[(T1s2x2x8x8,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

nostrict-decall

  • inputs: #1[(T1s2x2x8x8,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

AtenInterpolate

forward

def forward(self, x):
    y = torch.nn.functional.interpolate(
        x,
        scale_factor=2.0,
        mode="bilinear",
        recompute_scale_factor=False,
    )
    return y

strict

  • inputs: #1[(T1s2x2x3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

nostrict

  • inputs: #1[(T1s2x2x3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

nostrict-decall

  • inputs: #1[(T1s2x2x3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

AtenNonZero

forward

def forward(self, x):
    y = torch.nonzero(x)
    return y

strict

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_1 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    return (nonzero,)

nostrict

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_1 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    return (nonzero,)

nostrict-decall

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_2 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_2,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_2, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    return (nonzero,)

AtenNonZeroTuple

forward

def forward(self, x):
    y = torch.nonzero(x, as_tuple=True)
    return y[0], y[1]

strict

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int_1 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

nostrict

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int_1 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

nostrict-decall

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=3] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_2 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_2,), kwargs = {})
    %ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_2, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 0, 1), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 1, 2), kwargs = {})
    %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_1, [1]), kwargs = {})
    %squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_2, [1]), kwargs = {})
    return (squeeze, squeeze_1)

AtenRollPos

forward

def forward(self, x):
    return torch.roll(x, 1, -1)

strict

  • inputs: #1[(T1s2x3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

nostrict

  • inputs: #1[(T1s2x3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

nostrict-decall

  • inputs: #1[(T1s2x3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, 4), kwargs = {layout: torch.strided, device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arange, 3), kwargs = {})
    %fmod : [num_users=1] = call_function[target=torch.ops.aten.fmod.Scalar](args = (%add, 4), kwargs = {})
    %index_select : [num_users=1] = call_function[target=torch.ops.aten.index_select.default](args = (%x, 2, %fmod), kwargs = {})
    return (index_select,)

AtenRollRelu

forward

def forward(self, x):
    return torch.relu(torch.roll(x, -1, -1))

strict

  • inputs: #1[(T1s2x3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

nostrict

  • inputs: #1[(T1s2x3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

nostrict-decall

  • inputs: #1[(T1s2x3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, 4), kwargs = {layout: torch.strided, device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arange, 1), kwargs = {})
    %fmod : [num_users=1] = call_function[target=torch.ops.aten.fmod.Scalar](args = (%add, 4), kwargs = {})
    %index_select : [num_users=1] = call_function[target=torch.ops.aten.index_select.default](args = (%x, 2, %fmod), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%index_select,), kwargs = {})
    return (relu,)

BuildInIsInstance

forward

def forward(self, x, lx: list | torch.Tensor):
    if isinstance(lx, list):
        t = lx[0] * lx[1].sum(axis=1, keepdim=True)
        return torch.sigmoid(self.linear(x)) - self.buff + t
    return torch.sigmoid(self.linear(x)) - self.buff + lx

strict

  • inputs: #2[(T1s4x3,#2[T1s4x1,T1s4x2]),(T1s8x3,#2[T1s8x1,T1s8x2])]

  • shapes: dict(x:{0:Dim(batch)},lx:#2[{0:Dim(batch)},{0:Dim(batch)}])

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict

  • inputs: #2[(T1s4x3,#2[T1s4x1,T1s4x2]),(T1s8x3,#2[T1s8x1,T1s8x2])]

  • shapes: dict(x:{0:Dim(batch)},lx:#2[{0:Dim(batch)},{0:Dim(batch)}])

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict-decall

  • inputs: #2[(T1s4x3,#2[T1s4x1,T1s4x2]),(T1s8x3,#2[T1s8x1,T1s8x2])]

  • shapes: dict(x:{0:Dim(batch)},lx:#2[{0:Dim(batch)},{0:Dim(batch)}])

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add_15 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub_4, %mul_1), kwargs = {})
    return (add_15,)

BuildInLen

forward

def forward(self, x, lx: list):
    t = lx[0] * lx[1].sum(axis=1, keepdim=True)
    if len(lx) > 2:
        t = t + lx[2].sum(axis=1, keepdim=True)
    return torch.sigmoid(self.linear(x)) - self.buff + t

strict

FAILED

Trying to flatten user inputs with exported input tree spec: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

nostrict

FAILED

Trying to flatten user inputs with exported input tree spec: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

nostrict-decall

FAILED

Trying to flatten user inputs with exported input tree spec: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

ComplexPolar

forward

def forward(self, x, angle):
    return torch.polar(x, angle)

strict

  • inputs: #1[(T1s4x4,T1s4x4)]

  • shapes: dict(x:{0:Dim(batch)},angle:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

nostrict

  • inputs: #1[(T1s4x4,T1s4x4)]

  • shapes: dict(x:{0:Dim(batch)},angle:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

nostrict-decall

  • inputs: #1[(T1s4x4,T1s4x4)]

  • shapes: dict(x:{0:Dim(batch)},angle:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

ControlFlowCond

forward

def forward(self, x):
    def true_fn(x):
        return torch.sin(x)

    def false_fn(x):
        return torch.cos(x)

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

strict

  • inputs: #1[(T1s5x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

  • inputs: #1[(T1s5x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decall

  • inputs: #1[(T1s5x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

ControlFlowCond2Inputs

forward

def forward(self, x, y):
    def true_fn(x, y):
        return torch.sin(x), torch.cos(x) + y

    def false_fn(x, y):
        return torch.cos(x), torch.sin(x) + y

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y])

strict

  • inputs: #1[(T1s5x3,T1s5x3)]

  • shapes: dict(x:{0:Dim(batch)},y:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %y)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict

  • inputs: #1[(T1s5x3,T1s5x3)]

  • shapes: dict(x:{0:Dim(batch)},y:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %y)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict-decall

  • inputs: #1[(T1s5x3,T1s5x3)]

  • shapes: dict(x:{0:Dim(batch)},y:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %y)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

ControlFlowCond2Outputs

forward

def forward(self, x):
    def true_fn(x):
        return torch.sin(x), torch.cos(x)

    def false_fn(x):
        return torch.cos(x), torch.sin(x)

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

strict

  • inputs: #1[(T1s5x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict

  • inputs: #1[(T1s5x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict-decall

  • inputs: #1[(T1s5x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

ControlFlowCondConstant

forward

def forward(self, x):
    def true_fn(x):
        return torch.sin(x) - torch.ones(x.shape, dtype=x.dtype)

    def false_fn(x):
        return torch.cos(x) + torch.ones((1, 1024), dtype=x.dtype)

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

strict

  • inputs: #1[(T1s1024x1024,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

  • inputs: #1[(T1s1024x1024,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decall

  • inputs: #1[(T1s1024x1024,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

ControlFlowCondIdentity_153832

forward

def forward(self, x, y):

    def branch_cond_then_1(x):
        x = torch.abs(x) + 1
        return x

    def branch_cond_else_1(x):
        return x  # fails but succeeds with x.clone()

    x = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [x])
    return x + y

strict

FAILED

Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 380, in forward
    x = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [x])
  File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 135, in cond
    return cond_op(pred, true_fn, false_fn, operands)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

FAILED

Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 184, in _cond_op_wrapper
    return cond_op(*args, **kwargs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict-decall

FAILED

Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 184, in _cond_op_wrapper
    return cond_op(*args, **kwargs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

ControlFlowCondNestedModule

forward

def forward(self, x):
    def true_fn(x):
        return self.submodule(x)

    def false_fn(x):
        return x - self.weight

    y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
    return y

strict

  • inputs: #1[(T7s2,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %weight : [num_users=1] = get_attr[target=weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %submodule_weight, %weight)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

  • inputs: #1[(T7s2,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %weight : [num_users=1] = get_attr[target=weight]
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %submodule_weight, %weight)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decall

  • inputs: #1[(T7s2,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %weight : [num_users=1] = get_attr[target=weight]
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %submodule_weight, %weight)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

ControlFlowCondNonZero

forward

def forward(self, input_ids, image_features, vocab_size):
    def then_branch(input_ids, image_features, vocab_size):
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])

        condition = (input_ids < 0) & (input_ids > -int(1e9))
        positions = torch.nonzero(condition, as_tuple=True)
        input_ids = input_ids.clamp_min(0).clamp_max(vocab_size)
        return (input_ids, positions[0], positions[1])

    def else_branch(input_ids, image_features, vocab_size):
        r = torch.where(torch.zeros((1, 1), dtype=torch.bool))
        return (input_ids, r[0], r[1])

    a, b, c = torch.cond(
        image_features.numel() > 0,
        then_branch,
        else_branch,
        [input_ids, image_features, vocab_size],
    )
    return a, b, c

strict

FAILED

Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 341, in forward
    a, b, c = torch.cond(
  File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 135, in cond
    return cond_op(pred, true_fn, false_fn, operands)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

FAILED

Expect operands to be a tuple of possibly nested dict/list/tuple that only consists of tensor leaves, but got [FakeTensor(..., size=(s35, 12), dtype=torch.int64), FakeTensor(..., size=(s58, s43)), 1025].

nostrict-decall

FAILED

Expect operands to be a tuple of possibly nested dict/list/tuple that only consists of tensor leaves, but got [FakeTensor(..., size=(s35, 12), dtype=torch.int64), FakeTensor(..., size=(s58, s43)), 1025].

ControlFlowNestCond

forward

def forward(self, x):
    def true_fn2(x):
        def true_fn1(x):
            return torch.sin(x)

        def false_fn1(x):
            return torch.cos(x)

        return torch.cond(x.sum() < 0, true_fn1, false_fn1, [x])

    def false_fn2(x):
        return -x

    return torch.cond(x.sum() > 0, true_fn2, false_fn2, [x])

strict

  • inputs: #1[(T1s5x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

  • inputs: #1[(T1s5x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decall

  • inputs: #1[(T1s5x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

ControlFlowScan

forward

def forward(self, x):
    init = torch.zeros_like(x[0])
    carry, out = torch.ops.higher_order.scan(
        ControlFlowScan.add, [init], [x], additional_inputs=[]
    )
    return carry

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 399, in forward
    carry, out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

  • inputs: #1[(T1s3x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], ()), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem,)

nostrict-decall

FAILED

scan might be aliasing the input or the output!

While executing %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], ()), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x):
        x: "f32[s35, 3][3, 1]"; 

        x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
         # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py:398 in forward, code: init = torch.zeros_like(x[0])
        select: "f32[3][1]" = torch.ops.aten.select.int(x, 0, 0)
        zeros_like: "f32[3][1]" = torch.ops.aten.zeros_like.default(select, pin_memory = False);  select = None
    
         # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py:399 in forward, code: carry, out = torch.ops.higher_order.scan(
        scan_combine_graph_0 = self.scan_combine_graph_0
        scan = torch.ops.higher_order.scan(scan_combine_graph_0, [zeros_like], [x], ());  scan_combine_graph_0 = zeros_like = x = None
        getitem: "f32[3][1]" = scan[0]
        getitem_1: "f32[s35, 3][3, 1]" = scan[1];  scan = getitem_1 = None
        return pytree.tree_unflatten((getitem,), self._out_spec)
    
    class scan_combine_graph_0(torch.nn.Module):
        def forward(self, carry_1: "f32[3][1]", y_1: "f32[3][1]"):
             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py:399 in forward, code: carry, out = torch.ops.higher_order.scan(
            add: "f32[3][1]" = torch.ops.aten.add.Tensor(carry_1, y_1);  carry_1 = y_1 = None
            return [add, add]
        

Original traceback:
File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 399, in forward
    carry, out = torch.ops.higher_order.scan(

ControlFlowScan2Carried

forward

def forward(self, x):
    init1 = torch.zeros_like(x[0])
    init2 = torch.ones_like(x[0])
    carry1, carry2, out1, out2 = torch.ops.higher_order.scan(
        ControlFlowScan2Carried.add,
        [init1, init2],
        [x, x * 2],
        # dim=0,  # 01/31/2025, not supported anymore
        additional_inputs=[],
    )
    return carry1, carry2, out1, out2

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 418, in forward
    carry1, carry2, out1, out2 = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=4] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %select_1 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %ones_like : [num_users=1] = call_function[target=torch.ops.aten.ones_like.default](args = (%select_1,), kwargs = {pin_memory: False})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 2), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=4] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like, %ones_like], [%x, %mul], ()), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 2), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 3), kwargs = {})
    return (getitem, getitem_1, getitem_2, getitem_3)

nostrict-decall

FAILED

scan might be aliasing the input or the output!

While executing %scan : [num_users=4] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like, %ones_like], [%x, %mul], ()), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x):
        x: "f32[s35, 4][4, 1]"; 

        x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
         # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py:416 in forward, code: init1 = torch.zeros_like(x[0])
        select: "f32[4][1]" = torch.ops.aten.select.int(x, 0, 0)
        zeros_like: "f32[4][1]" = torch.ops.aten.zeros_like.default(select, pin_memory = False);  select = None
    
         # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py:417 in forward, code: init2 = torch.ones_like(x[0])
        select_1: "f32[4][1]" = torch.ops.aten.select.int(x, 0, 0)
        ones_like: "f32[4][1]" = torch.ops.aten.ones_like.default(select_1, pin_memory = False);  select_1 = None
    
         # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py:421 in forward, code: [x, x * 2],
        mul: "f32[s35, 4][4, 1]" = torch.ops.aten.mul.Tensor(x, 2)
    
         # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py:418 in forward, code: carry1, carry2, out1, out2 = torch.ops.higher_order.scan(
        scan_combine_graph_0 = self.scan_combine_graph_0
        scan = torch.ops.higher_order.scan(scan_combine_graph_0, [zeros_like, ones_like], [x, mul], ());  scan_combine_graph_0 = zeros_like = ones_like = x = mul = None
        getitem: "f32[4][1]" = scan[0]
        getitem_1: "f32[4][1]" = scan[1]
        getitem_2: "f32[s35, 4][4, 1]" = scan[2]
        getitem_3: "f32[s35, 4][4, 1]" = scan[3];  scan = None
        return pytree.tree_unflatten((getitem, getitem_1, getitem_2, getitem_3), self._out_spec)
    
    class scan_combine_graph_0(torch.nn.Module):
        def forward(self, carry1_1: "f32[4][1]", carry2_1: "f32[4][1]", y1_1: "f32[4][1]", y2_1: "f32[4][1]"):
             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py:418 in forward, code: carry1, carry2, out1, out2 = torch.ops.higher_order.scan(
            add: "f32[4][1]" = torch.ops.aten.add.Tensor(carry1_1, y1_1);  carry1_1 = y1_1 = None
            mul: "f32[4][1]" = torch.ops.aten.mul.Tensor(carry2_1, y2_1);  carry2_1 = y2_1 = None
            return [add, mul, add, mul]
        

Original traceback:
File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 418, in forward
    carry1, carry2, out1, out2 = torch.ops.higher_order.scan(

ControlFlowScanCDist

forward

def forward(self, x):
    carry, out = torch.ops.higher_order.scan(
        ControlFlowScanCDist.dist,
        [x],
        [x],
        # dim=0,  # 01/31/2025, not supported anymore
        additional_inputs=[],
    )
    return out

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 444, in forward
    carry, out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%x], [%x], ()), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict-decall

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%x], [%x], ()), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

ControlFlowScanCDist2

forward

def forward(self, x):
    z = torch.tensor([0], dtype=torch.float32)
    y = x.clone()
    out = torch.ops.higher_order.scan(
        ControlFlowScanCDist2.dist,
        [z],
        [x],
        # dim=0,  # 01/31/2025, not supported anymore
        additional_inputs=[y],
    )
    return out[1]

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 472, in forward
    out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {})
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%detach_], [%x], (%clone,)), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict-decall

  • inputs: #1[(T1s3x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %clone_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%clone], [%x], (%clone_1,)), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

ControlFlowScanCDistXY

forward

def forward(self, x, y):
    carry, out = torch.ops.higher_order.scan(
        ControlFlowScanCDistXY.dist,
        [y],
        [x],
        # dim=0,  # 01/31/2025, not supported anymore
        additional_inputs=[],
    )
    return out

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 499, in forward
    carry, out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

  • inputs: #2[(T1s3x4,T1s5x4),(T1s13x14,T1s15x14)]

  • shapes: dict(x:{0:Dim(x_rows),1:Dim(dim)},y:{0:Dim(y_rows),1:Dim(dim)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%y], [%x], ()), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict-decall

  • inputs: #2[(T1s3x4,T1s5x4),(T1s13x14,T1s15x14)]

  • shapes: dict(x:{0:Dim(x_rows),1:Dim(dim)},y:{0:Dim(y_rows),1:Dim(dim)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%y], [%x], ()), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

ControlFlowScanDecomposition_151564

forward

def forward(self, images, position):
    return self.select_when_exporting(self.dummy_loop, self.dummy_loop_with_scan)(
        images, position
    )

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 577, in forward
    return self.select_when_exporting(self.dummy_loop, self.dummy_loop_with_scan)(
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 565, in dummy_loop_with_scan
    return torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

  • inputs: #1[(T1s5x6,T7s5)]

  • shapes: dict(images:{0:DYNAMIC,1:DYNAMIC},position:{0:DYNAMIC})

graph():
    %images : [num_users=1] = placeholder[target=images]
    %position : [num_users=1] = placeholder[target=position]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [], [%images, %position], ()), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    return (getitem,)

nostrict-decall

FAILED

Could not guard on data-dependent expression u1 < 0 (unhinted: u1 < 0).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:734 in slice_forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u1"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

While executing %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [], [%images, %position], ()), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, images, position):
        images: "f32[s35, s16][s16, 1]"; position: "i64[s58][1]"; 

        images, position, = fx_pytree.tree_flatten_spec(([images, position], {}), self._in_spec)
         # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py:577 in forward, code: return self.select_when_exporting(self.dummy_loop, self.dummy_loop_with_scan)(
        scan_combine_graph_0 = self.scan_combine_graph_0
        scan = torch.ops.higher_order.scan(scan_combine_graph_0, [], [images, position], ());  scan_combine_graph_0 = images = position = None
        getitem: "f32[s35, s16][s16, 1]" = scan[0];  scan = None
        return pytree.tree_unflatten((getitem,), self._out_spec)
    
    class scan_combine_graph_0(torch.nn.Module):
        def forward(self, padded_1: "f32[s16][1]", p_1: "i64[][]"):
             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py:577 in forward, code: return self.select_when_exporting(self.dummy_loop, self.dummy_loop_with_scan)(
            sym_size_int: "Sym(s16)" = torch.ops.aten.sym_size.int(padded_1, 0)
            zeros: "f32[s16][1]" = torch.ops.aten.zeros.default([sym_size_int], device = device(type='cpu'), pin_memory = False);  sym_size_int = None
            item: "Sym(u0)" = torch.ops.aten.item.default(p_1);  p_1 = None
            slice_1: "f32[u0][1]" = torch.ops.aten.slice.Tensor(padded_1, 0, 0, item);  padded_1 = None
            slice_2: "f32[u0][1]" = torch.ops.aten.slice.Tensor(zeros, 0, 0, item);  item = None
            copy_: "f32[u0][1]" = torch.ops.aten.copy_.default(slice_2, slice_1);  slice_2 = slice_1 = copy_ = None
            return (zeros,)
        

Original traceback:
File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 577, in forward
    return self.select_when_exporting(self.dummy_loop, self.dummy_loop_with_scan)(

ControlFlowScanInplace_153705

forward

def forward(self, x, y):
    def loop_body_1(z, iv, x, y):
        z = z.clone()
        i = iv.item()
        z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
        return [z, iv]

    z = torch.empty((x.shape[0], y.shape[0]))
    r = torch.ops.higher_order.scan(
        loop_body_1, [z], [torch.arange(x.shape[0], dtype=torch.int64)], [x, y]
    )
    return r[0]

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/eval/model_cases.py", line 529, in forward
    r = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

FAILED

only integers, slices (`:`), ellipsis (`...`), None and long or byte Variables are valid indices (got SymInt)

nostrict-decall

FAILED

only integers, slices (`:`), ellipsis (`...`), None and long or byte Variables are valid indices (got SymInt)

CreateFromShape

forward

def forward(self, x):
    y = torch.ones((x.shape[0], x.shape[1] + 1))
    return y

strict

  • inputs: #2[(T1s4x4,),(T1s5x5,)]

  • shapes: dict(x:{0:Dim(dx),1:Dim(dy)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_3, 1), kwargs = {})
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([%sym_size_int_2, %add],), kwargs = {device: cpu, pin_memory: False})
    return (ones,)

nostrict

  • inputs: #2[(T1s4x4,),(T1s5x5,)]

  • shapes: dict(x:{0:Dim(dx),1:Dim(dy)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_3, 1), kwargs = {})
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([%sym_size_int_2, %add],), kwargs = {device: cpu, pin_memory: False})
    return (ones,)

nostrict-decall

  • inputs: #2[(T1s4x4,),(T1s5x5,)]

  • shapes: dict(x:{0:Dim(dx),1:Dim(dy)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_3, 1), kwargs = {})
    %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_2, %add], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    return (full,)

CreateFromShapeThroughFunction

forward

def forward(self, x):
    dy1 = CreateFromShapeThroughFunction.add_one(x.shape[1])
    y = torch.ones((x.shape[0], dy1))
    return y

strict

  • inputs: #2[(T1s4x4,),(T1s5x5,)]

  • shapes: dict(x:{0:Dim(dx),1:Dim(dy)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_3, 1), kwargs = {})
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([%sym_size_int_2, %add],), kwargs = {device: cpu, pin_memory: False})
    return (ones,)

nostrict

  • inputs: #2[(T1s4x4,),(T1s5x5,)]

  • shapes: dict(x:{0:Dim(dx),1:Dim(dy)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_3, 1), kwargs = {})
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([%sym_size_int_2, %add],), kwargs = {device: cpu, pin_memory: False})
    return (ones,)

nostrict-decall

  • inputs: #2[(T1s4x4,),(T1s5x5,)]

  • shapes: dict(x:{0:Dim(dx),1:Dim(dy)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_3, 1), kwargs = {})
    %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_2, %add], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    return (full,)

CropLastDimensionWithTensorContent

forward

def forward(self, x, shape):
    return x[..., : shape[0]]

strict

FAILED

When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['x', 'shape'] of `inputs`, but here they are ['x']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

nostrict

FAILED

When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['x', 'shape'] of `inputs`, but here they are ['x']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

nostrict-decall

FAILED

When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['x', 'shape'] of `inputs`, but here they are ['x']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

CropLastDimensionWithTensorShape

forward

def forward(self, x, y):
    return x[..., : y.shape[0]]

strict

  • inputs: #2[(T1s3x4x4,T1s2),(T1s6x4x4,T1s3)]

  • shapes: dict(x:{0:Dim(batch)},y:{0:Dim(crop)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, %sym_size_int_2), kwargs = {})
    return (slice_1,)

nostrict

  • inputs: #2[(T1s3x4x4,T1s2),(T1s6x4x4,T1s3)]

  • shapes: dict(x:{0:Dim(batch)},y:{0:Dim(crop)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, %sym_size_int_2), kwargs = {})
    return (slice_1,)

nostrict-decall

  • inputs: #2[(T1s3x4x4,T1s2),(T1s6x4x4,T1s3)]

  • shapes: dict(x:{0:Dim(batch)},y:{0:Dim(crop)})

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, %sym_size_int_4), kwargs = {})
    return (slice_1,)

InplaceAdd

forward

def forward(self, x):
    x += self.bias
    return x

strict

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

nostrict

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

nostrict-decall

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add_3), kwargs = {})
    return (copy__default,)

InplaceAdd2

forward

def forward(self, x):
    x.add_(self.bias)
    return x

strict

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

nostrict

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

nostrict-decall

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add_3), kwargs = {})
    return (copy__default,)

InplaceAdd_Mul

forward

def forward(self, x):
    x.add_(self.bias)
    return x * 2

strict

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_, 2), kwargs = {})
    return (mul,)

nostrict

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_, 2), kwargs = {})
    return (mul,)

nostrict-decall

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_3, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add_3), kwargs = {})
    return (mul_4,)

InplaceCloneAdd

forward

def forward(self, x):
    x = x.clone()
    x.add_(self.bias)
    return x

strict

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%clone, %bias), kwargs = {})
    return (add_,)

nostrict

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%clone, %bias), kwargs = {})
    return (add_,)

nostrict-decall

  • inputs: #2[(T1s3x4,),(T1s5x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %bias), kwargs = {})
    return (add_6,)

InplaceSetItemEllipsis_1

forward

def forward(self, index, update):
    copy = self.params.clone()
    copy[..., index] = update
    return copy

strict

FAILED

When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['index', 'update'] of `inputs`, but here they are ['x']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

nostrict

FAILED

When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['index', 'update'] of `inputs`, but here they are ['x']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

nostrict-decall

FAILED

When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['index', 'update'] of `inputs`, but here they are ['x']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

InplaceSetItemEllipsis_2

forward

def forward(self, index, update):
    copy = self.params.clone()
    copy[..., index] = update
    return copy

strict

FAILED

When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['index', 'update'] of `inputs`, but here they are ['x']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

nostrict

FAILED

When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['index', 'update'] of `inputs`, but here they are ['x']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

nostrict-decall

FAILED

When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['index', 'update'] of `inputs`, but here they are ['x']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

InplaceSetItemMask

forward

def forward(self, x):
    mask = x.to(bool)
    x[mask] = 2
    return x

strict

  • inputs: #2[(T1s2x3x3,),(T1s3x3x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %_assert_tensor_metadata_default : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x,), kwargs = {dtype: torch.float32, device: cpu, layout: torch.strided})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bool), kwargs = {})
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%x, [%to], %lift_fresh_copy), kwargs = {})
    return (index_put_,)

nostrict

  • inputs: #2[(T1s2x3x3,),(T1s3x3x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %_assert_tensor_metadata_default : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x,), kwargs = {dtype: torch.float32, device: cpu, layout: torch.strided})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bool), kwargs = {})
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%x, [%to], %lift_fresh_copy), kwargs = {})
    return (index_put_,)

nostrict-decall

  • inputs: #2[(T1s2x3x3,),(T1s3x3x3,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x, None, None, torch.float32), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bool})
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%x, [%_to_copy], %clone), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %index_put), kwargs = {})
    return (copy__default,)

InplaceSetItemSquare

forward

def forward(self, x):
    x[:2, :3] = 1
    return x

strict

  • inputs: #2[(T1s5x5,),(T1s7x5,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    return (x,)

nostrict

  • inputs: #2[(T1s5x5,),(T1s7x5,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    return (x,)

nostrict-decall

  • inputs: #2[(T1s5x5,),(T1s7x5,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %clone), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (copy__default,)

InplaceSetItemSquareAdd

forward

def forward(self, x):
    x[:2, :3] = 1
    return x + 2

strict

  • inputs: #2[(T1s5x5,),(T1s7x5,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    return (add,)

nostrict

  • inputs: #2[(T1s5x5,),(T1s7x5,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    return (add,)

nostrict-decall

  • inputs: #2[(T1s5x5,),(T1s7x5,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %clone), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=2] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add,)

InplaceSetItemSquareAdd2

forward

def forward(self, x):
    x[:2, :3] = 1
    return x + 2, x + 3

strict

  • inputs: #2[(T1s5x5,),(T1s7x5,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {})
    return (add, add_1)

nostrict

  • inputs: #2[(T1s5x5,),(T1s7x5,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {})
    return (add, add_1)

nostrict-decall

  • inputs: #2[(T1s5x5,),(T1s7x5,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %clone), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 3), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add, add_4)

SignatureFloat1

forward

def forward(self, x, alpha: float = 2.0):
    return torch.sigmoid(self.linear(x)) - self.buff * alpha

strict

FAILED

Expected input at *args[1] to be equal to 1.5, but got 2.5

nostrict

FAILED

Expected input at *args[1] to be equal to 1.5, but got 2.5

nostrict-decall

FAILED

Expected input at *args[1] to be equal to 1.5, but got 2.5

SignatureInt1

forward

def forward(self, x, i: int = 2):
    return torch.sigmoid(self.linear(x)) - self.buff + x[:, i : i + 1]

strict

FAILED

Expected input at *args[1] to be equal to 1, but got 2

nostrict

FAILED

Expected input at *args[1] to be equal to 1, but got 2

nostrict-decall

FAILED

Expected input at *args[1] to be equal to 1, but got 2

SignatureInt2

forward

def forward(self, x, i: int = 2):
    return torch.sigmoid(self.linear(x)) - self.buff + x[:, i]

strict

  • inputs: #1[(T1s4x3,int)]

  • shapes: dict(x:{0:Dim(batch)},i:None)

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %select), kwargs = {})
    return (add,)

nostrict

  • inputs: #1[(T1s4x3,int)]

  • shapes: dict(x:{0:Dim(batch)},i:None)

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x,), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %select), kwargs = {})
    return (add,)

nostrict-decall

  • inputs: #1[(T1s4x3,int)]

  • shapes: dict(x:{0:Dim(batch)},i:None)

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub_2 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x,), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add_14 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub_2, %select), kwargs = {})
    return (add_14,)

SignatureListFixedLength

forward

def forward(self, x, lx: list):
    return (
        torch.sigmoid(self.linear(x)) - self.buff + lx[0] * lx[1].sum(axis=1, keepdim=True)
    )

strict

  • inputs: #2[(T1s4x3,#2[T1s4x1,T1s4x2]),(T1s8x3,#2[T1s8x1,T1s8x2])]

  • shapes: dict(x:{0:Dim(batch)},lx:#2[{0:Dim(batch)},{0:Dim(batch)}])

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict

  • inputs: #2[(T1s4x3,#2[T1s4x1,T1s4x2]),(T1s8x3,#2[T1s8x1,T1s8x2])]

  • shapes: dict(x:{0:Dim(batch)},lx:#2[{0:Dim(batch)},{0:Dim(batch)}])

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict-decall

  • inputs: #2[(T1s4x3,#2[T1s4x1,T1s4x2]),(T1s8x3,#2[T1s8x1,T1s8x2])]

  • shapes: dict(x:{0:Dim(batch)},lx:#2[{0:Dim(batch)},{0:Dim(batch)}])

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub_2 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add_15 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub_2, %mul_4), kwargs = {})
    return (add_15,)

SignatureListFixedWithNone

forward

def forward(self, lx):
    x = lx[0]
    if lx[1] is not None:
        x += lx[1]
    if lx[2] is not None:
        x += lx[2]
    return x

strict

FAILED

Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs['lx']` has 3 elements, but `dynamic_shapes['lx']` has 2 elements
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

nostrict

FAILED

Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs['lx']` has 3 elements, but `dynamic_shapes['lx']` has 2 elements
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

nostrict-decall

FAILED

Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs['lx']` has 3 elements, but `dynamic_shapes['lx']` has 2 elements
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

SignatureListVariableLength

forward

def forward(self, x, lx: list):
    t = torch.cat(lx, dim=1).sum(axis=1, keepdim=True)
    return torch.sigmoid(self.linear(x)) - self.buff + t

strict

FAILED

Trying to flatten user inputs with exported input tree spec: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

nostrict

FAILED

Trying to flatten user inputs with exported input tree spec: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

nostrict-decall

FAILED

Trying to flatten user inputs with exported input tree spec: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    TreeSpec(list, None, [*,
      *,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

SignatureShapeAsIndex

forward

def forward(self, x, y):
    t = torch.sigmoid(self.linear(x)) + x
    return t[:, : y.shape[1]]

strict

  • inputs: #1[(T1s4x3,T1s4x2)]

  • shapes: dict(x:{0:Dim(batch)},y:{0:Dim(batch),1:Dim(length)})

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, %sym_size_int_3), kwargs = {})
    return (slice_2,)

nostrict

  • inputs: #1[(T1s4x3,T1s4x2)]

  • shapes: dict(x:{0:Dim(batch)},y:{0:Dim(batch),1:Dim(length)})

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=0] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add,), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, None, %sym_size_int_2), kwargs = {})
    return (slice_2,)

nostrict-decall

  • inputs: #1[(T1s4x3,T1s4x2)]

  • shapes: dict(x:{0:Dim(batch)},y:{0:Dim(batch),1:Dim(length)})

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=0] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_6,), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, None, %sym_size_int_5), kwargs = {})
    return (slice_2,)

TypeBFloat16

forward

def forward(self, x):
    xb = x.to(torch.bfloat16)
    return (xb + xb).to(torch.float32)

strict

  • inputs: #1[(T1s4x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %_assert_tensor_metadata_default : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x,), kwargs = {dtype: torch.float32, device: cpu, layout: torch.strided})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bfloat16), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%to, %to), kwargs = {})
    %_assert_tensor_metadata_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%add,), kwargs = {dtype: torch.bfloat16, device: cpu, layout: torch.strided})
    %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%add, torch.float32), kwargs = {})
    return (to_1,)

nostrict

  • inputs: #1[(T1s4x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %_assert_tensor_metadata_default : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x,), kwargs = {dtype: torch.float32, device: cpu, layout: torch.strided})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bfloat16), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%to, %to), kwargs = {})
    %_assert_tensor_metadata_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%add,), kwargs = {dtype: torch.bfloat16, device: cpu, layout: torch.strided})
    %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%add, torch.float32), kwargs = {})
    return (to_1,)

nostrict-decall

  • inputs: #1[(T1s4x4,)]

  • shapes: dict(x:{0:Dim(batch)})

graph():
    %x : [num_users=2] = placeholder[target=x]
    %_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x, None, None, torch.float32), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bfloat16})
    %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%_to_copy, %_to_copy), kwargs = {})
    %_assert_tensor_metadata_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%add_3, None, None, torch.bfloat16), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add_3,), kwargs = {dtype: torch.float32})
    return (_to_copy_1,)

Summary