Exported Programs with Static Shapes

The following script shows the exported program for many short cases and various way to retrieve the torch.fx.Graph equivalent to the original model. The tested scenarios are described at Tested Scenarios.

<<<

import inspect
import textwrap
import pandas
from experimental_experiment.torch_interpreter.eval import discover, run_exporter
from experimental_experiment.ext_test_case import unit_test_going

cases = discover()
print()
print(":ref:`Summary <le-summary-exported-program>`")
print()
sorted_cases = sorted(cases.items())
if unit_test_going():
    sorted_cases = sorted_cases[:3]
for name, cls_model in sorted_cases:
    print(f"* :ref:`{name} <le-model-case-export-{name}>`")
print()

obs = []
for name, cls_model in sorted(cases.items()):
    print()
    print(f".. _le-model-case-export-{name}:")
    print()
    print(name)
    print("=" * len(name))
    print()
    print("forward")
    print("+++++++")
    print()
    print("::")
    print()
    print(
        textwrap.indent(textwrap.dedent(inspect.getsource(cls_model.forward)), "    ")
    )
    print()
    for exporter in (
        "export-strict",
        "export-strict-decall",
        "export-nostrict",
        "export-nostrict-decall",
        "export-jit",
        "export-jit-decall",
        "export-tracing",
    ):
        expname = exporter.replace("export-", "")
        print()
        print(expname)
        print("+" * len(expname))
        print()
        res = run_exporter(exporter, cls_model, False, quiet=True)
        case_ref = f":ref:`{name} <le-model-case-export-{name}>`"
        expo = exporter.split("-", maxsplit=1)[-1]
        if "exported" in res:
            print("::")
            print()
            print(textwrap.indent(str(res["exported"].graph), "    "))
            print()
            obs.append(dict(case=case_ref, error="", exporter=expo))
        else:
            print("**FAILED**")
            print()
            print("::")
            print()
            print(textwrap.indent(str(res["error"]), "    "))
            print()
            obs.append(dict(case=case_ref, error="FAIL", exporter=expo))

print()
print(".. _le-summary-exported-program:")
print()
print("Summary")
print("+++++++")
print()
df = pandas.DataFrame(obs)
piv = df.pivot(index="case", columns="exporter", values="error")
print(piv.to_markdown(tablefmt="rst"))
print()

>>>

Summary

AtenAsStrided

forward

def forward(self, x):
    y = torch.as_strided(x, (2, 2, 8, 4), (128, 8, 16, 1))
    return y

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

strict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

nostrict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

jit-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.as_strided](args = (%x, (2, 2, 8, 4), (128, 8, 16, 1)), kwargs = {})
    return as_strided

AtenInterpolate

forward

def forward(self, x):
    y = torch.nn.functional.interpolate(
        x,
        scale_factor=2.0,
        mode="bilinear",
        recompute_scale_factor=False,
    )
    return y

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

strict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

nostrict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

jit-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %interpolate : [num_users=1] = call_function[target=torch.nn.functional.interpolate](args = (%x,), kwargs = {size: None, scale_factor: 2.0, mode: bilinear, align_corners: None, recompute_scale_factor: False, antialias: False})
    return interpolate

AtenNonZero

forward

def forward(self, x):
    y = torch.nonzero(x)
    return y

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u0 <= 12 on node 'le'), kwargs = {})
    return (nonzero,)

strict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    return (nonzero,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u0 <= 12 on node 'le'), kwargs = {})
    return (nonzero,)

nostrict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    return (nonzero,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_2 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_2,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_2, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    return (nonzero,)

jit-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_3 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_3,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_3, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    return (nonzero,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=1] = call_function[target=torch.nonzero](args = (%x,), kwargs = {})
    return nonzero

AtenNonZeroTuple

forward

def forward(self, x):
    y = torch.nonzero(x, as_tuple=True)
    return y[0], y[1]

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u0 <= 12 on node 'le'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

strict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=3] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 0, 1), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 1, 2), kwargs = {})
    %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_1, [1]), kwargs = {})
    %squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_2, [1]), kwargs = {})
    return (squeeze, squeeze_1)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u0 <= 12 on node 'le'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

nostrict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=3] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 0, 1), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 1, 2), kwargs = {})
    %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_1, [1]), kwargs = {})
    %squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_2, [1]), kwargs = {})
    return (squeeze, squeeze_1)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int_2 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_2,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_2, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

jit-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=3] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_3 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_3,), kwargs = {})
    %ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_3, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 0, 1), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 1, 2), kwargs = {})
    %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_1, [1]), kwargs = {})
    %squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_2, [1]), kwargs = {})
    return (squeeze, squeeze_1)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.nonzero](args = (%x,), kwargs = {as_tuple: True})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero, 1), kwargs = {})
    return (getitem, getitem_1)

AtenRollPos

forward

def forward(self, x):
    return torch.roll(x, 1, -1)

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

strict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, 4), kwargs = {layout: torch.strided, device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arange, 3), kwargs = {})
    %fmod : [num_users=1] = call_function[target=torch.ops.aten.fmod.Scalar](args = (%add, 4), kwargs = {})
    %index_select : [num_users=1] = call_function[target=torch.ops.aten.index_select.default](args = (%x, 2, %fmod), kwargs = {})
    return (index_select,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

nostrict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, 4), kwargs = {layout: torch.strided, device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arange, 3), kwargs = {})
    %fmod : [num_users=1] = call_function[target=torch.ops.aten.fmod.Scalar](args = (%add, 4), kwargs = {})
    %index_select : [num_users=1] = call_function[target=torch.ops.aten.index_select.default](args = (%x, 2, %fmod), kwargs = {})
    return (index_select,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

jit-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_5 : [num_users=4] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sym_size_int_5, 1), kwargs = {})
    %mod : [num_users=1] = call_function[target=operator.mod](args = (%sub, %sym_size_int_5), kwargs = {})
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, %sym_size_int_5), kwargs = {layout: torch.strided, device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arange, %mod), kwargs = {})
    %fmod : [num_users=1] = call_function[target=torch.ops.aten.fmod.Scalar](args = (%add, %sym_size_int_5), kwargs = {})
    %index_select : [num_users=1] = call_function[target=torch.ops.aten.index_select.default](args = (%x, 2, %fmod), kwargs = {})
    return (index_select,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.roll](args = (%x, 1, -1), kwargs = {})
    return roll

AtenRollRelu

forward

def forward(self, x):
    return torch.relu(torch.roll(x, -1, -1))

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

