Exported Programs

The following script shows the exported program for many short cases and various way to retrieve the torch.fx.Graph equivalent to the original model.

<<<

import inspect
import textwrap
from experimental_experiment.torch_interpreter.eval import discover, run_exporter

cases = discover()
print()
for name, cls_model in sorted(cases.items()):
    print(f"* :ref:`{name} <l-model-case-export-{name}>`")
print()

for name, cls_model in sorted(cases.items()):
    print()
    print(f".. _l-model-case-export-{name}:")
    print()
    print(name)
    print("=" * len(name))
    print()
    print("forward")
    print("+++++++")
    print()
    print("::")
    print()
    print(
        textwrap.indent(textwrap.dedent(inspect.getsource(cls_model.forward)), "    ")
    )
    print()
    for exporter in (
        "export-strict",
        "export-strict-decomposition",
        "export-nostrict",
        "export-nostrict-decomposition",
        "export-jit",
        "export-jit-decomposition",
        "export-tracing",
    ):
        expname = exporter.replace("export-", "")
        print()
        print(expname)
        print("+" * len(expname))
        print()
        res = run_exporter(exporter, cls_model, False, quiet=True)
        if "exported" in res:
            print("::")
            print()
            print(textwrap.indent(str(res["exported"].graph), "    "))
            print()
        else:
            print("**FAILED**")
            print()
            print("::")
            print()
            print(textwrap.indent(str(res["error"]), "    "))
            print()

>>>

AtenAsStrided

forward

def forward(self, x):
    y = torch.as_strided(x, (2, 2, 8, 4), (128, 8, 16, 1))
    return y

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

jit-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%x, [2, 2, 8, 4], [128, 8, 16, 1]), kwargs = {})
    return (as_strided,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %as_strided : [num_users=1] = call_function[target=torch.as_strided](args = (%x, (2, 2, 8, 4), (128, 8, 16, 1)), kwargs = {})
    return as_strided

AtenInterpolate

forward

def forward(self, x):
    y = torch.nn.functional.interpolate(
        x,
        scale_factor=2.0,
        mode="bilinear",
        recompute_scale_factor=False,
    )
    return y

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

jit-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.vec](args = (%x, None, False, [2.0, 2.0]), kwargs = {})
    return (upsample_bilinear2d,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %interpolate : [num_users=1] = call_function[target=torch.nn.functional.interpolate](args = (%x,), kwargs = {size: None, scale_factor: 2.0, mode: bilinear, align_corners: None, recompute_scale_factor: False, antialias: False})
    return interpolate

AtenNonZero

forward

def forward(self, x):
    y = torch.nonzero(x)
    return y

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    return (nonzero,)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    return (nonzero,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u0 <= 12 on node 'le'), kwargs = {})
    return (nonzero,)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    return (nonzero,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u0 <= 12 on node 'le'), kwargs = {})
    return (nonzero,)

jit-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.ops.aten.nonzero.default](args = (%x,), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%nonzero, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    return (nonzero,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=1] = call_function[target=torch.nonzero](args = (%x,), kwargs = {})
    return nonzero

AtenNonZeroTuple

forward

def forward(self, x):
    y = torch.nonzero(x, as_tuple=True)
    return y[0], y[1]

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u0 <= 12 on node 'le'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u0 <= 12 on node 'le'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

jit-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero_numpy : [num_users=2] = call_function[target=torch.ops.aten.nonzero_numpy.default](args = (%x,), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%nonzero_numpy, 0), kwargs = {})
    %sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem_2, 0), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 0 on node 'ge_1'), kwargs = {})
    %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 12), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 12 on node 'le_1'), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero_numpy, 1), kwargs = {})
    return (getitem_2, getitem_1)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %nonzero : [num_users=2] = call_function[target=torch.nonzero](args = (%x,), kwargs = {as_tuple: True})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%nonzero, 1), kwargs = {})
    return (getitem, getitem_1)

AtenRollPos

forward

def forward(self, x):
    return torch.roll(x, 1, -1)

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

jit-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [1], [-1]), kwargs = {})
    return (roll,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.roll](args = (%x, 1, -1), kwargs = {})
    return roll

AtenRollRelu

forward

def forward(self, x):
    return torch.relu(torch.roll(x, -1, -1))

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

jit-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.ops.aten.roll.default](args = (%x, [-1], [-1]), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%roll,), kwargs = {})
    return (relu,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %roll : [num_users=1] = call_function[target=torch.roll](args = (%x, -1, -1), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%roll,), kwargs = {})
    return relu

BuildInIsInstance

forward