strict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, 4), kwargs = {layout: torch.strided, device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arange, 1), kwargs = {})
    %fmod : [num_users=1] = call_function[target=torch.ops.aten.fmod.Scalar](args = (%add, 4), kwargs = {})
    %index_select : [num_users=1] = call_function[target=torch.ops.aten.index_select.default](args = (%x, 2, %fmod), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%index_select,), kwargs = {})
    return (relu,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

nostrict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, 4), kwargs = {layout: torch.strided, device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arange, 1), kwargs = {})
    %fmod : [num_users=1] = call_function[target=torch.ops.aten.fmod.Scalar](args = (%add, 4), kwargs = {})
    %index_select : [num_users=1] = call_function[target=torch.ops.aten.index_select.default](args = (%x, 2, %fmod), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%index_select,), kwargs = {})
    return (relu,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

jit-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_5 : [num_users=4] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sym_size_int_5, -1), kwargs = {})
    %mod : [num_users=1] = call_function[target=operator.mod](args = (%sub, %sym_size_int_5), kwargs = {})
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, %sym_size_int_5), kwargs = {layout: torch.strided, device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arange, %mod), kwargs = {})
    %fmod : [num_users=1] = call_function[target=torch.ops.aten.fmod.Scalar](args = (%add, %sym_size_int_5), kwargs = {})
    %index_select : [num_users=1] = call_function[target=torch.ops.aten.index_select.default](args = (%x, 2, %fmod), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%index_select,), kwargs = {})
    return (relu,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.roll](args = (%x, -1, -1), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%roll,), kwargs = {})
    return relu

BuildInIsInstance

forward

def forward(self, x, lx: list | torch.Tensor):
    if isinstance(lx, list):
        t = lx[0] * lx[1].sum(axis=1, keepdim=True)
        return torch.sigmoid(self.linear(x)) - self.buff + t
    return torch.sigmoid(self.linear(x)) - self.buff + lx

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

strict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit-decall

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add_15 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub_4, %mul_1), kwargs = {})
    return (add_15,)

tracing

Traceback (most recent call last):
File “~/vv/this312/lib/python3.12/site-packages/torch/fx/graph_module.py”, line 403, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File “~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py”, line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py”, line 1778, in _call_impl
return forward_call(*args, **kwargs)
File “<eval_with_key>.6722 from ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:776 in forward”, line 9, in forward
add = sub + lx; sub = lx = None

~~~~^~~~

TypeError: unsupported operand type(s) for +: ‘Tensor’ and ‘list’

Call using an FX-traced Module, line 9 of the traced Module’s generated forward function:

sub = sigmoid - buff; sigmoid = buff = None add = sub + lx; sub = lx = None

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <— HERE

return add

FAILED

unsupported operand type(s) for +: 'Tensor' and 'list'

BuildInLen

forward

def forward(self, x, lx: list):
    t = lx[0] * lx[1].sum(axis=1, keepdim=True)
    if len(lx) > 2:
        t = t + lx[2].sum(axis=1, keepdim=True)
    return torch.sigmoid(self.linear(x)) - self.buff + t

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

strict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit-decall

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add_15 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub_4, %mul_1), kwargs = {})
    return (add_15,)

tracing

FAILED

len(.) expects an integer, len needs to be replaced. You should use _len.

ComplexPolar

forward

def forward(self, x, angle):
    return torch.polar(x, angle)

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

strict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

nostrict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

jit-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.polar](args = (%x, %angle), kwargs = {})
    return polar

ControlFlowCond

forward

def forward(self, x):
    def true_fn(x):
        return torch.sin(x)

    def false_fn(x):
        return torch.cos(x)

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

strict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

jit

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decall

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn_0]
    %_cb_cond_false_fn_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn_0, %_cb_cond_false_fn_0, [%x]), kwargs = {})
    return condcc

ControlFlowCond2Inputs

forward

def forward(self, x, y):
    def true_fn(x, y):
        return torch.sin(x), torch.cos(x) + y

    def false_fn(x, y):
        return torch.cos(x), torch.sin(x) + y

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y])

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %y)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

strict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %y)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %y)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %y)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

jit

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decall

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn_0]
    %_cb_cond_false_fn_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn_0, %_cb_cond_false_fn_0, [%x, %y]), kwargs = {})
    return condcc

ControlFlowCond2Outputs

forward

def forward(self, x):
    def true_fn(x):
        return torch.sin(x), torch.cos(x)

    def false_fn(x):
        return torch.cos(x), torch.sin(x)

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

strict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

jit

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decall

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn_0]
    %_cb_cond_false_fn_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn_0, %_cb_cond_false_fn_0, [%x]), kwargs = {})
    return condcc

ControlFlowCondConstant

forward

def forward(self, x):
    def true_fn(x):
        return torch.sin(x) - torch.ones(x.shape, dtype=x.dtype)

    def false_fn(x):
        return torch.cos(x) + torch.ones((1, 1024), dtype=x.dtype)

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

strict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

jit

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decall

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn_0]
    %_cb_cond_false_fn_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn_0, %_cb_cond_false_fn_0, [%x]), kwargs = {})
    return condcc

ControlFlowCondNestedModule

forward

def forward(self, x):
    def true_fn(x):
        return self.submodule(x)

    def false_fn(x):
        return x - self.weight

    y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
    return y

strict

graph():
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %weight : [num_users=1] = get_attr[target=weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %submodule_weight, %weight)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

strict-decall

graph():
    %weight : [num_users=1] = get_attr[target=weight]
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %submodule_weight, %weight)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

graph():
    %weight : [num_users=1] = get_attr[target=weight]
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %submodule_weight, %weight)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decall

graph():
    %weight : [num_users=1] = get_attr[target=weight]
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %submodule_weight, %weight)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

jit

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decall

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn_0]
    %_cb_cond_false_fn_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn_0, %_cb_cond_false_fn_0, [%x]), kwargs = {})
    return condcc

ControlFlowCondNonZero

forward

def forward(self, input_ids, image_features, vocab_size):
    def then_branch(input_ids, image_features, vocab_size):
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])

        condition = (input_ids < 0) & (input_ids > -int(1e9))
        positions = torch.nonzero(condition, as_tuple=True)
        input_ids = input_ids.clamp_min(0).clamp_max(vocab_size)
        return (input_ids, positions[0], positions[1])

    def else_branch(input_ids, image_features, vocab_size):
        r = torch.where(torch.zeros((1, 1), dtype=torch.bool))
        return (input_ids, r[0], r[1])

    a, b, c = torch.cond(
        image_features.numel() > 0,
        then_branch,
        else_branch,
        [input_ids, image_features, vocab_size],
    )
    return a, b, c

strict

graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %image_features : [num_users=0] = placeholder[target=image_features]
    %vocab_size : [num_users=0] = placeholder[target=vocab_size]
    %view : [num_users=3] = call_function[target=torch.ops.aten.view.default](args = (%input_ids, [-1, 12]), kwargs = {})
    %lt : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%view, 0), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%view, -1000000000), kwargs = {})
    %and_1 : [num_users=1] = call_function[target=torch.ops.aten.__and__.Tensor](args = (%lt, %gt), kwargs = {})
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%and_1,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 24), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u0 <= 24 on node 'le'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    %clamp_min : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%view, 0), kwargs = {})
    %clamp_max : [num_users=1] = call_function[target=torch.ops.aten.clamp_max.default](args = (%clamp_min, 1025), kwargs = {})
    return (clamp_max, getitem_2, getitem_1)