def forward(self, x, lx: list | torch.Tensor):
    if isinstance(lx, list):
        t = lx[0] * lx[1].sum(axis=1, keepdim=True)
        return torch.sigmoid(self.linear(x)) - self.buff + t
    return torch.sigmoid(self.linear(x)) - self.buff + lx

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

strict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit-decomposition

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

tracing

Traceback (most recent call last):
File “/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph_module.py”, line 387, in __call__

return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]

File “/home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1739, in _wrapped_call_impl

return self._call_impl(*args, **kwargs)

File “/home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1750, in _call_impl

return forward_call(*args, **kwargs)

File “<eval_with_key>.996 from /home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:610 in forward”, line 9, in forward

add = sub + lx; sub = lx = None

TypeError: unsupported operand type(s) for +: ‘Tensor’ and ‘list’

Call using an FX-traced Module, line 9 of the traced Module’s generated forward function:

sub = sigmoid - buff; sigmoid = buff = None add = sub + lx; sub = lx = None

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <— HERE

return add

FAILED

unsupported operand type(s) for +: 'Tensor' and 'list'

BuildInLen

forward

def forward(self, x, lx: list):
    t = lx[0] * lx[1].sum(axis=1, keepdim=True)
    if len(lx) > 2:
        t = t + lx[2].sum(axis=1, keepdim=True)
    return torch.sigmoid(self.linear(x)) - self.buff + t

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

strict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit-decomposition

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

tracing

FAILED

len(.) expects an integer, len needs to be replaced. You should use _len.

ComplexPolar

forward

def forward(self, x, angle):
    return torch.polar(x, angle)

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

jit-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.ops.aten.polar.default](args = (%x, %angle), kwargs = {})
    return (polar,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %angle : [num_users=1] = placeholder[target=angle]
    %polar : [num_users=1] = call_function[target=torch.polar](args = (%x, %angle), kwargs = {})
    return polar

ControlFlowCond

forward

def forward(self, x):
    def true_fn(x):
        return torch.sin(x)

    def false_fn(x):
        return torch.cos(x)

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

strict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn_0]
    %_cb_cond_false_fn_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn_0, %_cb_cond_false_fn_0, [%x]), kwargs = {})
    return condcc

ControlFlowCond2Inputs

forward

def forward(self, x, y):
    def true_fn(x, y):
        return torch.sin(x), torch.cos(x) + y

    def false_fn(x, y):
        return torch.cos(x), torch.sin(x) + y

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y])

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x, %y]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

strict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x, %y]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x, %y]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x, %y]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn_0]
    %_cb_cond_false_fn_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn_0, %_cb_cond_false_fn_0, [%x, %y]), kwargs = {})
    return condcc

ControlFlowCond2Outputs

forward

def forward(self, x):
    def true_fn(x):
        return torch.sin(x), torch.cos(x)

    def false_fn(x):
        return torch.cos(x), torch.sin(x)

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

strict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

nostrict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 1), kwargs = {})
    return (getitem, getitem_1)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn_0]
    %_cb_cond_false_fn_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn_0, %_cb_cond_false_fn_0, [%x]), kwargs = {})
    return condcc

ControlFlowCondConstant

forward

def forward(self, x):
    def true_fn(x):
        return torch.sin(x) - torch.ones(x.shape, dtype=x.dtype)

    def false_fn(x):
        return torch.cos(x) + torch.ones((1, 1024), dtype=x.dtype)

    return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

strict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn_0]
    %_cb_cond_false_fn_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn_0, %_cb_cond_false_fn_0, [%x]), kwargs = {})
    return condcc

ControlFlowCondNestedModule

forward

def forward(self, x):
    def true_fn(x):
        return self.submodule(x)

    def false_fn(x):
        return x - self.weight

    y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
    return y

strict

graph():
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %weight : [num_users=1] = get_attr[target=weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x, %submodule_weight, %weight]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

strict-decomposition

graph():
    %weight : [num_users=1] = get_attr[target=weight]
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %submodule_weight, %weight)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

graph():
    %weight : [num_users=1] = get_attr[target=weight]
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x, %submodule_weight, %weight]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decomposition

graph():
    %weight : [num_users=1] = get_attr[target=weight]
    %submodule_weight : [num_users=1] = get_attr[target=submodule.weight]
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%x, %submodule_weight, %weight)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn_0]
    %_cb_cond_false_fn_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn_0, %_cb_cond_false_fn_0, [%x]), kwargs = {})
    return condcc

ControlFlowNestCond

forward

def forward(self, x):
    def true_fn2(x):
        def true_fn1(x):
            return torch.sin(x)

        def false_fn1(x):
            return torch.cos(x)

        return torch.cond(x.sum() < 0, true_fn1, false_fn1, [x])

    def false_fn2(x):
        return -x

    return torch.cond(x.sum() > 0, true_fn2, false_fn2, [x])

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

strict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

nostrict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%x]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

if isinstance(pred, torch.Tensor) and pred.numel() != 1:

FAILED

Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    %_cb_cond_true_fn2_0 : [num_users=1] = get_attr[target=_cb_cond_true_fn2_0]
    %_cb_cond_false_fn2_0 : [num_users=1] = get_attr[target=_cb_cond_false_fn2_0]
    %condcc : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %_cb_cond_true_fn2_0, %_cb_cond_false_fn2_0, [%x]), kwargs = {})
    return condcc

ControlFlowScan

forward

def forward(self, x):
    init = torch.zeros_like(x[0])
    carry, out = torch.ops.higher_order.scan(
        ControlFlowScan.add, [init], [x], dim=0, reverse=False, additional_inputs=[]
    )
    return carry

strict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], 0, False, []), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem,)

strict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], 0, False, []), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    return (getitem,)

nostrict

graph():
    %x : [num_users=2] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], 0, False, []), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem,)

nostrict-decomposition

graph():
    %x : [num_users=2] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], 0, False, []), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    return (getitem,)

jit

FAILED

could not find kernel for HigherOrderOperator scan (<torch._higher_order_ops.scan.ScanOp object at 0x7fa78a3217e0>, <class 'torch._higher_order_ops.scan.ScanOp'>)at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

jit-decomposition

FAILED

could not find kernel for HigherOrderOperator scan (<torch._higher_order_ops.scan.ScanOp object at 0x7fa78a3217e0>, <class 'torch._higher_order_ops.scan.ScanOp'>)at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

tracing

FAILED

Unable to symbolically trace HigherOrderOperators

ControlFlowScan2Carried

forward

def forward(self, x):
    init1 = torch.zeros_like(x[0])
    init2 = torch.ones_like(x[0])
    carry1, carry2, out1, out2 = torch.ops.higher_order.scan(
        ControlFlowScan2Carried.add,
        [init1, init2],
        [x, x * 2],
        dim=0,
        reverse=False,
        additional_inputs=[],
    )
    return carry1, carry2, out1, out2

strict

graph():
    %x : [num_users=4] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %select_1 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %ones_like : [num_users=1] = call_function[target=torch.ops.aten.ones_like.default](args = (%select_1,), kwargs = {pin_memory: False})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 2), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=4] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like, %ones_like], [%x, %mul], 0, False, []), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 2), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 3), kwargs = {})
    return (getitem, getitem_1, getitem_2, getitem_3)

strict-decomposition

graph():
    %x : [num_users=4] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %select_1 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %ones_like : [num_users=1] = call_function[target=torch.ops.aten.ones_like.default](args = (%select_1,), kwargs = {pin_memory: False})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 2), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=4] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like, %ones_like], [%x, %mul], 0, False, []), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 2), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 3), kwargs = {})
    return (getitem, getitem_1, getitem_2, getitem_3)

nostrict

graph():
    %x : [num_users=4] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %select_1 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %ones_like : [num_users=1] = call_function[target=torch.ops.aten.ones_like.default](args = (%select_1,), kwargs = {pin_memory: False})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 2), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=4] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like, %ones_like], [%x, %mul], 0, False, []), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 2), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 3), kwargs = {})
    return (getitem, getitem_1, getitem_2, getitem_3)

nostrict-decomposition

graph():
    %x : [num_users=4] = placeholder[target=x]
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
    %select_1 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
    %ones_like : [num_users=1] = call_function[target=torch.ops.aten.ones_like.default](args = (%select_1,), kwargs = {pin_memory: False})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 2), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=4] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like, %ones_like], [%x, %mul], 0, False, []), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 2), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 3), kwargs = {})
    return (getitem, getitem_1, getitem_2, getitem_3)

jit

FAILED

could not find kernel for HigherOrderOperator scan (<torch._higher_order_ops.scan.ScanOp object at 0x7fa78a3217e0>, <class 'torch._higher_order_ops.scan.ScanOp'>)at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

jit-decomposition

FAILED

could not find kernel for HigherOrderOperator scan (<torch._higher_order_ops.scan.ScanOp object at 0x7fa78a3217e0>, <class 'torch._higher_order_ops.scan.ScanOp'>)at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

tracing

FAILED

Unable to symbolically trace HigherOrderOperators

ControlFlowScanCDist

forward

def forward(self, x):
    carry, out = torch.ops.higher_order.scan(
        ControlFlowScanCDist.dist,
        [x],
        [x],
        dim=0,
        reverse=False,
        additional_inputs=[],
    )
    return out

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%x], [%x], 0, False, []), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%x], [%x], 0, False, []), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%x], [%x], 0, False, []), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%x], [%x], 0, False, []), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