strict-decall

graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %image_features : [num_users=0] = placeholder[target=image_features]
    %vocab_size : [num_users=0] = placeholder[target=vocab_size]
    %view : [num_users=3] = call_function[target=torch.ops.aten.view.default](args = (%input_ids, [-1, 12]), kwargs = {})
    %lt : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%view, 0), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%view, -1000000000), kwargs = {})
    %bitwise_and : [num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%lt, %gt), kwargs = {})
    %nonzero : [num_users=3] = call_function[target=torch.ops.aten.nonzero.default](args = (%bitwise_and,), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 24), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 24 on node 'le_1'), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 0, 1), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 1, 2), kwargs = {})
    %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_1, [1]), kwargs = {})
    %squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_2, [1]), kwargs = {})
    %clamp : [num_users=1] = call_function[target=torch.ops.aten.clamp.default](args = (%view, 0), kwargs = {})
    %clamp_1 : [num_users=1] = call_function[target=torch.ops.aten.clamp.default](args = (%clamp, None, 1025), kwargs = {})
    return (clamp_1, squeeze, squeeze_1)

nostrict

graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %image_features : [num_users=0] = placeholder[target=image_features]
    %vocab_size : [num_users=0] = placeholder[target=vocab_size]
    %view : [num_users=3] = call_function[target=torch.ops.aten.view.default](args = (%input_ids, [-1, 12]), kwargs = {})
    %lt : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%view, 0), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%view, -1000000000), kwargs = {})
    %and_1 : [num_users=1] = call_function[target=torch.ops.aten.__and__.Tensor](args = (%lt, %gt), kwargs = {})
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%and_1,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 24), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u0 <= 24 on node 'le'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    %clamp_min : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%view, 0), kwargs = {})
    %clamp_max : [num_users=1] = call_function[target=torch.ops.aten.clamp_max.default](args = (%clamp_min, 1025), kwargs = {})
    return (clamp_max, getitem_2, getitem_1)

nostrict-decall

graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %image_features : [num_users=0] = placeholder[target=image_features]
    %vocab_size : [num_users=0] = placeholder[target=vocab_size]
    %view : [num_users=3] = call_function[target=torch.ops.aten.view.default](args = (%input_ids, [-1, 12]), kwargs = {})
    %lt : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%view, 0), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%view, -1000000000), kwargs = {})
    %bitwise_and : [num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%lt, %gt), kwargs = {})
    %nonzero : [num_users=3] = call_function[target=torch.ops.aten.nonzero.default](args = (%bitwise_and,), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 24), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 24 on node 'le_1'), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 0, 1), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%nonzero, 1, 1, 2), kwargs = {})
    %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_1, [1]), kwargs = {})
    %squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%slice_2, [1]), kwargs = {})
    %clamp : [num_users=1] = call_function[target=torch.ops.aten.clamp.default](args = (%view, 0), kwargs = {})
    %clamp_1 : [num_users=1] = call_function[target=torch.ops.aten.clamp.default](args = (%clamp, None, 1025), kwargs = {})
    return (clamp_1, squeeze, squeeze_1)

jit

FAILED

Type 'Tuple[Tensor, Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

jit-decall

FAILED

Type 'Tuple[Tensor, Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

tracing

graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %image_features : [num_users=2] = placeholder[target=image_features]
    %vocab_size : [num_users=1] = placeholder[target=vocab_size]
    %numel : [num_users=1] = call_method[target=numel](args = (%image_features,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%numel, 0), kwargs = {})
    %_cb_cond_then_branch_0 : [num_users=1] = get_attr[target=_cb_cond_then_branch_0]
    %_cb_cond_else_branch_0 : [num_users=1] = get_attr[target=_cb_cond_else_branch_0]
    %condcc : [num_users=3] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_then_branch_0, %_cb_cond_else_branch_0, [%input_ids, %image_features, %vocab_size]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%condcc, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%condcc, 1), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%condcc, 2), kwargs = {})
    return (getitem, getitem_1, getitem_2)

ControlFlowNestCond

forward

def forward(self, x):
    def true_fn2(x):
        def true_fn1(x):
            return torch.sin(x)

        def false_fn1(x):
            return torch.cos(x)

        return torch.cond(x.sum() < 0, true_fn1, false_fn1, [x])

    def false_fn2(x):
        return -x

    return torch.cond(x.sum() > 0, true_fn2, false_fn2, [x])

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

strict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%x, []), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

jit

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decall

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn2_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn2_0]
    %_cb_cond_false_fn2_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn2_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn2_0, %_cb_cond_false_fn2_0, [%x]), kwargs = {})
    return condcc

ControlFlowScan

forward

def forward(self, x):
    init = torch.zeros_like(x[0])
    carry, out = torch.ops.higher_order.scan(
        ControlFlowScan.add, [init], [x], additional_inputs=[]
    )
    return carry

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 377, in forward
    carry, out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

strict-decall

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 377, in forward
    carry, out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], ()), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem,)

nostrict-decall

FAILED

scan might be aliasing the input or the output!

While executing %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], ()), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x):
        x: "f32[3, 3][3, 1]"; 

        x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
         # File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:376 in forward, code: init = torch.zeros_like(x[0])
        select: "f32[3][1]" = torch.ops.aten.select.int(x, 0, 0)
        zeros_like: "f32[3][1]" = torch.ops.aten.zeros_like.default(select, pin_memory = False);  select = None
    
         # File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:377 in forward, code: carry, out = torch.ops.higher_order.scan(
        scan_combine_graph_0 = self.scan_combine_graph_0
        scan = torch.ops.higher_order.scan(scan_combine_graph_0, [zeros_like], [x], ());  scan_combine_graph_0 = zeros_like = x = None
        getitem: "f32[3][1]" = scan[0]
        getitem_1: "f32[3, 3][3, 1]" = scan[1];  scan = getitem_1 = None
        return pytree.tree_unflatten((getitem,), self._out_spec)
    
    class scan_combine_graph_0(torch.nn.Module):
        def forward(self, carry_1: "f32[3][1]", y_1: "f32[3][1]"):
             # File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:377 in forward, code: carry, out = torch.ops.higher_order.scan(
            add: "f32[3][1]" = torch.ops.aten.add.Tensor(carry_1, y_1);  carry_1 = y_1 = None
            return [add, add]
        