jit

FAILED

could not find kernel for HigherOrderOperator scan (<torch._higher_order_ops.scan.ScanOp object at 0x7fa78a3217e0>, <class 'torch._higher_order_ops.scan.ScanOp'>)at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

jit-decomposition

FAILED

could not find kernel for HigherOrderOperator scan (<torch._higher_order_ops.scan.ScanOp object at 0x7fa78a3217e0>, <class 'torch._higher_order_ops.scan.ScanOp'>)at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

tracing

FAILED

Unable to symbolically trace HigherOrderOperators

ControlFlowScanCDist2

forward

def forward(self, x):
    z = torch.tensor([0], dtype=torch.float32)
    y = x.clone()
    out = torch.ops.higher_order.scan(
        ControlFlowScanCDist2.dist,
        [z],
        [x],
        dim=0,
        reverse=False,
        additional_inputs=[y],
    )
    return out[1]

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {})
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%detach_], [%x], 0, False, [%clone]), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%lift_fresh_copy], [%x], 0, False, [%clone]), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {})
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%detach_], [%x], 0, False, [%clone]), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%lift_fresh_copy], [%x], 0, False, [%clone]), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

jit

/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:405: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.

z = torch.tensor([0], dtype=torch.float32)

FAILED

could not find kernel for HigherOrderOperator scan (<torch._higher_order_ops.scan.ScanOp object at 0x7fa78a3217e0>, <class 'torch._higher_order_ops.scan.ScanOp'>)at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

jit-decomposition

/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py:405: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.

z = torch.tensor([0], dtype=torch.float32)

FAILED

could not find kernel for HigherOrderOperator scan (<torch._higher_order_ops.scan.ScanOp object at 0x7fa78a3217e0>, <class 'torch._higher_order_ops.scan.ScanOp'>)at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

tracing

FAILED

[CustomProxy(clone)] can only be of (<class 'torch.Tensor'>, <class 'int'>, <class 'torch.SymInt'>) but got (<class 'experimental_experiment.torch_interpreter.tracing.CustomProxy'>,)

ControlFlowScanCDistXY

forward

def forward(self, x, y):
    carry, out = torch.ops.higher_order.scan(
        ControlFlowScanCDistXY.dist,
        [y],
        [x],
        dim=0,
        reverse=False,
        additional_inputs=[],
    )
    return out

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%y], [%x], 0, False, []), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%y], [%x], 0, False, []), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%y], [%x], 0, False, []), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%y], [%x], 0, False, []), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

jit

FAILED

could not find kernel for HigherOrderOperator scan (<torch._higher_order_ops.scan.ScanOp object at 0x7fa78a3217e0>, <class 'torch._higher_order_ops.scan.ScanOp'>)at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

jit-decomposition

FAILED

could not find kernel for HigherOrderOperator scan (<torch._higher_order_ops.scan.ScanOp object at 0x7fa78a3217e0>, <class 'torch._higher_order_ops.scan.ScanOp'>)at dispatch key DispatchKey.??? (resolved from DispatchKey.???)

tracing

FAILED

Unable to symbolically trace HigherOrderOperators

CropLastDimensionWithTensorContent

forward

def forward(self, x, shape):
    return x[..., : shape[0]]

strict

FAILED

Dynamic slicing on data-dependent value is not supported

from user code:
   File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 699, in forward
    return x[..., : shape[0]]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

strict-decomposition

FAILED

Dynamic slicing on data-dependent value is not supported

from user code:
   File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 699, in forward
    return x[..., : shape[0]]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

nostrict

FAILED

Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:557 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 699, in forward
    return x[..., : shape[0]]

nostrict-decomposition

FAILED

Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:557 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 699, in forward
    return x[..., : shape[0]]

jit

FAILED

Could not guard on data-dependent expression u0 < 0 (unhinted: u0 < 0).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:771 in slice_forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "<string>", line 1, in <lambda>


While executing %slice_tensor : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, %_local_scalar_dense_default, 1), kwargs = {})
Original traceback:
None

jit-decomposition

FAILED

Could not guard on data-dependent expression u0 < 0 (unhinted: u0 < 0).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:771 in slice_forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "<string>", line 1, in <lambda>


While executing %slice_tensor : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, %_local_scalar_dense_default, 1), kwargs = {})
Original traceback:
None

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %shape : [num_users=1] = placeholder[target=shape]
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%shape, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%x, (Ellipsis, slice(None, getitem, None))), kwargs = {})
    return getitem_1

CropLastDimensionWithTensorShape

forward

def forward(self, x, y):
    return x[..., : y.shape[0]]

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, 2), kwargs = {})
    return (slice_1,)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, 2), kwargs = {})
    return (slice_1,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, 2), kwargs = {})
    return (slice_1,)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, 2), kwargs = {})
    return (slice_1,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, 2), kwargs = {})
    return (slice_1,)

jit-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 2, 0, 2), kwargs = {})
    return (slice_1,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%y, shape), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%getattr_1, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%x, (Ellipsis, slice(None, getitem, None))), kwargs = {})
    return getitem_1

InplaceAdd

forward

def forward(self, x):
    x += self.bias
    return x

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (copy__default,)

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (copy__default,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %lifted_tensor_3), kwargs = {})
    return (add_,)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_3), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (copy__default,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %bias : [num_users=1] = get_attr[target=bias]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %bias), kwargs = {})
    return add

InplaceAdd_

forward

def forward(self, x):
    x.add_(self.bias)
    return x

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (copy__default,)

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    return (add_,)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (copy__default,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %lifted_tensor_3), kwargs = {})
    return (add_,)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_3), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (copy__default,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %bias : [num_users=1] = get_attr[target=bias]
    %add_ : [num_users=1] = call_method[target=add_](args = (%x, %bias), kwargs = {})
    return add_

InplaceAdd_Mul

forward

def forward(self, x):
    x.add_(self.bias)
    return x * 2

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_, 2), kwargs = {})
    return (mul,)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (mul,)

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_, 2), kwargs = {})
    return (mul,)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=2] = placeholder[target=x]
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %add), kwargs = {})
    return (mul,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_4 target lifted_tensor_4 lifted_tensor_4 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %lifted_tensor_4 : [num_users=1] = get_attr[target=lifted_tensor_4]
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%x_1, %lifted_tensor_4), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_, %lifted_tensor_2), kwargs = {})
    return (mul,)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_4 target lifted_tensor_4 lifted_tensor_4 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_4 target lifted_tensor_4 lifted_tensor_4 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_4 : [num_users=1] = get_attr[target=lifted_tensor_4]
    %x_1 : [num_users=2] = placeholder[target=x_1]
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%x_1, %lifted_tensor_4), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x_1, %add), kwargs = {})
    return (mul,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %bias : [num_users=1] = get_attr[target=bias]
    %add_ : [num_users=1] = call_method[target=add_](args = (%x, %bias), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%add_, 2), kwargs = {})
    return mul

InplaceCloneAdd

forward

def forward(self, x):
    x = x.clone()
    x.add_(self.bias)
    return x

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%clone, %bias), kwargs = {})
    return (add_,)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %bias), kwargs = {})
    return (add,)

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%clone, %bias), kwargs = {})
    return (add_,)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node bias target bias bias of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %bias : [num_users=1] = get_attr[target=bias]
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %bias), kwargs = {})
    return (add,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x_1,), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%clone, %lifted_tensor_3), kwargs = {})
    return (add_,)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%x_1,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %lifted_tensor_3), kwargs = {})
    return (add,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %clone : [num_users=1] = call_method[target=clone](args = (%x,), kwargs = {})
    %bias : [num_users=1] = get_attr[target=bias]
    %add_ : [num_users=1] = call_method[target=add_](args = (%clone, %bias), kwargs = {})
    return add_

InplaceSetItemEllipsis_1

forward

def forward(self, index, update):
    copy = self.params.clone()
    copy[..., index] = update
    return copy

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node params target params params of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node params target params params of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

FAILED

false INTERNAL ASSERT FAILED at "/pytorch/build/aten/src/ATen/RegisterFunctionalization_1.cpp":5939, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.

While executing %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
Original traceback:
  File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 116, in forward
    copy[..., index] = update

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node params target params params of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node params target params params of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

FAILED

false INTERNAL ASSERT FAILED at "/pytorch/build/aten/src/ATen/RegisterFunctionalization_1.cpp":5939, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.

While executing %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
Original traceback:
File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 116, in forward
    copy[..., index] = update

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_5 target lifted_tensor_5 lifted_tensor_5 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_5 : [num_users=1] = get_attr[target=lifted_tensor_5]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_5,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_5 target lifted_tensor_5 lifted_tensor_5 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

FAILED

false INTERNAL ASSERT FAILED at "/pytorch/build/aten/src/ATen/RegisterFunctionalization_1.cpp":5939, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.

While executing %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
Original traceback:
None

tracing

graph():
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0]
    %setitem : [num_users=1] = call_function[target=operator.setitem](args = (%_tensor_constant0, (Ellipsis, %index), %update), kwargs = {})
    return setitem

InplaceSetItemEllipsis_2

forward