Original traceback:
File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 377, in forward
    carry, out = torch.ops.higher_order.scan(

jit

FAILED

could not find kernel for HigherOrderOperator scan at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

jit-decall

FAILED

could not find kernel for HigherOrderOperator scan at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

tracing

FAILED

Unable to symbolically trace HigherOrderOperators

ControlFlowScan2Carried

forward

def forward(self, x):
    init1 = torch.zeros_like(x[0])
    init2 = torch.ones_like(x[0])
    carry1, carry2, out1, out2 = torch.ops.higher_order.scan(
        ControlFlowScan2Carried.add,
        [init1, init2],
        [x, x * 2],
        # dim=0,  # 01/31/2025, not supported anymore
        additional_inputs=[],
    )
    return carry1, carry2, out1, out2

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 397, in forward
    carry1, carry2, out1, out2 = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

strict-decall

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 397, in forward
    carry1, carry2, out1, out2 = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

graph():
    %x : [num_users=4] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %select_1 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %ones_like : [num_users=1] = call_function[target=torch.ops.aten.ones_like.default](args = (%select_1,), kwargs = {pin_memory: False})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 2), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=4] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like, %ones_like], [%x, %mul], ()), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 2), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 3), kwargs = {})
    return (getitem, getitem_1, getitem_2, getitem_3)

nostrict-decall

FAILED

scan might be aliasing the input or the output!

While executing %scan : [num_users=4] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like, %ones_like], [%x, %mul], ()), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x):
        x: "f32[3, 4][4, 1]"; 

        x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
         # File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:395 in forward, code: init1 = torch.zeros_like(x[0])
        select: "f32[4][1]" = torch.ops.aten.select.int(x, 0, 0)
        zeros_like: "f32[4][1]" = torch.ops.aten.zeros_like.default(select, pin_memory = False);  select = None
    
         # File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:396 in forward, code: init2 = torch.ones_like(x[0])
        select_1: "f32[4][1]" = torch.ops.aten.select.int(x, 0, 0)
        ones_like: "f32[4][1]" = torch.ops.aten.ones_like.default(select_1, pin_memory = False);  select_1 = None
    
         # File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:400 in forward, code: [x, x * 2],
        mul: "f32[3, 4][4, 1]" = torch.ops.aten.mul.Tensor(x, 2)
    
         # File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:397 in forward, code: carry1, carry2, out1, out2 = torch.ops.higher_order.scan(
        scan_combine_graph_0 = self.scan_combine_graph_0
        scan = torch.ops.higher_order.scan(scan_combine_graph_0, [zeros_like, ones_like], [x, mul], ());  scan_combine_graph_0 = zeros_like = ones_like = x = mul = None
        getitem: "f32[4][1]" = scan[0]
        getitem_1: "f32[4][1]" = scan[1]
        getitem_2: "f32[3, 4][4, 1]" = scan[2]
        getitem_3: "f32[3, 4][4, 1]" = scan[3];  scan = None
        return pytree.tree_unflatten((getitem, getitem_1, getitem_2, getitem_3), self._out_spec)
    
    class scan_combine_graph_0(torch.nn.Module):
        def forward(self, carry1_1: "f32[4][1]", carry2_1: "f32[4][1]", y1_1: "f32[4][1]", y2_1: "f32[4][1]"):
             # File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:397 in forward, code: carry1, carry2, out1, out2 = torch.ops.higher_order.scan(
            add: "f32[4][1]" = torch.ops.aten.add.Tensor(carry1_1, y1_1);  carry1_1 = y1_1 = None
            mul: "f32[4][1]" = torch.ops.aten.mul.Tensor(carry2_1, y2_1);  carry2_1 = y2_1 = None
            return [add, mul, add, mul]
        

Original traceback:
File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 397, in forward
    carry1, carry2, out1, out2 = torch.ops.higher_order.scan(

jit

FAILED

could not find kernel for HigherOrderOperator scan at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

jit-decall

FAILED

could not find kernel for HigherOrderOperator scan at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

tracing

FAILED

Unable to symbolically trace HigherOrderOperators

ControlFlowScanCDist

forward

def forward(self, x):
    carry, out = torch.ops.higher_order.scan(
        ControlFlowScanCDist.dist,
        [x],
        [x],
        # dim=0,  # 01/31/2025, not supported anymore
        additional_inputs=[],
    )
    return out

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 422, in forward
    carry, out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

strict-decall

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 422, in forward
    carry, out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%x], [%x], ()), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%x], [%x], ()), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

jit

FAILED

could not find kernel for HigherOrderOperator scan at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

jit-decall

FAILED

could not find kernel for HigherOrderOperator scan at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

tracing

FAILED

Unable to symbolically trace HigherOrderOperators

ControlFlowScanCDist2

forward

def forward(self, x):
    z = torch.tensor([0], dtype=torch.float32)
    y = x.clone()
    out = torch.ops.higher_order.scan(
        ControlFlowScanCDist2.dist,
        [z],
        [x],
        # dim=0,  # 01/31/2025, not supported anymore
        additional_inputs=[y],
    )
    return out[1]

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 449, in forward
    out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

strict-decall

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 449, in forward
    out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {})
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%detach_], [%x], (%clone,)), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict-decall

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %clone_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%clone], [%x], (%clone_1,)), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

jit

~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:447: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.

z = torch.tensor([0], dtype=torch.float32)

FAILED

could not find kernel for HigherOrderOperator scan at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

jit-decall

FAILED

could not find kernel for HigherOrderOperator scan at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

tracing

FAILED

(CustomProxy(clone),) can only be of (<class 'torch.Tensor'>, <class 'int'>, <class 'torch.SymInt'>) but got (<class 'experimental_experiment.torch_interpreter.tracing.CustomProxy'>,)

ControlFlowScanCDistXY

forward

def forward(self, x, y):
    carry, out = torch.ops.higher_order.scan(
        ControlFlowScanCDistXY.dist,
        [y],
        [x],
        # dim=0,  # 01/31/2025, not supported anymore
        additional_inputs=[],
    )
    return out

strict

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 475, in forward
    carry, out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

strict-decall

FAILED

scan must be captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 475, in forward
    carry, out = torch.ops.higher_order.scan(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%y], [%x], ()), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%y], [%x], ()), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

jit

FAILED

could not find kernel for HigherOrderOperator scan at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

jit-decall

FAILED

could not find kernel for HigherOrderOperator scan at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

tracing

FAILED

Unable to symbolically trace HigherOrderOperators

CreateFromShape

forward

def forward(self, x):
    y = torch.ones((x.shape[0], x.shape[1] + 1))
    return y

strict

graph():
    %x : [num_users=0] = placeholder[target=x]
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([4, 5],), kwargs = {device: cpu, pin_memory: False})
    return (ones,)

strict-decall

graph():
    %x : [num_users=0] = placeholder[target=x]
    %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([4, 5], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    return (full,)

nostrict

graph():
    %x : [num_users=0] = placeholder[target=x]
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([4, 5],), kwargs = {device: cpu, pin_memory: False})
    return (ones,)

nostrict-decall

graph():
    %x : [num_users=0] = placeholder[target=x]
    %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([4, 5], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    return (full,)

jit

graph():
    %lifted_tensor_5 : [num_users=1] = get_attr[target=lifted_tensor_5]
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%sym_size_int_3,), kwargs = {dtype: torch.int64})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%scalar_tensor, %lifted_tensor_5), kwargs = {})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add,), kwargs = {dtype: torch.int32})
    %_local_scalar_dense : [num_users=3] = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%_to_copy,), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%_local_scalar_dense,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([%sym_size_int_2, %_local_scalar_dense],), kwargs = {device: cpu, pin_memory: False})
    return (ones,)

jit-decall

graph():
    %lifted_tensor_5 : [num_users=1] = get_attr[target=lifted_tensor_5]
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%sym_size_int_3,), kwargs = {dtype: torch.int64})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%scalar_tensor, %lifted_tensor_5), kwargs = {})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add,), kwargs = {dtype: torch.int32})
    %_local_scalar_dense : [num_users=3] = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%_to_copy,), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%_local_scalar_dense,), kwargs = {})
    %ge_3 : [num_users=1] = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_3, Runtime assertion failed for expression u0 >= 0 on node 'ge_3'), kwargs = {})
    %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_2, %_local_scalar_dense], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    return (full,)

tracing

FAILED

ones(): argument 'size' (position 1) must be tuple of ints, but found element of type CustomProxy at pos 0

CreateFromShapeThroughFunction

forward

def forward(self, x):
    dy1 = CreateFromShapeThroughFunction.add_one(x.shape[1])
    y = torch.ones((x.shape[0], dy1))
    return y

strict

graph():
    %x : [num_users=0] = placeholder[target=x]
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([4, 5],), kwargs = {device: cpu, pin_memory: False})
    return (ones,)

strict-decall

graph():
    %x : [num_users=0] = placeholder[target=x]
    %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([4, 5], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    return (full,)

nostrict

graph():
    %x : [num_users=0] = placeholder[target=x]
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([4, 5],), kwargs = {device: cpu, pin_memory: False})
    return (ones,)

nostrict-decall

graph():
    %x : [num_users=0] = placeholder[target=x]
    %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([4, 5], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    return (full,)

jit

graph():
    %lifted_tensor_6 : [num_users=1] = get_attr[target=lifted_tensor_6]
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%sym_size_int_3,), kwargs = {dtype: torch.int64})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%scalar_tensor, %lifted_tensor_6), kwargs = {})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add,), kwargs = {dtype: torch.int32})
    %_local_scalar_dense : [num_users=3] = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%_to_copy,), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%_local_scalar_dense,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([%sym_size_int_2, %_local_scalar_dense],), kwargs = {device: cpu, pin_memory: False})
    return (ones,)

jit-decall

graph():
    %lifted_tensor_6 : [num_users=1] = get_attr[target=lifted_tensor_6]
    %x : [num_users=2] = placeholder[target=x]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%sym_size_int_3,), kwargs = {dtype: torch.int64})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%scalar_tensor, %lifted_tensor_6), kwargs = {})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add,), kwargs = {dtype: torch.int32})
    %_local_scalar_dense : [num_users=3] = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%_to_copy,), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%_local_scalar_dense,), kwargs = {})
    %ge_3 : [num_users=1] = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_3, Runtime assertion failed for expression u0 >= 0 on node 'ge_3'), kwargs = {})
    %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_2, %_local_scalar_dense], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    return (full,)

tracing

FAILED

ones(): argument 'size' (position 1) must be tuple of ints, but found element of type CustomProxy at pos 0

CropLastDimensionWithTensorContent

forward

def forward(self, x, shape):
    return x[..., : shape[0]]

strict

FAILED

Dynamic slicing with Tensor arguments
  Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
  Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.

  Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: TensorVariable(), step: ConstantVariable(NoneType: None)


from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 865, in forward
    return x[..., : shape[0]]

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

strict-decall

FAILED

Dynamic slicing with Tensor arguments
  Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
  Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.

  Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: TensorVariable(), step: ConstantVariable(NoneType: None)


from user code:
   File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 865, in forward
    return x[..., : shape[0]]

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

nostrict

def forward(self, arg0_1: “f32[s35, s16, s90]”, arg1_1: “i64[1]”):

# File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:865 in forward, code: return x[…, : shape[0]]

select: “i64[]” = torch.ops.aten.select.int(arg1_1, 0, 0); arg1_1 = None item: “Sym(u0)” = torch.ops.aten.item.default(select); select = item = None

def forward(self, arg0_1: “f32[s35, s16, s90]”, arg1_1: “i64[1]”):

# File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:865 in forward, code: return x[…, : shape[0]]

select: “i64[]” = torch.ops.aten.select.int(arg1_1, 0, 0); arg1_1 = None item: “Sym(u0)” = torch.ops.aten.item.default(select); select = item = None

FAILED

Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:973 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 865, in forward
    return x[..., : shape[0]]


The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

nostrict-decall

def forward(self, arg0_1: “f32[s35, s16, s90]”, arg1_1: “i64[1]”):

# File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:865 in forward, code: return x[…, : shape[0]]

select: “i64[]” = torch.ops.aten.select.int(arg1_1, 0, 0); arg1_1 = None item: “Sym(u0)” = torch.ops.aten.item.default(select); select = item = None

def forward(self, arg0_1: “f32[s35, s16, s90]”, arg1_1: “i64[1]”):

# File: ~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:865 in forward, code: return x[…, : shape[0]]

select: “i64[]” = torch.ops.aten.select.int(arg1_1, 0, 0); arg1_1 = None item: “Sym(u0)” = torch.ops.aten.item.default(select); select = item = None

FAILED

Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:973 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 865, in forward
    return x[..., : shape[0]]


The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

jit

def forward(self, arg0_1: “f32[s35, s16, s90]”, arg1_1: “i64[1]”):

# No stacktrace found for following nodes select: “i64[]” = torch.ops.aten.select.int(arg1_1, 0, 0); arg1_1 = None _to_copy: “i32[]” = torch.ops.aten._to_copy.default(select, dtype = torch.int32); select = None _local_scalar_dense: “Sym(u0)” = torch.ops.aten._local_scalar_dense.default(_to_copy); _to_copy = None slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, _local_scalar_dense); arg0_1 = _local_scalar_dense = slice_1 = None

def forward(self, arg0_1: “f32[s35, s16, s90]”, arg1_1: “i64[1]”):

# No stacktrace found for following nodes select: “i64[]” = torch.ops.aten.select.int(arg1_1, 0, 0); arg1_1 = None _to_copy: “i32[]” = torch.ops.aten._to_copy.default(select, dtype = torch.int32); select = None _local_scalar_dense: “Sym(u0)” = torch.ops.aten._local_scalar_dense.default(_to_copy); _to_copy = None slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, _local_scalar_dense); arg0_1 = _local_scalar_dense = slice_1 = None