def forward(self, index, update):
    copy = self.params.clone()
    copy[..., index] = update
    return copy

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node params target params params of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node params target params params of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

FAILED

false INTERNAL ASSERT FAILED at "/pytorch/build/aten/src/ATen/RegisterFunctionalization_1.cpp":5939, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.

While executing %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
Original traceback:
  File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 134, in forward
    copy[..., index] = update

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node params target params params of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %params : [num_users=1] = get_attr[target=params]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%params,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node params target params params of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

FAILED

false INTERNAL ASSERT FAILED at "/pytorch/build/aten/src/ATen/RegisterFunctionalization_1.cpp":5939, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.

While executing %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
Original traceback:
File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/eval/model_cases.py", line 134, in forward
    copy[..., index] = update

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_5 target lifted_tensor_5 lifted_tensor_5 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_5 : [num_users=1] = get_attr[target=lifted_tensor_5]
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%lifted_tensor_5,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
    return (index_put_,)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_5 target lifted_tensor_5 lifted_tensor_5 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

FAILED

false INTERNAL ASSERT FAILED at "/pytorch/build/aten/src/ATen/RegisterFunctionalization_1.cpp":5939, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.

While executing %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [None, None, %index], %update), kwargs = {})
Original traceback:
None

tracing

graph():
    %index : [num_users=1] = placeholder[target=index]
    %update : [num_users=1] = placeholder[target=update]
    %_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0]
    %setitem : [num_users=1] = call_function[target=operator.setitem](args = (%_tensor_constant0, (Ellipsis, %index), %update), kwargs = {})
    return setitem

InplaceSetItemMask

forward

def forward(self, x):
    mask = x.to(bool)
    x[mask] = 2
    return x

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bool), kwargs = {})
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%x, [%to], %lift_fresh_copy), kwargs = {})
    return (index_put_,)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bool})
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%x, [%_to_copy], %lift_fresh_copy), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %index_put), kwargs = {})
    return (copy__default,)

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bool), kwargs = {})
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%x, [%to], %lift_fresh_copy), kwargs = {})
    return (index_put_,)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bool})
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%x, [%_to_copy], %lift_fresh_copy), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %index_put), kwargs = {})
    return (copy__default,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %x : [num_users=2] = placeholder[target=x]
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bool), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%x, [%to], %lifted_tensor_2), kwargs = {})
    return (index_put_,)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %x : [num_users=3] = placeholder[target=x]
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bool})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%x, [%_to_copy], %lifted_tensor_2), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %index_put), kwargs = {})
    return (copy__default,)

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %to : [num_users=1] = call_method[target=to](args = (%x, torch.bool), kwargs = {})
    %setitem : [num_users=1] = call_function[target=operator.setitem](args = (%x, %to, 2), kwargs = {})
    return setitem

InplaceSetItemSquare

forward

def forward(self, x):
    x[:2, :3] = 1
    return x

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    return (x,)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill : [num_users=1] = call_function[target=torch.ops.aten.fill.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %fill, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (copy__default,)

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    return (x,)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill : [num_users=1] = call_function[target=torch.ops.aten.fill.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %fill, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (copy__default,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %x : [num_users=2] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lifted_tensor_2), kwargs = {})
    return (x,)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %x : [num_users=4] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill : [num_users=1] = call_function[target=torch.ops.aten.fill.Tensor](args = (%slice_2, %lifted_tensor_2), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %fill, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (copy__default,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %setitem : [num_users=1] = call_function[target=operator.setitem](args = (%x, (slice(None, 2, None), slice(None, 3, None)), 1), kwargs = {})
    return setitem

InplaceSetItemSquareAdd

forward

def forward(self, x):
    x[:2, :3] = 1
    return x + 2

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    return (add,)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill : [num_users=1] = call_function[target=torch.ops.aten.fill.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %fill, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=2] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add,)

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=2] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    return (add,)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill : [num_users=1] = call_function[target=torch.ops.aten.fill.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %fill, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=2] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=2] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lifted_tensor_3), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_2), kwargs = {})
    return (add,)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %x : [num_users=4] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill : [num_users=1] = call_function[target=torch.ops.aten.fill.Tensor](args = (%slice_2, %lifted_tensor_3), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %fill, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=2] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %setitem : [num_users=1] = call_function[target=operator.setitem](args = (%x, (slice(None, 2, None), slice(None, 3, None)), 1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%setitem, 2), kwargs = {})
    return add

InplaceSetItemSquareAdd2

forward

def forward(self, x):
    x[:2, :3] = 1
    return x + 2, x + 3

strict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {})
    return (add, add_1)

strict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill : [num_users=1] = call_function[target=torch.ops.aten.fill.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %fill, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 3), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add, add_1)