FAILED

Could not guard on data-dependent expression u0 < 0 (unhinted: u0 < 0).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:734 in slice_forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "<string>", line 1, in <lambda>


While executing %slice_tensor : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, %_local_scalar_dense_default, 1), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x, shape):
        # No stacktrace found for following nodes
        select_int = torch.ops.aten.select.int(shape, 0, 0);  shape = None
        _to_copy_default = torch.ops.aten._to_copy.default(select_int, dtype = torch.int32);  select_int = None
        _local_scalar_dense_default = torch.ops.aten._local_scalar_dense.default(_to_copy_default);  _to_copy_default = None
        slice_tensor = torch.ops.aten.slice.Tensor(x, 2, 0, _local_scalar_dense_default, 1);  x = _local_scalar_dense_default = None
        return slice_tensor
    

Original traceback:
None

jit-decall

def forward(self, arg0_1: “f32[s35, s16, s90]”, arg1_1: “i64[1]”):

# No stacktrace found for following nodes select: “i64[]” = torch.ops.aten.select.int(arg1_1, 0, 0); arg1_1 = None _to_copy: “i32[]” = torch.ops.aten._to_copy.default(select, dtype = torch.int32); select = None _local_scalar_dense: “Sym(u0)” = torch.ops.aten._local_scalar_dense.default(_to_copy); _to_copy = None slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, _local_scalar_dense); arg0_1 = _local_scalar_dense = slice_1 = None

def forward(self, arg0_1: “f32[s35, s16, s90]”, arg1_1: “i64[1]”):

# No stacktrace found for following nodes select: “i64[]” = torch.ops.aten.select.int(arg1_1, 0, 0); arg1_1 = None _to_copy: “i32[]” = torch.ops.aten._to_copy.default(select, dtype = torch.int32); select = None _local_scalar_dense: “Sym(u0)” = torch.ops.aten._local_scalar_dense.default(_to_copy); _to_copy = None slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, _local_scalar_dense); arg0_1 = _local_scalar_dense = slice_1 = None

FAILED

Could not guard on data-dependent expression u0 < 0 (unhinted: u0 < 0).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:734 in slice_forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "<string>", line 1, in <lambda>


While executing %slice_tensor : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, %_local_scalar_dense_default, 1), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x, shape):
        # No stacktrace found for following nodes
        select_int = torch.ops.aten.select.int(shape, 0, 0);  shape = None
        _to_copy_default = torch.ops.aten._to_copy.default(select_int, dtype = torch.int32);  select_int = None
        _local_scalar_dense_default = torch.ops.aten._local_scalar_dense.default(_to_copy_default);  _to_copy_default = None
        slice_tensor = torch.ops.aten.slice.Tensor(x, 2, 0, _local_scalar_dense_default, 1);  x = _local_scalar_dense_default = None
        return slice_tensor
    

Original traceback:
None

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %shape : [num_users=1] = placeholder[target=shape]
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%shape, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%x, (Ellipsis, slice(None, getitem, None))), kwargs = {})
    return getitem_1

CropLastDimensionWithTensorShape

forward

def forward(self, x, y):
    return x[..., : y.shape[0]]

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, 2), kwargs = {})
    return (slice_1,)

strict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, 2), kwargs = {})
    return (slice_1,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, 2), kwargs = {})
    return (slice_1,)

nostrict-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, 2), kwargs = {})
    return (slice_1,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, %sym_size_int_4), kwargs = {})
    return (slice_1,)

jit-decall

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_7 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, %sym_size_int_7), kwargs = {})
    return (slice_1,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%y, shape), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%getattr_1, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%x, (Ellipsis, slice(None, getitem, None))), kwargs = {})
    return getitem_1

InplaceAdd

forward

def forward(self, x):
    x += self.bias
    return x

strict

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

strict-decall

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (copy__default,)

nostrict

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

nostrict-decall

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (copy__default,)

jit

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %lifted_tensor_3), kwargs = {})
    return (add_,)

jit-decall

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=2] = placeholder[target=x]
    %add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_3), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add_3), kwargs = {})
    return (copy__default,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %bias : [num_users=1] = get_attr[target=bias]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %bias), kwargs = {})
    return add

InplaceAdd2

forward

def forward(self, x):
    x.add_(self.bias)
    return x

strict

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

strict-decall

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (copy__default,)

nostrict

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

nostrict-decall

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (copy__default,)

jit

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %lifted_tensor_3), kwargs = {})
    return (add_,)

jit-decall

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=2] = placeholder[target=x]
    %add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_3), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add_3), kwargs = {})
    return (copy__default,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %bias : [num_users=1] = get_attr[target=bias]
    %add_ : [num_users=1] = call_method[target=add_](args = (%x, %bias), kwargs = {})
    return add_

InplaceAdd_Mul

forward

def forward(self, x):
    x.add_(self.bias)
    return x * 2

strict

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_, 2), kwargs = {})
    return (mul,)

strict-decall

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (mul,)

nostrict

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_, 2), kwargs = {})
    return (mul,)

nostrict-decall

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (mul,)

jit

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %lifted_tensor_4 : [num_users=1] = get_attr[target=lifted_tensor_4]
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x_1, %lifted_tensor_4), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_, %lifted_tensor_2), kwargs = {})
    return (mul,)

jit-decall

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %lifted_tensor_4 : [num_users=1] = get_attr[target=lifted_tensor_4]
    %x_1 : [num_users=2] = placeholder[target=x_1]
    %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%x_1, %lifted_tensor_4), kwargs = {})
    %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_3, %lifted_tensor_2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x_1, %add_3), kwargs = {})
    return (mul_4,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %bias : [num_users=1] = get_attr[target=bias]
    %add_ : [num_users=1] = call_method[target=add_](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%add_, 2), kwargs = {})
    return mul

InplaceCloneAdd

forward

def forward(self, x):
    x = x.clone()
    x.add_(self.bias)
    return x

strict

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%clone, %bias), kwargs = {})
    return (add_,)

strict-decall

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %bias), kwargs = {})
    return (add,)

nostrict

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%clone, %bias), kwargs = {})
    return (add_,)

nostrict-decall

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %bias), kwargs = {})
    return (add,)

jit

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x_1,), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%clone, %lifted_tensor_3), kwargs = {})
    return (add_,)

jit-decall

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x_1,), kwargs = {})
    %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %lifted_tensor_3), kwargs = {})
    return (add_6,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_method[target=clone](args = (%x,), kwargs = {})
    %bias : [num_users=1] = get_attr[target=bias]
    %add_ : [num_users=1] = call_method[target=add_](args = (%clone, %bias), kwargs = {})
    return add_

InplaceSetItemEllipsis_1

forward

def forward(self, index, update):
    copy = self.params.clone()
    copy[..., index] = update
    return copy

strict

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

strict-decall

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put,)

nostrict

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

nostrict-decall

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put,)

jit

graph():
    %lifted_tensor_5 : [num_users=1] = get_attr[target=lifted_tensor_5]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_5,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

jit-decall

graph():
    %lifted_tensor_5 : [num_users=1] = get_attr[target=lifted_tensor_5]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_5,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put,)

tracing

graph():
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0]
    %setitem : [num_users=1] = call_function[target=operator.setitem](args = (%_tensor_constant0, (Ellipsis, %index), %update), kwargs = {})
    return setitem

InplaceSetItemEllipsis_2

forward

def forward(self, index, update):
    copy = self.params.clone()
    copy[..., index] = update
    return copy

strict

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

strict-decall

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put,)

nostrict

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

nostrict-decall

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put,)

jit

graph():
    %lifted_tensor_5 : [num_users=1] = get_attr[target=lifted_tensor_5]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_5,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

jit-decall

graph():
    %lifted_tensor_5 : [num_users=1] = get_attr[target=lifted_tensor_5]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_5,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put,)

tracing

graph():
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0]
    %setitem : [num_users=1] = call_function[target=operator.setitem](args = (%_tensor_constant0, (Ellipsis, %index), %update), kwargs = {})
    return setitem

InplaceSetItemMask

forward

def forward(self, x):
    mask = x.to(bool)
    x[mask] = 2
    return x

strict

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %_assert_tensor_metadata_default : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x,), kwargs = {dtype: torch.float32, device: cpu, layout: torch.strided})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bool), kwargs = {})
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%x, [%to], %lift_fresh_copy), kwargs = {})
    return (index_put_,)

strict-decall

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x, None, None, torch.float32), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bool})
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%x, [%_to_copy], %clone), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %index_put), kwargs = {})
    return (copy__default,)

nostrict

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %_assert_tensor_metadata_default : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x,), kwargs = {dtype: torch.float32, device: cpu, layout: torch.strided})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bool), kwargs = {})
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%x, [%to], %lift_fresh_copy), kwargs = {})
    return (index_put_,)

nostrict-decall

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x, None, None, torch.float32), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bool})
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%x, [%_to_copy], %clone), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %index_put), kwargs = {})
    return (copy__default,)

jit

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %x : [num_users=3] = placeholder[target=x]
    %_assert_tensor_metadata_default : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x,), kwargs = {dtype: torch.float32, device: cpu, layout: torch.strided})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bool), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%x, [%to], %lifted_tensor_2), kwargs = {})
    return (index_put_,)

jit-decall

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %x : [num_users=4] = placeholder[target=x]
    %_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x, None, None, torch.float32), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bool})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%x, [%_to_copy], %lifted_tensor_2), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %index_put), kwargs = {})
    return (copy__default,)

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %to : [num_users=1] = call_method[target=to](args = (%x, torch.bool), kwargs = {})
    %setitem : [num_users=1] = call_function[target=operator.setitem](args = (%x, %to, 2), kwargs = {})
    return setitem

InplaceSetItemSquare

forward

def forward(self, x):
    x[:2, :3] = 1
    return x

strict

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    return (x,)

strict-decall

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %clone), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (copy__default,)

nostrict

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    return (x,)

nostrict-decall

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %clone), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (copy__default,)

jit

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %x : [num_users=2] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lifted_tensor_2), kwargs = {})
    return (x,)

jit-decall

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %x : [num_users=4] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %lifted_tensor_2), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (copy__default,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %setitem : [num_users=1] = call_function[target=operator.setitem](args = (%x, (slice(None, 2, None), slice(None, 3, None)), 1), kwargs = {})
    return setitem

InplaceSetItemSquareAdd

forward

def forward(self, x):
    x[:2, :3] = 1
    return x + 2

strict

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    return (add,)

strict-decall

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %clone), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=2] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add,)

nostrict

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    return (add,)

nostrict-decall

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %clone), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=2] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add,)

jit

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=2] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lifted_tensor_3), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_2), kwargs = {})
    return (add,)

jit-decall

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=4] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %lifted_tensor_3), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=2] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add_12 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, %lifted_tensor_2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add_12,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %setitem : [num_users=1] = call_function[target=operator.setitem](args = (%x, (slice(None, 2, None), slice(None, 3, None)), 1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%setitem, 2), kwargs = {})
    return add

InplaceSetItemSquareAdd2

forward

def forward(self, x):
    x[:2, :3] = 1
    return x + 2, x + 3

strict

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {})
    return (add, add_1)

strict-decall

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %clone), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 3), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add, add_1)

nostrict

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {})
    return (add, add_1)

nostrict-decall

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %clone), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 3), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add, add_1)

jit

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %lifted_tensor_4 : [num_users=1] = get_attr[target=lifted_tensor_4]
    %x : [num_users=3] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lifted_tensor_4), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_3), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_2), kwargs = {})
    return (add, add_1)

jit-decall

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %lifted_tensor_4 : [num_users=1] = get_attr[target=lifted_tensor_4]
    %x : [num_users=4] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %lifted_tensor_4), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %copy, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add_12 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, %lifted_tensor_3), kwargs = {})
    %add_16 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, %lifted_tensor_2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add_12, add_16)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %setitem : [num_users=2] = call_function[target=operator.setitem](args = (%x, (slice(None, 2, None), slice(None, 3, None)), 1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%setitem, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%setitem, 3), kwargs = {})
    return (add, add_1)

SignatureFloat1

forward

def forward(self, x, alpha: float = 2.0):
    return torch.sigmoid(self.linear(x)) - self.buff * alpha

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %alpha : [num_users=0] = placeholder[target=alpha]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%buff, 1.5), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %mul), kwargs = {})
    return (sub,)

strict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %alpha : [num_users=0] = placeholder[target=alpha]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%buff, 1.5), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %mul), kwargs = {})
    return (sub,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %alpha : [num_users=0] = placeholder[target=alpha]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%buff, 1.5), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %mul), kwargs = {})
    return (sub,)

nostrict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %alpha : [num_users=0] = placeholder[target=alpha]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%buff, 1.5), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %mul), kwargs = {})
    return (sub,)

jit

FAILED

Type 'Tuple[Tensor, float]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

jit-decall

FAILED

Type 'Tuple[Tensor, float]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %alpha : float [num_users=1] = placeholder[target=alpha](default=2.0)
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %buff : [num_users=1] = get_attr[target=buff]
    %mul : [num_users=1] = call_method[target=mul](args = (%buff, %alpha), kwargs = {})
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sigmoid, %mul), kwargs = {})
    return sub

SignatureInt1

forward

def forward(self, x, i: int = 2):
    return torch.sigmoid(self.linear(x)) - self.buff + x[:, i : i + 1]

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 1, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %slice_2), kwargs = {})
    return (add,)