nostrict

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=3] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {})
    return (add, add_1)

nostrict-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
    %x : [num_users=4] = placeholder[target=x]
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill : [num_users=1] = call_function[target=torch.ops.aten.fill.Tensor](args = (%slice_2, %lift_fresh_copy), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %fill, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 3), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add, add_1)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_4 target lifted_tensor_4 lifted_tensor_4 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_2 : [num_users=1] = get_attr[target=lifted_tensor_2]
    %lifted_tensor_3 : [num_users=1] = get_attr[target=lifted_tensor_3]
    %lifted_tensor_4 : [num_users=1] = get_attr[target=lifted_tensor_4]
    %x : [num_users=3] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill_ : [num_users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%slice_2, %lifted_tensor_4), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_3), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_2), kwargs = {})
    return (add, add_1)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_2 target lifted_tensor_2 lifted_tensor_2 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_3 target lifted_tensor_3 lifted_tensor_3 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_4 target lifted_tensor_4 lifted_tensor_4 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

getattr_node = gm.graph.get_attr(lifted_node)

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_4 target lifted_tensor_4 lifted_tensor_4 of does not reference an nn.Module, nn.Parameter, or buffer, which is what ‘get_attr’ Nodes typically target

warnings.warn(

graph():
    %lifted_tensor_4 : [num_users=1] = get_attr[target=lifted_tensor_4]
    %x : [num_users=4] = placeholder[target=x]
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 3), kwargs = {})
    %fill : [num_users=1] = call_function[target=torch.ops.aten.fill.Tensor](args = (%slice_2, %lifted_tensor_4), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 2), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_3, %fill, 1, 0, 3), kwargs = {})
    %slice_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%x, %slice_scatter, 0, 0, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_1, 3), kwargs = {})
    %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%x, %slice_scatter_1), kwargs = {})
    return (add, add_1)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %setitem : [num_users=2] = call_function[target=operator.setitem](args = (%x, (slice(None, 2, None), slice(None, 3, None)), 1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%setitem, 2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%setitem, 3), kwargs = {})
    return (add, add_1)

SignatureFloat1

forward

def forward(self, x, alpha: float = 2.0):
    return torch.sigmoid(self.linear(x)) - self.buff * alpha

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %alpha : [num_users=0] = placeholder[target=alpha]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%buff, 1.5), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %mul), kwargs = {})
    return (sub,)

strict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %alpha : [num_users=0] = placeholder[target=alpha]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%buff, 1.5), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %mul), kwargs = {})
    return (sub,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %alpha : [num_users=0] = placeholder[target=alpha]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%buff, 1.5), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %mul), kwargs = {})
    return (sub,)

nostrict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %alpha : [num_users=0] = placeholder[target=alpha]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%buff, 1.5), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %mul), kwargs = {})
    return (sub,)

jit

FAILED

Type 'Tuple[Tensor, float]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

jit-decomposition

FAILED

Type 'Tuple[Tensor, float]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %alpha : float [num_users=1] = placeholder[target=alpha](default=2.0)
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %buff : [num_users=1] = get_attr[target=buff]
    %mul : [num_users=1] = call_method[target=mul](args = (%buff, %alpha), kwargs = {})
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sigmoid, %mul), kwargs = {})
    return sub

SignatureInt1

forward

def forward(self, x, i: int = 2):
    return torch.sigmoid(self.linear(x)) - self.buff + x[:, i : i + 1]

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 1, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %slice_2), kwargs = {})
    return (add,)

strict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 1, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %slice_2), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 1, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %slice_2), kwargs = {})
    return (add,)

nostrict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 1, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %slice_2), kwargs = {})
    return (add,)

jit

FAILED

Type 'Tuple[Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

jit-decomposition

FAILED

Type 'Tuple[Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %i : int [num_users=2] = placeholder[target=i](default=2)
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %buff : [num_users=1] = get_attr[target=buff]
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%i, 1), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%x, (slice(None, None, None), slice(i, add, None))), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%sub, %getitem), kwargs = {})
    return add_1

SignatureInt2

forward

def forward(self, x, i: int = 2):
    return torch.sigmoid(self.linear(x)) - self.buff + x[:, i]

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %select), kwargs = {})
    return (add,)

strict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %select), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %select), kwargs = {})
    return (add,)

nostrict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %i : [num_users=0] = placeholder[target=i]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 9223372036854775807), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_1, 1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %select), kwargs = {})
    return (add,)

jit

FAILED

Type 'Tuple[Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

jit-decomposition

FAILED

Type 'Tuple[Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %i : int [num_users=1] = placeholder[target=i](default=2)
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %buff : [num_users=1] = get_attr[target=buff]
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sigmoid, %buff), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%x, (slice(None, None, None), %i)), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sub, %getitem), kwargs = {})
    return add