strict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 1, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %slice_2), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x,), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 1, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %slice_2), kwargs = {})
    return (add,)

nostrict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x,), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 1, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %slice_2), kwargs = {})
    return (add,)

jit

FAILED

Type 'Tuple[Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

jit-decall

FAILED

Type 'Tuple[Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %i : int [num_users=2] = placeholder[target=i](default=2)
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %buff : [num_users=1] = get_attr[target=buff]
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%i, 1), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%x, (slice(None, None, None), slice(i, add, None))), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%sub, %getitem), kwargs = {})
    return add_1

SignatureInt2

forward

def forward(self, x, i: int = 2):
    return torch.sigmoid(self.linear(x)) - self.buff + x[:, i]

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %select), kwargs = {})
    return (add,)

strict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %select), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x,), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %select), kwargs = {})
    return (add,)

nostrict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x,), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %select), kwargs = {})
    return (add,)

jit

FAILED

Type 'Tuple[Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

jit-decall

FAILED

Type 'Tuple[Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %i : int [num_users=1] = placeholder[target=i](default=2)
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %buff : [num_users=1] = get_attr[target=buff]
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sigmoid, %buff), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%x, (slice(None, None, None), %i)), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sub, %getitem), kwargs = {})
    return add

SignatureListFixedLength

forward

def forward(self, x, lx: list):
    return (
        torch.sigmoid(self.linear(x)) - self.buff + lx[0] * lx[1].sum(axis=1, keepdim=True)
    )

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

strict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit-decall

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub_2 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add_15 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub_2, %mul_4), kwargs = {})
    return (add_15,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %lx : list [num_users=2] = placeholder[target=lx]
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %buff : [num_users=1] = get_attr[target=buff]
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sigmoid, %buff), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%lx, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%lx, 1), kwargs = {})
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {axis: 1, keepdim: True})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%getitem, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sub, %mul), kwargs = {})
    return add

SignatureListFixedWithNone

forward

def forward(self, lx):
    x = lx[0]
    if lx[1] is not None:
        x += lx[1]
    if lx[2] is not None:
        x += lx[2]
    return x

strict

FAILED

Unable to clone type <class 'NoneType'>, x=None into numpy

strict-decall

FAILED

Unable to clone type <class 'NoneType'>, x=None into numpy

nostrict

FAILED

Unable to clone type <class 'NoneType'>, x=None into numpy

nostrict-decall

FAILED

Unable to clone type <class 'NoneType'>, x=None into numpy

jit

FAILED

Type 'Tuple[List[Optional[Tensor]]]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

jit-decall

FAILED

Type 'Tuple[List[Optional[Tensor]]]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

tracing

FAILED

Unable to clone type <class 'NoneType'>, x=None into numpy

SignatureListVariableLength

forward

def forward(self, x, lx: list):
    t = torch.cat(lx, dim=1).sum(axis=1, keepdim=True)
    return torch.sigmoid(self.linear(x)) - self.buff + t

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

strict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

nostrict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

jit

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

jit-decall

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %sub_5 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add_15 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub_5, %sum_1), kwargs = {})
    return (add_15,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %lx : list [num_users=1] = placeholder[target=lx]
    %cat : [num_users=1] = call_function[target=torch.cat](args = (%lx, 1), kwargs = {})
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%cat,), kwargs = {axis: 1, keepdim: True})
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %buff : [num_users=1] = get_attr[target=buff]
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sub, %sum_1), kwargs = {})
    return add

SignatureShapeAsIndex

forward

def forward(self, x, y):
    t = torch.sigmoid(self.linear(x)) + x
    return t[:, : y.shape[1]]

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 2), kwargs = {})
    return (slice_2,)

strict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 2), kwargs = {})
    return (slice_2,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=0] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add,), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, None, 2), kwargs = {})
    return (slice_2,)

nostrict-decall

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=0] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add,), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, None, 2), kwargs = {})
    return (slice_2,)

jit

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, %sym_size_int_3), kwargs = {})
    return (slice_2,)

jit-decall

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_6 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
    %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_6, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, %sym_size_int_6), kwargs = {})
    return (slice_2,)

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sigmoid, %x), kwargs = {})
    %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%y, shape), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%getattr_1, 1), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%add, (slice(None, None, None), slice(None, getitem, None))), kwargs = {})
    return getitem_1

TypeBFloat16

forward

def forward(self, x):
    xb = x.to(torch.bfloat16)
    return (xb + xb).to(torch.float32)

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %_assert_tensor_metadata_default : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x,), kwargs = {dtype: torch.float32, device: cpu, layout: torch.strided})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bfloat16), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%to, %to), kwargs = {})
    %_assert_tensor_metadata_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%add,), kwargs = {dtype: torch.bfloat16, device: cpu, layout: torch.strided})
    %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%add, torch.float32), kwargs = {})
    return (to_1,)

strict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x, None, None, torch.float32), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bfloat16})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%_to_copy, %_to_copy), kwargs = {})
    %_assert_tensor_metadata_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%add, None, None, torch.bfloat16), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add,), kwargs = {dtype: torch.float32})
    return (_to_copy_1,)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %_assert_tensor_metadata_default : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x,), kwargs = {dtype: torch.float32, device: cpu, layout: torch.strided})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bfloat16), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%to, %to), kwargs = {})
    %_assert_tensor_metadata_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%add,), kwargs = {dtype: torch.bfloat16, device: cpu, layout: torch.strided})
    %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%add, torch.float32), kwargs = {})
    return (to_1,)

nostrict-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x, None, None, torch.float32), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bfloat16})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%_to_copy, %_to_copy), kwargs = {})
    %_assert_tensor_metadata_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%add, None, None, torch.bfloat16), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add,), kwargs = {dtype: torch.float32})
    return (_to_copy_1,)

jit

graph():
    %x : [num_users=2] = placeholder[target=x]
    %_assert_tensor_metadata_default : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x,), kwargs = {dtype: torch.float32, device: cpu, layout: torch.strided})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bfloat16), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%to, %to), kwargs = {})
    %_assert_tensor_metadata_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%add,), kwargs = {dtype: torch.bfloat16, device: cpu, layout: torch.strided})
    %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%add, torch.float32), kwargs = {})
    return (to_1,)

jit-decall

graph():
    %x : [num_users=2] = placeholder[target=x]
    %_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%x, None, None, torch.float32), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bfloat16})
    %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%_to_copy, %_to_copy), kwargs = {})
    %_assert_tensor_metadata_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%add_3, None, None, torch.bfloat16), kwargs = {device: cpu, layout: torch.strided})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add_3,), kwargs = {dtype: torch.float32})
    return (_to_copy_1,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %to : [num_users=1] = call_method[target=to](args = (%x, torch.bfloat16), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%to, %to), kwargs = {})
    %to_1 : [num_users=1] = call_method[target=to](args = (%add, torch.float32), kwargs = {})
    return to_1

Summary