SignatureListFixedLength

forward

def forward(self, x, lx: list):
    return (
        torch.sigmoid(self.linear(x)) - self.buff + lx[0] * lx[1].sum(axis=1, keepdim=True)
    )

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

strict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

nostrict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

jit-decomposition

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%lx_1, [1], True), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%lx_0, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %mul), kwargs = {})
    return (add,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %lx : list [num_users=2] = placeholder[target=lx]
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %buff : [num_users=1] = get_attr[target=buff]
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sigmoid, %buff), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%lx, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%lx, 1), kwargs = {})
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {axis: 1, keepdim: True})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%getitem, %sum_1), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sub, %mul), kwargs = {})
    return add

SignatureListVariableLength

forward

def forward(self, x, lx: list):
    t = torch.cat(lx, dim=1).sum(axis=1, keepdim=True)
    return torch.sigmoid(self.linear(x)) - self.buff + t

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

strict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

nostrict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

jit

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

jit-decomposition

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %buff : [num_users=1] = get_attr[target=buff]
    %x : [num_users=1] = placeholder[target=x]
    %lx_0 : [num_users=1] = placeholder[target=lx_0]
    %lx_1 : [num_users=1] = placeholder[target=lx_1]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%lx_0, %lx_1], 1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%cat, [1], True), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sub, %sum_1), kwargs = {})
    return (add,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %lx : list [num_users=1] = placeholder[target=lx]
    %cat : [num_users=1] = call_function[target=torch.cat](args = (%lx, 1), kwargs = {})
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%cat,), kwargs = {axis: 1, keepdim: True})
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %buff : [num_users=1] = get_attr[target=buff]
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%sigmoid, %buff), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sub, %sum_1), kwargs = {})
    return add

SignatureShapeAsIndex

forward

def forward(self, x, y):
    t = torch.sigmoid(self.linear(x)) + x
    return t[:, : y.shape[1]]

strict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 2), kwargs = {})
    return (slice_2,)

strict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 2), kwargs = {})
    return (slice_2,)

nostrict

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=0] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 2), kwargs = {})
    return (slice_2,)

nostrict-decomposition

graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %buff : [num_users=0] = get_attr[target=buff]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 2), kwargs = {})
    return (slice_2,)

jit

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_export/converter.py:1474: UserWarning: Manually populate buff into state_dict ExportedProgram, but it is never used by the ExportedProgram.

warnings.warn(

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 2), kwargs = {})
    return (slice_2,)

jit-decomposition

/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_export/converter.py:1474: UserWarning: Manually populate buff into state_dict ExportedProgram, but it is never used by the ExportedProgram.

warnings.warn(

graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %linear_weight, %linear_bias), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sigmoid, %x), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 2), kwargs = {})
    return (slice_2,)

tracing

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sigmoid, %x), kwargs = {})
    %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%y, shape), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%getattr_1, 1), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%add, (slice(None, None, None), slice(None, getitem, None))), kwargs = {})
    return getitem_1

TypeBFloat16

forward

def forward(self, x):
    xb = x.to(torch.bfloat16)
    return (xb + xb).to(torch.float32)

strict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bfloat16), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to, %to), kwargs = {})
    %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%add, torch.float32), kwargs = {})
    return (to_1,)

strict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bfloat16})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_to_copy, %_to_copy), kwargs = {})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add,), kwargs = {dtype: torch.float32})
    return (_to_copy_1,)

nostrict

graph():
    %x : [num_users=1] = placeholder[target=x]
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bfloat16), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to, %to), kwargs = {})
    %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%add, torch.float32), kwargs = {})
    return (to_1,)

nostrict-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bfloat16})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_to_copy, %_to_copy), kwargs = {})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add,), kwargs = {dtype: torch.float32})
    return (_to_copy_1,)

jit

graph():
    %x : [num_users=1] = placeholder[target=x]
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.bfloat16), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to, %to), kwargs = {})
    %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%add, torch.float32), kwargs = {})
    return (to_1,)

jit-decomposition

graph():
    %x : [num_users=1] = placeholder[target=x]
    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.bfloat16})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_to_copy, %_to_copy), kwargs = {})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add,), kwargs = {dtype: torch.float32})
    return (_to_copy_1,)

tracing

graph():
    %x : [num_users=1] = placeholder[target=x]
    %to : [num_users=1] = call_method[target=to](args = (%x, torch.bfloat16), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%to, %to), kwargs = {})
    %to_1 : [num_users=1] = call_method[target=to](args = (%add, torch.float32), kwargs = {})
    return to_1