Use the custom exporter in torch

Subject to change

File onnxruntime.py

This change enables the custom rewriter is an environment variable is enabled. Look for substring TODO:.

def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs):
    """This function replaces GraphModule._wrapped_call in compiled model.

    The _wrapped_call is the underlying implementation of forward method. Replacing
    it means we delegate the computation to _ort_acclerated_call and therefore
    onnxruntime.InferenceSession.
    """
    cached_execution_info_per_session = (
        self._all_ort_execution_info.search_reusable_session_execution_info(
            graph_module, *args
        )
    )
    if cached_execution_info_per_session:
        onnx_session = cached_execution_info_per_session.session
        input_names = cached_execution_info_per_session.input_names
        output_names = cached_execution_info_per_session.output_names
        input_value_infos = cached_execution_info_per_session.input_value_infos
        output_value_infos = cached_execution_info_per_session.output_value_infos
        input_devices = cached_execution_info_per_session.input_devices
        output_devices = cached_execution_info_per_session.output_devices
        prim_outputs = cached_execution_info_per_session.example_outputs
    else:
        # It's first time seeing such as graph. Let's make a new session
        # (type: onnxruntime.InferenceSession) for it.

        ##########################
        # TODO: Insert these lines
        ##########################

        use_other_rewriter = os.environ.get("ONNXRT_CHANGE_REWRITER", None) in (1, "1")
        if use_other_rewriter:
            from experimental_experiment.torch_interpreter import to_onnx
            from experimental_experiment.torch_interpreter._torch_models import create_input_names
            from experimental_experiment.xbuilder import OptimizationOptions
            from experimental_experiment.torch_interpreter.oxs_dispatcher import OxsDispatcher

            input_names = input_names = create_input_names(graph_module, args)
            dispatcher = OxsDispatcher()
            target_opset = self._resolved_onnx_exporter_options.onnx_registry.opset_version
            options = OptimizationOptions(
                remove_unused=True,
                constant_folding=False,
                patterns="default",
                verbose=1,
            )
            onnx_model, builder = to_onnx(
                graph_module,
                tuple(args),
                input_names=input_names,
                options=options,
                verbose=1,
                target_opset=target_opset,
                return_builder=True,
                dispatcher=dispatcher,
            )

            def maybe_map_to_meta_val(value):
                if hasattr(value, "meta") and "val" in value.meta:
                    # Select outputs with "val" information. Without "val",
                    # it's not possible access output_arg.meta["val"].device.
                    return value.meta["val"]
                return value

            extracted_outputs = _extract_graph_module_outputs(graph_module)
            prim_outputs = _pytree.tree_map(maybe_map_to_meta_val, extracted_outputs)

        else:

        ####################################
        # TODO: end of the insertion
        # TODO: indent what follows
        ####################################

            graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront(
                self._resolved_onnx_exporter_options.diagnostic_context,
                graph_module,
            ).run()
            # Generate reference outputs. They are used to indicate output
            # tensors' types and devices when calling ORT.
            #
            # WARNING: The downstream code should not change prim_outputs and
            # this backend should always produces output with schema identical to prim_outputs'.

            if self._resolved_onnx_exporter_options.dynamic_shapes:
                # No pre-allocation when dynamic shape is enabled.
                self.preallocate_output = False
                extracted_outputs = _extract_graph_module_outputs(graph_module)

                def maybe_map_to_meta_val(value):
                    if hasattr(value, "meta") and "val" in value.meta:
                        # Select outputs with "val" information. Without "val",
                        # it's not possible access output_arg.meta["val"].device.
                        return value.meta["val"]
                    else:
                        return value

                prim_outputs = _pytree.tree_map(
                    maybe_map_to_meta_val, extracted_outputs
                )
            else:
                try:
                    prim_outputs = FakeTensorProp(graph_module).propagate(
                        *args, **kwargs
                    )
                except Exception:
                    logger.warning("FakeTensorProb failed for %s", graph_module)
                    # When FakeTensorProp fails, it is not possible to preallocate output buffers
                    # because the output shapes are not inferred.
                    self.preallocate_output = False

                    # rethrow FakeTensorProb failure because it is not yet currently handled.
                    raise

            # Create the object to iterate through the nodes in graph one-by-one
            # and calls the corresponding ONNX exporter for each node.
            fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
                diagnostic_context=self._resolved_onnx_exporter_options.diagnostic_context
            )
            # Cast FX variables if they will result schema-mismatch when searching
            # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch,
            # but ONNX expects add(double_tensor, double_tensor).
            graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion(
                self._resolved_onnx_exporter_options.diagnostic_context, graph_module
            ).run()
            # Start the per-node exporting process. It's conceptually a for loop
            # scanning through the nodes in the graph.
            exported = fx_interpreter.run(
                fx_graph_module=graph_module,
                onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher,
                op_level_debug=self._resolved_onnx_exporter_options.op_level_debug,
            )
            # Convert the exported result to ONNX ModelProto.
            onnx_model = exported.to_model_proto(
                opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version,
            )

        ####################################
        # TODO: end of the modification
        ####################################

        # Modify ONNX model using pre-registered graph transforms.
        # They are in-place modifications for avoiding unnecessary
        # copy of ONNX initializers.
        if self._options.pre_ort_model_transforms:
            for transform in self._options.pre_ort_model_transforms:
                transform(onnx_model)

        onnx_model_bytes = onnx_model.SerializeToString()
        if os.environ.get("ONNXRT_DUMP_PATH", None):
            # If not empty, environment variable ONNXRT_DUMP_PATH defined the path
            # where generated onnx files should be stored.
            # This module keeps a global variables keeping track of the
            # stored models.
            # If ONNXRT_DUMP_PATH="dumped/dumped_model_"
            # The first file name will be 'dumped/dumped_model_0.onnx'.
            # For every dumped model, a text file 'dumped/dumped_model_0.txt'
            # is created as well to contain the string representing the graph_module.
            _dump_onnx_model(onnx_model_bytes, graph_module=graph_module)

        # Initialize a ORT session to execute this ONNX model.
        # Note that TorchDynamo assumes all inputs/outputs are on the
        # same device, but it's subject to change (very likely with
        # dynamic shape support), so we add execution providers
        # based on the logic in _select_eps: (explicitly preferred EPs,
        # EPs inferred from inputs or graph, and the fallback default EP)/
        #
        # TODO(wschin): enable external allocators.
        # See https://github.com/pytorch/pytorch/issues/106867
        onnx_session = onnxruntime.InferenceSession(
            path_or_bytes=onnx_model_bytes,
            sess_options=self._options.ort_session_options,
            providers=self._select_eps(graph_module, *args),
        )

        # Cache ORT session. It's reused for the same "graph_module".
        # Generate ONNX model and extract its input and output names.
        input_names = tuple(input.name for input in onnx_model.graph.input)
        output_names = tuple(output.name for output in onnx_model.graph.output)
        input_devices = _get_onnx_devices(args)
        # Cache devices for inputs and outputs. They are used to invoke
        # ORT session. Output devices indicate where (e.g., GPU or CPU)
        # to store outputs
        if isinstance(prim_outputs, tuple):
            output_devices = _get_onnx_devices(prim_outputs)
        else:
            output_devices = _get_onnx_devices((prim_outputs,))

        input_value_infos = tuple(input for input in onnx_model.graph.input)
        output_value_infos = tuple(output for output in onnx_model.graph.output)

        execution_info_per_session = OrtExecutionInfoPerSession(
            session=onnx_session,
            input_names=input_names,
            input_value_infos=input_value_infos,
            output_names=output_names,
            output_value_infos=output_value_infos,
            input_devices=input_devices,
            output_devices=output_devices,
            example_outputs=prim_outputs,
        )

        self._all_ort_execution_info.cache_session_execution_info(
            graph_module, execution_info_per_session
        )

    self.execution_count += 1

    # ORT always returns a tuple of outputs. If the original output is a tensor,
    # ORT output's first element must be extracted and returned. Otherwise, type
    # mismatch may happen in downstream computation.
    is_single_tensor_output = isinstance(prim_outputs, torch.Tensor)
    normalized_prim_outputs = (
        (prim_outputs,) if is_single_tensor_output else prim_outputs
    )
    assert isinstance(normalized_prim_outputs, tuple)
    assert all(
        isinstance(elem, (torch.Tensor, torch.SymInt, int))
        for elem in normalized_prim_outputs
    )

    _nvtx_range_push("run_onnx_session_with_ortvaluevector")
    onnx_outputs = self.run(
        onnx_session,
        input_names,
        args,
        input_devices,
        output_names,
        normalized_prim_outputs,
        output_devices,
        self._options.preallocate_output,
        input_value_infos,
        normalized_prim_outputs,
    )
    _nvtx_range_pop()

    if self._assert_allclose_to_baseline:
        # Compute baseline.
        baseline_outputs = torch._prims.executor.execute(
            graph_module, *args, executor="aten"
        )
        normalized_baseline_ouptuts = (
            (baseline_outputs,) if is_single_tensor_output else baseline_outputs
        )
        # Ensure every output tensor is close to the corresponding baseline.
        for onnx_output, baseline_output in zip(
            onnx_outputs, normalized_baseline_ouptuts
        ):
            torch.testing.assert_close(onnx_output, baseline_output)
    return onnx_outputs[0] if is_single_tensor_output else onnx_outputs

Examples

Baseline

<<<

import os
import warnings
import numpy as np
import onnx
import torch
import torch.onnx
from experimental_experiment.torch_models.training_helper import (
    make_aot_ort,
    train_loop,
)
from experimental_experiment.torch_models.dump_helper import dump_onnx

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    from transformers import LlamaConfig
    from transformers.models.llama.modeling_llama import LlamaModel


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = LlamaConfig(
    hidden_size=16,
    num_hidden_layers=1,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=1024,
    num_attention_heads=2,
)
config._attn_implementation = "eager"

model = LlamaModel(config)

batch, seq, vocab_size = 2, 1024, 1024

input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

model(input_ids, input_mask)

local_aot_ort, _ = make_aot_ort(
    dynamic=True,
    rewrite=True,
    verbose=1,
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
    with dump_onnx("dort-llama-ort", folder="dump_llama", clean=True):
        train_loop(optimized_mod, input_ids, input_mask)

names = [_ for _ in os.listdir("dump_llama") if _.endswith(".onnx")]
print("------------------------------------------")
print(f"exported model: {names}")
for name in names:
    print()
    print("NODES in {name!r}")
    onx = onnx.load(os.path.join("dump_llama", name))
    for i, node in enumerate(onx.graph.node):
        print(
            f"{i+1}/{len(onx.graph.node)}: {node.op_type} {node.input} -> {node.output}"
        )

>>>

    [2024-05-27 18:42:21,156] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    Applied 1 of general pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific pattern rewrite rules.
    ------------------------------------------
    exported model: ['dort-llama-ort_1.onnx', 'dort-llama-ort_0.onnx']
    
    NODES in {name!r}
    1/305: Cos ['cat'] -> ['cos']
    2/305: Constant [] -> ['_val_34']
    3/305: Pow ['add_7', '_val_34'] -> ['pow_5']
    4/305: Constant [] -> ['aten_view_114_size_0']
    5/305: Reshape ['mm_4', 'aten_view_114_size_0'] -> ['view_23']
    6/305: Constant [] -> ['aten_view_116_size_0']
    7/305: Reshape ['mm_5', 'aten_view_116_size_0'] -> ['view_25']
    8/305: Constant [] -> ['aten_view_121_size_0']
    9/305: Reshape ['mm_3', 'aten_view_121_size_0'] -> ['view_21']
    10/305: Mul ['embedding', 'rsqrt'] -> ['mul_1']
    11/305: Mul ['tangents_1', 'primals_3'] -> ['mul_13']
    12/305: Mul ['add_7', 'rsqrt_2'] -> ['mul_11']
    13/305: Constant [] -> ['_val_50']
    14/305: Pow ['embedding', '_val_50'] -> ['pow_9']
    15/305: Sin ['cat'] -> ['sin']
    16/305: Constant [] -> ['_val_53']
    17/305: Equal ['primals_13', '_val_53'] -> ['eq_1']
    18/305: Constant [] -> ['aten_unsqueeze_133_dim_0']
    19/305: Unsqueeze ['cos', 'aten_unsqueeze_133_dim_0'] -> ['unsqueeze_10']
    20/305: Constant [] -> ['_val_57']
    21/305: Mul ['pow_5', '_val_57'] -> ['mul_19']
    22/305: Mul ['view_23', 'sigmoid'] -> ['mul_9']
    23/305: Constant [] -> ['fill']
    24/305: Constant [] -> ['alpha__1']
    25/305: Mul ['view_21', 'alpha__1'] -> ['other_1__1']
    26/305: Add ['embedding', 'other_1__1'] -> ['add_5']
    27/305: Mul ['mul_13', 'add_7'] -> ['mul_15']
    28/305: Mul ['mul_13', 'rsqrt_2'] -> ['mul_16']
    29/305: Mul ['tangents_1', 'mul_11'] -> ['mul_14']
    30/305: Constant [] -> ['_val_68']
    31/305: Mul ['pow_9', '_val_68'] -> ['mul_46']
    32/305: Constant [] -> ['aten_unsqueeze_147_dim_0']
    33/305: Unsqueeze ['sin', 'aten_unsqueeze_147_dim_0'] -> ['unsqueeze_11']
    34/305: Constant [] -> ['aten_unsqueeze_148_dim_0']
    35/305: Unsqueeze ['eq_1', 'aten_unsqueeze_148_dim_0'] -> ['unsqueeze_12']
    36/305: Constant [] -> ['alpha__2']
    37/305: Mul ['sigmoid', 'alpha__2'] -> ['other_1__2']
    38/305: Sub ['fill', 'other_1__2'] -> ['sub']
    39/305: Mul ['add_5', 'rsqrt_1'] -> ['mul_7']
    40/305: Constant [] -> ['_val_75']
    41/305: Pow ['add_5', '_val_75'] -> ['pow_7']
    42/305: Constant [] -> ['_val_79']
    43/305: ReduceSum ['mul_15', '_val_79'] -> ['sum_2']
    44/305: Constant [] -> ['_val_81']
    45/305: ReduceSum ['mul_14', '_val_81'] -> ['sum_1']
    46/305: Mul ['view_23', 'sub'] -> ['mul_23']
    47/305: Constant [] -> ['_val_85']
    48/305: Mul ['pow_7', '_val_85'] -> ['mul_32']
    49/305: Constant [] -> ['_val_89']
    50/305: Mul ['sum_2', '_val_89'] -> ['mul_17']
    51/305: Constant [] -> ['aten_view_169_size_0']
    52/305: Reshape ['sum_1', 'aten_view_169_size_0'] -> ['view_28']
    53/305: Constant [] -> ['aten_add_173_other_1']
    54/305: Add ['mul_23', 'aten_add_173_other_1'] -> ['add_10']
    55/305: Constant [] -> ['scalar_tensor_default_1']
    56/305: Pow ['rsqrt_1', 'scalar_tensor_default_1'] -> ['pow_6']
    57/305: Constant [] -> ['scalar_tensor_default_2']
    58/305: Pow ['rsqrt', 'scalar_tensor_default_2'] -> ['pow_8']
    59/305: Constant [] -> ['scalar_tensor_default_3']
    60/305: Pow ['rsqrt_2', 'scalar_tensor_default_3'] -> ['pow_4']
    61/305: Mul ['sigmoid', 'add_10'] -> ['mul_24']
    62/305: Mul ['mul_17', 'pow_4'] -> ['mul_18']
    63/305: Constant [] -> ['aten_expand_186_size_1']
    64/305: Expand ['mul_18', 'aten_expand_186_size_1'] -> ['expand_9']
    65/305: Constant [] -> ['scalar_tensor_default_4']
    66/305: Div ['expand_9', 'scalar_tensor_default_4'] -> ['div_1']
    67/305: Mul ['div_1', 'mul_19'] -> ['mul_20']
    68/305: Constant [] -> ['alpha__3']
    69/305: Mul ['mul_20', 'alpha__3'] -> ['other_1__3']
    70/305: Add ['mul_16', 'other_1__3'] -> ['add_9']
    71/305: Constant [] -> ['aten_view_193_size_0']
    72/305: Reshape ['add_9', 'aten_view_193_size_0'] -> ['view_29']
    73/305: Transpose ['view_29'] -> ['t_7']
    74/305: MatMul ['view_29', 't_9'] -> ['mm_8']
    75/305: MatMul ['t_7', 'view_26'] -> ['mm_7']
    76/305: Constant [] -> ['aten_view_198_size_0']
    77/305: Reshape ['mm_8', 'aten_view_198_size_0'] -> ['view_30']
    78/305: Transpose ['mm_7'] -> ['t_8']
    79/305: Mul ['view_30', 'mul_9'] -> ['mul_21']
    80/305: Mul ['view_30', 'view_25'] -> ['mul_22']
    81/305: Transpose ['t_8'] -> ['t_10']
    82/305: Constant [] -> ['aten_view_204_size_0']
    83/305: Reshape ['mul_21', 'aten_view_204_size_0'] -> ['view_31']
    84/305: Mul ['mul_22', 'mul_24'] -> ['mul_25']
    85/305: Transpose ['view_31'] -> ['t_11']
    86/305: MatMul ['view_31', 't_13'] -> ['mm_10']
    87/305: Constant [] -> ['aten_view_209_size_0']
    88/305: Reshape ['mul_25', 'aten_view_209_size_0'] -> ['view_33']
    89/305: MatMul ['t_11', 'view_22'] -> ['mm_9']
    90/305: Constant [] -> ['aten_view_212_size_0']
    91/305: Reshape ['mm_10', 'aten_view_212_size_0'] -> ['view_32']
    92/305: Transpose ['view_33'] -> ['t_15']
    93/305: MatMul ['view_33', 't_17'] -> ['mm_12']
    94/305: Transpose ['mm_9'] -> ['t_12']
    95/305: MatMul ['t_15', 'view_22'] -> ['mm_11']
    96/305: Constant [] -> ['aten_view_218_size_0']
    97/305: Reshape ['mm_12', 'aten_view_218_size_0'] -> ['view_34']
    98/305: Transpose ['t_12'] -> ['t_14']
    99/305: Transpose ['mm_11'] -> ['t_16']
    100/305: Constant [] -> ['alpha__4']
    101/305: Mul ['view_34', 'alpha__4'] -> ['other_1__4']
    102/305: Add ['view_32', 'other_1__4'] -> ['add_11']
    103/305: Transpose ['t_16'] -> ['t_18']
    104/305: Mul ['add_11', 'primals_2'] -> ['mul_26']
    105/305: Mul ['add_11', 'mul_7'] -> ['mul_27']
    106/305: Mul ['mul_26', 'add_5'] -> ['mul_28']
    107/305: Mul ['mul_26', 'rsqrt_1'] -> ['mul_29']
    108/305: Constant [] -> ['_val_150']
    109/305: ReduceSum ['mul_27', '_val_150'] -> ['sum_3']
    110/305: Constant [] -> ['_val_152']
    111/305: ReduceSum ['mul_28', '_val_152'] -> ['sum_4']
    112/305: Constant [] -> ['alpha__5']
    113/305: Mul ['mul_29', 'alpha__5'] -> ['other_1__5']
    114/305: Add ['add_9', 'other_1__5'] -> ['add_12']
    115/305: Constant [] -> ['aten_view_233_size_0']
    116/305: Reshape ['sum_3', 'aten_view_233_size_0'] -> ['view_35']
    117/305: Constant [] -> ['_val_157']
    118/305: Mul ['sum_4', '_val_157'] -> ['mul_30']
    119/305: Mul ['mul_30', 'pow_6'] -> ['mul_31']
    120/305: Constant [] -> ['aten_expand_238_size_1']
    121/305: Expand ['mul_31', 'aten_expand_238_size_1'] -> ['expand_10']
    122/305: Constant [] -> ['scalar_tensor_default_5']
    123/305: Div ['expand_10', 'scalar_tensor_default_5'] -> ['div_2']
    124/305: Mul ['div_2', 'mul_32'] -> ['mul_33']
    125/305: Constant [] -> ['alpha__6']
    126/305: Mul ['mul_33', 'alpha__6'] -> ['other_1__6']
    127/305: Add ['add_12', 'other_1__6'] -> ['add_13']
    128/305: Constant [] -> ['aten_view_245_size_0']
    129/305: Reshape ['add_13', 'aten_view_245_size_0'] -> ['view_36']
    130/305: Transpose ['view_36'] -> ['t_19']
    131/305: MatMul ['view_36', 't_21'] -> ['mm_14']
    132/305: MatMul ['t_19', 'view_20'] -> ['mm_13']
    133/305: Constant [] -> ['aten_view_250_size_0']
    134/305: Reshape ['mm_14', 'aten_view_250_size_0'] -> ['view_37']
    135/305: Transpose ['mm_13'] -> ['t_20']
    136/305: Constant [] -> ['aten_view_253_size_0']
    137/305: Reshape ['view_37', 'aten_view_253_size_0'] -> ['view_38']
    138/305: Transpose ['t_20'] -> ['t_22']
    139/305: Transpose ['view_38'] -> ['transpose_6']
    140/305: Constant [] -> ['aten_view_258_size_0']
    141/305: Reshape ['transpose_6', 'aten_view_258_size_0'] -> ['view_39']
    142/305: MatMul ['transpose_7', 'view_39'] -> ['bmm_3']
    143/305: MatMul ['view_39', 'transpose_8'] -> ['bmm_4']
    144/305: Constant [] -> ['aten_view_262_size_0']
    145/305: Reshape ['bmm_3', 'aten_view_262_size_0'] -> ['view_40']
    146/305: Constant [] -> ['aten_view_264_size_0']
    147/305: Reshape ['bmm_4', 'aten_view_264_size_0'] -> ['view_41']
    148/305: Constant [] -> ['alpha__7']
    149/305: Mul ['view_40', 'alpha__7'] -> ['other_1__7']
    150/305: Add ['tangents_3', 'other_1__7'] -> ['add_14']
    151/305: Mul ['view_41', 'detach_13'] -> ['mul_34']
    152/305: Transpose ['add_14'] -> ['transpose_12']
    153/305: Constant [] -> ['_val_191']
    154/305: ReduceSum ['mul_34', '_val_191'] -> ['sum_5']
    155/305: Mul ['detach_13', 'sum_5'] -> ['mul_35']
    156/305: Constant [] -> ['aten_view_273_size_0']
    157/305: Reshape ['transpose_12', 'aten_view_273_size_0'] -> ['view_45']
    158/305: Constant [] -> ['alpha__8']
    159/305: Mul ['mul_35', 'alpha__8'] -> ['other_1__8']
    160/305: Sub ['mul_34', 'other_1__8'] -> ['sub_1']
    161/305: Constant [] -> ['aten_view_276_size_0']
    162/305: Reshape ['view_45', 'aten_view_276_size_0'] -> ['view_48']
    163/305: Constant [] -> ['_val_200']
    164/305: Div ['sub_1', '_val_200'] -> ['div_3']
    165/305: Transpose ['view_48'] -> ['t_23']
    166/305: MatMul ['view_48', 't_25'] -> ['mm_16']
    167/305: Constant [] -> ['aten_view_282_size_0']
    168/305: Reshape ['div_3', 'aten_view_282_size_0'] -> ['view_42']
    169/305: MatMul ['t_23', 'view_1'] -> ['mm_15']
    170/305: Constant [] -> ['aten_view_285_size_0']
    171/305: Reshape ['mm_16', 'aten_view_285_size_0'] -> ['view_49']
    172/305: MatMul ['transpose_9', 'view_42'] -> ['bmm_5']
    173/305: MatMul ['view_42', 'transpose_10'] -> ['bmm_6']
    174/305: Transpose ['mm_15'] -> ['t_24']
    175/305: Constant [] -> ['aten_view_290_size_0']
    176/305: Reshape ['bmm_5', 'aten_view_290_size_0'] -> ['view_43']
    177/305: Constant [] -> ['aten_view_292_size_0']
    178/305: Reshape ['bmm_6', 'aten_view_292_size_0'] -> ['view_44']
    179/305: Transpose ['t_24'] -> ['t_26']
    180/305: Transpose ['view_43'] -> ['transpose_11']
    181/305: Mul ['view_44', 'unsqueeze_11'] -> ['mul_38']
    182/305: Mul ['view_44', 'unsqueeze_10'] -> ['mul_39']
    183/305: Constant [] -> ['alpha__9']
    184/305: Mul ['transpose_11', 'alpha__9'] -> ['other_1__9']
    185/305: Add ['tangents_2', 'other_1__9'] -> ['add_15']
    186/305: Constant [] -> ['_val_224']
    187/305: Constant [] -> ['_val_228']
    188/305: Constant [] -> ['_val_232']
    189/305: Constant [] -> ['_val_236']
    190/305: Slice ['mul_38', '_val_224', '_val_228', '_val_232', '_val_236'] -> ['slice_36']
    191/305: Constant [] -> ['_val_241']
    192/305: Constant [] -> ['_val_245']
    193/305: Constant [] -> ['_val_249']
    194/305: Constant [] -> ['_val_253']
    195/305: Slice ['mul_38', '_val_241', '_val_245', '_val_249', '_val_253'] -> ['slice_37']
    196/305: Mul ['add_15', 'unsqueeze_11'] -> ['mul_36']
    197/305: Mul ['add_15', 'unsqueeze_10'] -> ['mul_37']
    198/305: Neg ['slice_36'] -> ['neg_3']
    199/305: Constant [] -> ['_val_263']
    200/305: Constant [] -> ['_val_267']
    201/305: Constant [] -> ['_val_271']
    202/305: Constant [] -> ['_val_275']
    203/305: Slice ['mul_36', '_val_263', '_val_267', '_val_271', '_val_275'] -> ['slice_34']
    204/305: Constant [] -> ['_val_280']
    205/305: Constant [] -> ['_val_284']
    206/305: Constant [] -> ['_val_288']
    207/305: Constant [] -> ['_val_292']
    208/305: Slice ['mul_36', '_val_280', '_val_284', '_val_288', '_val_292'] -> ['slice_35']
    209/305: Constant [] -> ['_val_311']
    210/305: Transpose ['slice_37'] -> ['_val_312']
    211/305: Constant [] -> ['_val_313']
    212/305: ScatterND ['_val_313', '_val_311', '_val_312'] -> ['_val_314']
    213/305: Transpose ['_val_314'] -> ['slice_scatter_6']
    214/305: Neg ['slice_34'] -> ['neg_2']
    215/305: Constant [] -> ['_val_334']
    216/305: Transpose ['neg_3'] -> ['_val_335']
    217/305: Constant [] -> ['_val_336']
    218/305: ScatterND ['_val_336', '_val_334', '_val_335'] -> ['_val_337']
    219/305: Transpose ['_val_337'] -> ['slice_scatter_5']
    220/305: Constant [] -> ['_val_356']
    221/305: Transpose ['slice_35'] -> ['_val_357']
    222/305: Constant [] -> ['_val_358']
    223/305: ScatterND ['_val_358', '_val_356', '_val_357'] -> ['_val_359']
    224/305: Transpose ['_val_359'] -> ['slice_scatter_4']
    225/305: Constant [] -> ['alpha__10']
    226/305: Mul ['slice_scatter_6', 'alpha__10'] -> ['other_1__10']
    227/305: Add ['slice_scatter_5', 'other_1__10'] -> ['add_18']
    228/305: Constant [] -> ['_val_377']
    229/305: Transpose ['neg_2'] -> ['_val_378']
    230/305: Constant [] -> ['_val_379']
    231/305: ScatterND ['_val_379', '_val_377', '_val_378'] -> ['_val_380']
    232/305: Transpose ['_val_380'] -> ['slice_scatter_3']
    233/305: Constant [] -> ['alpha__11']
    234/305: Mul ['mul_39', 'alpha__11'] -> ['other_1__11']
    235/305: Add ['add_18', 'other_1__11'] -> ['add_19']
    236/305: Constant [] -> ['alpha__12']
    237/305: Mul ['slice_scatter_4', 'alpha__12'] -> ['other_1__12']
    238/305: Add ['slice_scatter_3', 'other_1__12'] -> ['add_16']
    239/305: Transpose ['add_19'] -> ['transpose_14']
    240/305: Constant [] -> ['alpha__13']
    241/305: Mul ['mul_37', 'alpha__13'] -> ['other_1__13']
    242/305: Add ['add_16', 'other_1__13'] -> ['add_17']
    243/305: Transpose ['add_17'] -> ['transpose_13']
    244/305: Constant [] -> ['aten_view_466_size_0']
    245/305: Reshape ['transpose_14', 'aten_view_466_size_0'] -> ['view_47']
    246/305: Constant [] -> ['aten_view_469_size_0']
    247/305: Reshape ['view_47', 'aten_view_469_size_0'] -> ['view_52']
    248/305: Constant [] -> ['aten_view_471_size_0']
    249/305: Reshape ['transpose_13', 'aten_view_471_size_0'] -> ['view_46']
    250/305: Transpose ['view_52'] -> ['t_31']
    251/305: MatMul ['view_52', 't_33'] -> ['mm_20']
    252/305: Constant [] -> ['aten_view_475_size_0']
    253/305: Reshape ['view_46', 'aten_view_475_size_0'] -> ['view_50']
    254/305: MatMul ['t_31', 'view_1'] -> ['mm_19']
    255/305: Constant [] -> ['aten_view_478_size_0']
    256/305: Reshape ['mm_20', 'aten_view_478_size_0'] -> ['view_53']
    257/305: Transpose ['view_50'] -> ['t_27']
    258/305: MatMul ['view_50', 't_29'] -> ['mm_18']
    259/305: Transpose ['mm_19'] -> ['t_32']
    260/305: MatMul ['t_27', 'view_1'] -> ['mm_17']
    261/305: Constant [] -> ['aten_view_484_size_0']
    262/305: Reshape ['mm_18', 'aten_view_484_size_0'] -> ['view_51']
    263/305: Transpose ['t_32'] -> ['t_34']
    264/305: Transpose ['mm_17'] -> ['t_28']
    265/305: Constant [] -> ['alpha__14']
    266/305: Mul ['view_51', 'alpha__14'] -> ['other_1__14']
    267/305: Add ['view_49', 'other_1__14'] -> ['add_20']
    268/305: Transpose ['t_28'] -> ['t_30']
    269/305: Constant [] -> ['alpha__15']
    270/305: Mul ['view_53', 'alpha__15'] -> ['other_1__15']
    271/305: Add ['add_20', 'other_1__15'] -> ['add_21']
    272/305: Mul ['add_21', 'primals_1'] -> ['mul_40']
    273/305: Mul ['add_21', 'mul_1'] -> ['mul_41']
    274/305: Mul ['mul_40', 'embedding'] -> ['mul_42']
    275/305: Mul ['mul_40', 'rsqrt'] -> ['mul_43']
    276/305: Constant [] -> ['_val_417']
    277/305: ReduceSum ['mul_41', '_val_417'] -> ['sum_6']
    278/305: Constant [] -> ['_val_419']
    279/305: ReduceSum ['mul_42', '_val_419'] -> ['sum_7']
    280/305: Constant [] -> ['alpha__16']
    281/305: Mul ['mul_43', 'alpha__16'] -> ['other_1__16']
    282/305: Add ['add_13', 'other_1__16'] -> ['add_22']
    283/305: Constant [] -> ['aten_view_500_size_0']
    284/305: Reshape ['sum_6', 'aten_view_500_size_0'] -> ['view_54']
    285/305: Constant [] -> ['_val_424']
    286/305: Mul ['sum_7', '_val_424'] -> ['mul_44']
    287/305: Mul ['mul_44', 'pow_8'] -> ['mul_45']
    288/305: Constant [] -> ['aten_expand_505_size_1']
    289/305: Expand ['mul_45', 'aten_expand_505_size_1'] -> ['expand_11']
    290/305: Constant [] -> ['scalar_tensor_default_6']
    291/305: Div ['expand_11', 'scalar_tensor_default_6'] -> ['div_4']
    292/305: Mul ['div_4', 'mul_46'] -> ['mul_47']
    293/305: Constant [] -> ['alpha__17']
    294/305: Mul ['mul_47', 'alpha__17'] -> ['other_1__17']
    295/305: Add ['add_22', 'other_1__17'] -> ['add_23']
    296/305: Constant [] -> ['aten_masked_fill_512_value_cast']
    297/305: Where ['unsqueeze_12', 'aten_masked_fill_512_value_cast', 'add_23'] -> ['masked_fill_1']
    298/305: Constant [] -> ['_val_436']
    299/305: ConstantOfShape ['_val_436'] -> ['aten_new_zeros_514_result']
    300/305: SequenceConstruct ['primals_13'] -> ['438']
    301/305: Constant [] -> ['int64_0__18']
    302/305: SequenceAt ['438', 'int64_0__18'] -> ['index__18']
    303/305: Constant [] -> ['int64_m1_1d__18']
    304/305: Unsqueeze ['index__18', 'int64_m1_1d__18'] -> ['new_index__18']
    305/305: ScatterND ['aten_new_zeros_514_result', 'new_index__18', 'masked_fill_1'] -> ['_unsafe_index_put']
    
    NODES in {name!r}
    1/275: Gather ['primals_4', 'primals_13'] -> ['embedding']
    2/275: Transpose ['primals_8'] -> ['t_3']
    3/275: Constant [] -> ['_val_22']
    4/275: Constant [] -> ['_val_23']
    5/275: Constant [] -> ['size_0__1']
    6/275: Constant [] -> ['fill_value_1__1']
    7/275: Expand ['fill_value_1__1', 'size_0__1'] -> ['full']
    8/275: Constant [] -> ['_val_36']
    9/275: Constant [] -> ['_val_40']
    10/275: Constant [] -> ['_val_44']
    11/275: Constant [] -> ['_val_48']
    12/275: Slice ['primals_14', '_val_36', '_val_40', '_val_44', '_val_48'] -> ['slice_8']
    13/275: Transpose ['primals_9'] -> ['t_4']
    14/275: Transpose ['primals_10'] -> ['t_5']
    15/275: Transpose ['primals_11'] -> ['t_6']
    16/275: Transpose ['primals_5'] -> ['t']
    17/275: Transpose ['primals_6'] -> ['t_1']
    18/275: Transpose ['primals_7'] -> ['t_2']
    19/275: Constant [] -> ['aten_unsqueeze_187_dim_0']
    20/275: Unsqueeze ['primals_12', 'aten_unsqueeze_187_dim_0'] -> ['unsqueeze_7']
    21/275: Constant [] -> ['scalar_tensor_default']
    22/275: Pow ['embedding', 'scalar_tensor_default'] -> ['pow_1']
    23/275: Transpose ['t_3'] -> ['t_21']
    24/275: Constant [] -> ['aten_triu_195_diagonal']
    25/275: Trilu ['full', 'aten_triu_195_diagonal'] -> ['triu']
    26/275: Constant [] -> ['aten_unsqueeze_196_dim_0']
    27/275: Unsqueeze ['slice_8', 'aten_unsqueeze_196_dim_0'] -> ['unsqueeze_5']
    28/275: Transpose ['t_4'] -> ['t_17']
    29/275: Transpose ['t_5'] -> ['t_13']
    30/275: Transpose ['t_6'] -> ['t_9']
    31/275: Transpose ['t'] -> ['t_33']
    32/275: Transpose ['t_1'] -> ['t_29']
    33/275: Transpose ['t_2'] -> ['t_25']
    34/275: Constant [] -> ['_val_75']
    35/275: Constant [] -> ['_val_79']
    36/275: Constant [] -> ['_val_83']
    37/275: Constant [] -> ['_val_87']
    38/275: Slice ['unsqueeze_7', '_val_75', '_val_79', '_val_83', '_val_87'] -> ['slice_21']
    39/275: Constant [] -> ['_val_89']
    40/275: ReduceMean ['pow_1', '_val_89'] -> ['mean']
    41/275: Constant [] -> ['gt']
    42/275: Constant [] -> ['aten_unsqueeze_240_dim_0']
    43/275: Unsqueeze ['unsqueeze_5', 'aten_unsqueeze_240_dim_0'] -> ['unsqueeze_6']
    44/275: Constant [] -> ['aten_unsqueeze_241_dim_0']
    45/275: Unsqueeze ['slice_21', 'aten_unsqueeze_241_dim_0'] -> ['unsqueeze_8']
    46/275: Constant [] -> ['aten_add_243_other_1']
    47/275: Add ['mean', 'aten_add_243_other_1'] -> ['add_1']
    48/275: Cast ['gt'] -> ['convert_element_type_default']
    49/275: Mul ['triu', 'convert_element_type_default'] -> ['mul']
    50/275: Constant [] -> ['_val_119']
    51/275: Constant [] -> ['_val_123']
    52/275: Constant [] -> ['_val_127']
    53/275: Constant [] -> ['_val_131']
    54/275: Slice ['unsqueeze_6', '_val_119', '_val_123', '_val_127', '_val_131'] -> ['slice_9']
    55/275: Constant [] -> ['aten_expand_265_size_1']
    56/275: Expand ['unsqueeze_8', 'aten_expand_265_size_1'] -> ['expand_2']
    57/275: Sqrt ['add_1'] -> ['aten_rsqrt_266_tmp']
    58/275: Reciprocal ['aten_rsqrt_266_tmp'] -> ['rsqrt']
    59/275: Constant [] -> ['aten_unsqueeze_284_dim_0']
    60/275: Unsqueeze ['mul', 'aten_unsqueeze_284_dim_0'] -> ['unsqueeze_3']
    61/275: Constant [] -> ['aten_expand_286_size_1']
    62/275: Expand ['expand_2', 'aten_expand_286_size_1'] -> ['expand_3']
    63/275: Mul ['embedding', 'rsqrt'] -> ['mul_1']
    64/275: Constant [] -> ['aten_unsqueeze_289_dim_0']
    65/275: Unsqueeze ['unsqueeze_3', 'aten_unsqueeze_289_dim_0'] -> ['unsqueeze_4']
    66/275: Mul ['primals_1', 'mul_1'] -> ['mul_2']
    67/275: Constant [] -> ['_val_167']
    68/275: Constant [] -> ['_val_171']
    69/275: Constant [] -> ['_val_175']
    70/275: Constant [] -> ['_val_179']
    71/275: Slice ['unsqueeze_4', '_val_167', '_val_171', '_val_175', '_val_179'] -> ['slice_3']
    72/275: Constant [] -> ['aten_view_313_size_0']
    73/275: Reshape ['mul_2', 'aten_view_313_size_0'] -> ['view_1']
    74/275: Constant [] -> ['view_11']
    75/275: Constant [] -> ['_val_188']
    76/275: Constant [] -> ['_val_192']
    77/275: Constant [] -> ['_val_196']
    78/275: Constant [] -> ['_val_200']
    79/275: Slice ['slice_3', '_val_188', '_val_192', '_val_196', '_val_200'] -> ['slice_4']
    80/275: MatMul ['view_1', 't'] -> ['mm']
    81/275: MatMul ['view_1', 't_1'] -> ['mm_1']
    82/275: MatMul ['view_1', 't_2'] -> ['mm_2']
    83/275: Constant [] -> ['aten_expand_338_size_1']
    84/275: Expand ['slice_4', 'aten_expand_338_size_1'] -> ['expand_1']
    85/275: Constant [] -> ['aten_view_340_size_0']
    86/275: Reshape ['mm', 'aten_view_340_size_0'] -> ['view_2']
    87/275: Constant [] -> ['aten_view_342_size_0']
    88/275: Reshape ['mm_1', 'aten_view_342_size_0'] -> ['view_4']
    89/275: Constant [] -> ['aten_view_344_size_0']
    90/275: Reshape ['mm_2', 'aten_view_344_size_0'] -> ['view_6']
    91/275: MatMul ['expand_3', 'view_11'] -> ['view_12']
    92/275: Constant [] -> ['aten_view_349_size_0']
    93/275: Reshape ['view_2', 'aten_view_349_size_0'] -> ['view_7']
    94/275: Constant [] -> ['aten_view_351_size_0']
    95/275: Reshape ['view_4', 'aten_view_351_size_0'] -> ['view_8']
    96/275: Constant [] -> ['aten_view_353_size_0']
    97/275: Reshape ['view_6', 'aten_view_353_size_0'] -> ['view_9']
    98/275: Transpose ['view_12'] -> ['transpose_3']
    99/275: Constant [] -> ['_val_227']
    100/275: Constant [] -> ['_val_231']
    101/275: Constant [] -> ['_val_235']
    102/275: Constant [] -> ['_val_239']
    103/275: Slice ['expand_1', '_val_227', '_val_231', '_val_235', '_val_239'] -> ['slice_5']
    104/275: Transpose ['view_7'] -> ['transpose']
    105/275: Transpose ['view_8'] -> ['transpose_1']
    106/275: Transpose ['view_9'] -> ['transpose_2']
    107/275: Concat ['transpose_3', 'transpose_3'] -> ['cat']
    108/275: Constant [] -> ['_val_249']
    109/275: Constant [] -> ['_val_253']
    110/275: Constant [] -> ['_val_257']
    111/275: Constant [] -> ['_val_261']
    112/275: Slice ['slice_5', '_val_249', '_val_253', '_val_257', '_val_261'] -> ['slice_6']
    113/275: Constant [] -> ['_val_266']
    114/275: Constant [] -> ['_val_270']
    115/275: Constant [] -> ['_val_274']
    116/275: Constant [] -> ['_val_278']
    117/275: Slice ['transpose', '_val_266', '_val_270', '_val_274', '_val_278'] -> ['slice_24']
    118/275: Constant [] -> ['_val_283']
    119/275: Constant [] -> ['_val_287']
    120/275: Constant [] -> ['_val_291']
    121/275: Constant [] -> ['_val_295']
    122/275: Slice ['transpose', '_val_283', '_val_287', '_val_291', '_val_295'] -> ['slice_25']
    123/275: Constant [] -> ['_val_300']
    124/275: Constant [] -> ['_val_304']
    125/275: Constant [] -> ['_val_308']
    126/275: Constant [] -> ['_val_312']
    127/275: Slice ['transpose_1', '_val_300', '_val_304', '_val_308', '_val_312'] -> ['slice_26']
    128/275: Constant [] -> ['_val_317']
    129/275: Constant [] -> ['_val_321']
    130/275: Constant [] -> ['_val_325']
    131/275: Constant [] -> ['_val_329']
    132/275: Slice ['transpose_1', '_val_317', '_val_321', '_val_325', '_val_329'] -> ['slice_27']
    133/275: Constant [] -> ['aten_expand_463_size_1']
    134/275: Expand ['transpose_2', 'aten_expand_463_size_1'] -> ['expand_8']
    135/275: Cos ['cat'] -> ['cos']
    136/275: Sin ['cat'] -> ['sin']
    137/275: Constant [] -> ['_val_338']
    138/275: Constant [] -> ['_val_342']
    139/275: Constant [] -> ['_val_346']
    140/275: Constant [] -> ['_val_350']
    141/275: Slice ['slice_6', '_val_338', '_val_342', '_val_346', '_val_350'] -> ['slice_7']
    142/275: Neg ['slice_25'] -> ['neg']
    143/275: Neg ['slice_27'] -> ['neg_1']
    144/275: Constant [] -> ['aten_unsqueeze_486_dim_0']
    145/275: Unsqueeze ['cos', 'aten_unsqueeze_486_dim_0'] -> ['unsqueeze_10']
    146/275: Constant [] -> ['aten_unsqueeze_487_dim_0']
    147/275: Unsqueeze ['sin', 'aten_unsqueeze_487_dim_0'] -> ['unsqueeze_11']
    148/275: Constant [] -> ['alpha__2']
    149/275: Mul ['slice_9', 'alpha__2'] -> ['other_1__2']
    150/275: Add ['slice_7', 'other_1__2'] -> ['add']
    151/275: Concat ['neg', 'slice_24'] -> ['cat_1']
    152/275: Concat ['neg_1', 'slice_26'] -> ['cat_2']
    153/275: Constant [] -> ['aten_view_494_size_0']
    154/275: Reshape ['expand_8', 'aten_view_494_size_0'] -> ['view_17']
    155/275: Mul ['transpose', 'unsqueeze_10'] -> ['mul_3']
    156/275: Mul ['transpose_1', 'unsqueeze_10'] -> ['mul_5']
    157/275: Constant [] -> ['scalar_tensor_default_1']
    158/275: Equal ['add', 'scalar_tensor_default_1'] -> ['eq']
    159/275: Mul ['cat_1', 'unsqueeze_11'] -> ['mul_4']
    160/275: Mul ['cat_2', 'unsqueeze_11'] -> ['mul_6']
    161/275: Transpose ['view_17'] -> ['transpose_8']
    162/275: Constant [] -> ['_val_372']
    163/275: Where ['eq', '_val_372', 'slice_7'] -> ['masked_fill']
    164/275: Constant [] -> ['alpha__3']
    165/275: Mul ['mul_4', 'alpha__3'] -> ['other_1__3']
    166/275: Add ['mul_3', 'other_1__3'] -> ['add_2']
    167/275: Constant [] -> ['alpha__4']
    168/275: Mul ['mul_6', 'alpha__4'] -> ['other_1__4']
    169/275: Add ['mul_5', 'other_1__4'] -> ['add_3']
    170/275: Constant [] -> ['aten_expand_509_size_1']
    171/275: Expand ['add_2', 'aten_expand_509_size_1'] -> ['expand_5']
    172/275: Transpose ['add_3'] -> ['transpose_4']
    173/275: Constant [] -> ['_val_395']
    174/275: Transpose ['masked_fill'] -> ['_val_396']
    175/275: Transpose ['slice_6'] -> ['_val_397']
    176/275: ScatterND ['_val_397', '_val_395', '_val_396'] -> ['_val_398']
    177/275: Transpose ['_val_398'] -> ['slice_scatter']
    178/275: Constant [] -> ['aten_expand_533_size_1']
    179/275: Expand ['transpose_4', 'aten_expand_533_size_1'] -> ['expand_6']
    180/275: Constant [] -> ['_val_418']
    181/275: Transpose ['slice_scatter'] -> ['_val_419']
    182/275: Transpose ['slice_5'] -> ['_val_420']
    183/275: ScatterND ['_val_420', '_val_418', '_val_419'] -> ['_val_421']
    184/275: Transpose ['_val_421'] -> ['slice_scatter_1']
    185/275: Constant [] -> ['aten_view_555_size_0']
    186/275: Reshape ['expand_5', 'aten_view_555_size_0'] -> ['view_13']
    187/275: Constant [] -> ['_val_441']
    188/275: ScatterND ['expand_1', '_val_441', 'slice_scatter_1'] -> ['slice_scatter_2']
    189/275: Transpose ['view_13'] -> ['transpose_9']
    190/275: Constant [] -> ['aten_view_576_size_0']
    191/275: Reshape ['expand_6', 'aten_view_576_size_0'] -> ['view_14']
    192/275: Constant [] -> ['_val_449']
    193/275: Constant [] -> ['_val_453']
    194/275: Constant [] -> ['_val_457']
    195/275: Constant [] -> ['_val_461']
    196/275: Slice ['slice_scatter_2', '_val_449', '_val_453', '_val_457', '_val_461'] -> ['slice_31']
    197/275: MatMul ['view_13', 'view_14'] -> ['bmm_1']
    198/275: Transpose ['view_14'] -> ['transpose_10']
    199/275: Constant [] -> ['_val_468']
    200/275: Constant [] -> ['_val_472']
    201/275: Constant [] -> ['_val_476']
    202/275: Constant [] -> ['_val_480']
    203/275: Slice ['slice_31', '_val_468', '_val_472', '_val_476', '_val_480'] -> ['slice_32']
    204/275: Constant [] -> ['aten_view_614_size_0']
    205/275: Reshape ['bmm_1', 'aten_view_614_size_0'] -> ['view_15']
    206/275: Constant [] -> ['_val_487']
    207/275: Constant [] -> ['_val_491']
    208/275: Constant [] -> ['_val_495']
    209/275: Constant [] -> ['_val_499']
    210/275: Slice ['slice_32', '_val_487', '_val_491', '_val_495', '_val_499'] -> ['slice_33']
    211/275: Constant [] -> ['_val_501']
    212/275: Div ['view_15', '_val_501'] -> ['div']
    213/275: Constant [] -> ['alpha__5']
    214/275: Mul ['slice_33', 'alpha__5'] -> ['other_1__5']
    215/275: Add ['div', 'other_1__5'] -> ['add_4']
    216/275: Softmax ['add_4'] -> ['_softmax']
    217/275: Constant [] -> ['aten_expand_640_size_1']
    218/275: Expand ['_softmax', 'aten_expand_640_size_1'] -> ['expand_7']
    219/275: Constant [] -> ['aten_view_643_size_0']
    220/275: Reshape ['expand_7', 'aten_view_643_size_0'] -> ['view_16']
    221/275: Identity ['_softmax'] -> ['detach_13']
    222/275: MatMul ['view_16', 'view_17'] -> ['bmm_2']
    223/275: Transpose ['view_16'] -> ['transpose_7']
    224/275: Constant [] -> ['aten_view_648_size_0']
    225/275: Reshape ['bmm_2', 'aten_view_648_size_0'] -> ['view_18']
    226/275: Transpose ['view_18'] -> ['transpose_5']
    227/275: Constant [] -> ['aten_view_652_size_0']
    228/275: Reshape ['transpose_5', 'aten_view_652_size_0'] -> ['view_19']
    229/275: Constant [] -> ['aten_view_654_size_0']
    230/275: Reshape ['view_19', 'aten_view_654_size_0'] -> ['view_20']
    231/275: MatMul ['view_20', 't_3'] -> ['mm_3']
    232/275: Constant [] -> ['aten_view_657_size_0']
    233/275: Reshape ['mm_3', 'aten_view_657_size_0'] -> ['view_21']
    234/275: Constant [] -> ['alpha__6']
    235/275: Mul ['view_21', 'alpha__6'] -> ['other_1__6']
    236/275: Add ['embedding', 'other_1__6'] -> ['add_5']
    237/275: Constant [] -> ['scalar_tensor_default_2']
    238/275: Pow ['add_5', 'scalar_tensor_default_2'] -> ['pow_2']
    239/275: Constant [] -> ['_val_531']
    240/275: ReduceMean ['pow_2', '_val_531'] -> ['mean_1']
    241/275: Constant [] -> ['aten_add_665_other_1']
    242/275: Add ['mean_1', 'aten_add_665_other_1'] -> ['add_6']
    243/275: Sqrt ['add_6'] -> ['aten_rsqrt_666_tmp']
    244/275: Reciprocal ['aten_rsqrt_666_tmp'] -> ['rsqrt_1']
    245/275: Mul ['add_5', 'rsqrt_1'] -> ['mul_7']
    246/275: Mul ['primals_2', 'mul_7'] -> ['mul_8']
    247/275: Constant [] -> ['aten_view_670_size_0']
    248/275: Reshape ['mul_8', 'aten_view_670_size_0'] -> ['view_22']
    249/275: MatMul ['view_22', 't_4'] -> ['mm_4']
    250/275: MatMul ['view_22', 't_5'] -> ['mm_5']
    251/275: Constant [] -> ['aten_view_674_size_0']
    252/275: Reshape ['mm_4', 'aten_view_674_size_0'] -> ['view_23']
    253/275: Constant [] -> ['aten_view_676_size_0']
    254/275: Reshape ['mm_5', 'aten_view_676_size_0'] -> ['view_25']
    255/275: Sigmoid ['view_23'] -> ['sigmoid']
    256/275: Mul ['view_23', 'sigmoid'] -> ['mul_9']
    257/275: Mul ['mul_9', 'view_25'] -> ['mul_10']
    258/275: Constant [] -> ['aten_view_681_size_0']
    259/275: Reshape ['mul_10', 'aten_view_681_size_0'] -> ['view_26']
    260/275: MatMul ['view_26', 't_6'] -> ['mm_6']
    261/275: Constant [] -> ['aten_view_684_size_0']
    262/275: Reshape ['mm_6', 'aten_view_684_size_0'] -> ['view_27']
    263/275: Constant [] -> ['alpha__7']
    264/275: Mul ['view_27', 'alpha__7'] -> ['other_1__7']
    265/275: Add ['add_5', 'other_1__7'] -> ['add_7']
    266/275: Constant [] -> ['scalar_tensor_default_3']
    267/275: Pow ['add_7', 'scalar_tensor_default_3'] -> ['pow_3']
    268/275: Constant [] -> ['_val_558']
    269/275: ReduceMean ['pow_3', '_val_558'] -> ['mean_2']
    270/275: Constant [] -> ['aten_add_692_other_1']
    271/275: Add ['mean_2', 'aten_add_692_other_1'] -> ['add_8']
    272/275: Sqrt ['add_8'] -> ['aten_rsqrt_693_tmp']
    273/275: Reciprocal ['aten_rsqrt_693_tmp'] -> ['rsqrt_2']
    274/275: Mul ['add_7', 'rsqrt_2'] -> ['mul_11']
    275/275: Mul ['primals_3', 'mul_11'] -> ['mul_12']
    [runpythonerror]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
      warnings.warn(
    2024-05-27 18:42:25,465 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-27 18:42:25,465 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue full due to large size 4194304.
    2024-05-27 18:42:25,504 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue triu due to large size 4194304.
    2024-05-27 18:42:25,565 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue convert_element_type_default due to large size 4194304.
    2024-05-27 18:42:25,575 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue mul due to large size 4194304.
    2024-05-27 18:42:25,587 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-27 18:42:25,587 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_3 due to large size 4194304.
    2024-05-27 18:42:25,589 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-27 18:42:25,589 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_4 due to large size 4194304.
    2024-05-27 18:42:25,597 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_3 due to large size 4194304.
    2024-05-27 18:42:25,603 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_4 due to large size 4194304.
    2024-05-27 18:42:25,614 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 8388608.
    2024-05-27 18:42:25,614 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue expand_1 due to large size 8388608.
    2024-05-27 18:42:25,633 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue clone due to large size 8388608.
    2024-05-27 18:42:25,640 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_5 due to large size 8388608.
    2024-05-27 18:42:25,649 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_6 due to large size 8388608.
    2024-05-27 18:42:25,674 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_7 due to large size 8388608.
    2024-05-27 18:42:25,689 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue _val_397 due to large size 8388608.
    2024-05-27 18:42:25,695 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue _val_420 due to large size 8388608.
    2024-05-27 18:42:25,881 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-27 18:42:25,882 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue full due to large size 4194304.
    2024-05-27 18:42:25,885 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue triu due to large size 4194304.
    2024-05-27 18:42:25,891 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue convert_element_type_default due to large size 4194304.
    2024-05-27 18:42:25,892 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue mul due to large size 4194304.
    2024-05-27 18:42:25,894 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_3 due to large size 4194304.
    2024-05-27 18:42:25,894 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_4 due to large size 4194304.
    2024-05-27 18:42:25,895 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_3 due to large size 4194304.
    2024-05-27 18:42:25,897 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_4 due to large size 4194304.
    2024-05-27 18:42:25,899 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue expand_1 due to large size 8388608.
    2024-05-27 18:42:25,901 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_5 due to large size 8388608.
    2024-05-27 18:42:25,903 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_6 due to large size 8388608.
    2024-05-27 18:42:25,910 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_7 due to large size 8388608.
    2024-05-27 18:42:25,915 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue _val_397 due to large size 8388608.
    2024-05-27 18:42:25,916 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue _val_420 due to large size 8388608.
    2024-05-27 18:42:26.066551500 [W:onnxruntime:, graph.cc:4051 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_23'. It is not used by any node and should be removed from the model.
    2024-05-27 18:42:26.066618900 [W:onnxruntime:, graph.cc:4051 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_22'. It is not used by any node and should be removed from the model.

With the custom exporter

<<<

import os
import warnings
import numpy as np
import onnx

# from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
import torch
import torch.onnx
from experimental_experiment.torch_models.training_helper import (
    make_aot_ort,
    train_loop,
)
from experimental_experiment.torch_models.dump_helper import dump_onnx

# from experimental_experiment.torch_interpreter import to_onnx

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    from transformers import LlamaConfig
    from transformers.models.llama.modeling_llama import LlamaModel


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = LlamaConfig(
    hidden_size=16,
    num_hidden_layers=1,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=1024,
    num_attention_heads=2,
)
config._attn_implementation = "eager"

model = LlamaModel(config)

batch, seq, vocab_size = 2, 1024, 1024

input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

model(input_ids, input_mask)

os.environ["ONNXRT_CHANGE_REWRITER"] = "1"

local_aot_ort, _ = make_aot_ort(
    dynamic=True,
    verbose=1,
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
    with dump_onnx("dort-llama-ort", folder="dump_llama", clean=True):
        train_loop(optimized_mod, input_ids, input_mask)

names = [_ for _ in os.listdir("dump_llama") if _.endswith(".onnx")]
print("------------------------------------------")
print(f"exported model: {names}")
for name in names:
    print()
    print("NODES in {name!r}")
    onx = onnx.load(os.path.join("dump_llama", name))
    for i, node in enumerate(onx.graph.node):
        print(
            f"{i+1}/{len(onx.graph.node)}: {node.op_type} {node.input} -> {node.output}"
        )

os.environ["ONNXRT_CHANGE_REWRITER"] = "0"

>>>

    [2024-05-27 18:42:34,034] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    Applied 1 of general pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific pattern rewrite rules.
    ------------------------------------------
    exported model: ['dort-llama-ort_1.onnx', 'dort-llama-ort_0.onnx']
    
    NODES in {name!r}
    1/305: Sin ['cat'] -> ['sin']
    2/305: Constant [] -> ['_val_34']
    3/305: Equal ['primals_13', '_val_34'] -> ['eq_1']
    4/305: Cos ['cat'] -> ['cos']
    5/305: Constant [] -> ['_val_38']
    6/305: Pow ['add_7', '_val_38'] -> ['pow_5']
    7/305: Constant [] -> ['aten_view_118_size_0']
    8/305: Reshape ['mm_4', 'aten_view_118_size_0'] -> ['view_23']
    9/305: Constant [] -> ['aten_view_120_size_0']
    10/305: Reshape ['mm_5', 'aten_view_120_size_0'] -> ['view_25']
    11/305: Mul ['add_7', 'rsqrt_2'] -> ['mul_11']
    12/305: Constant [] -> ['aten_view_126_size_0']
    13/305: Reshape ['mm_3', 'aten_view_126_size_0'] -> ['view_21']
    14/305: Mul ['embedding', 'rsqrt'] -> ['mul_1']
    15/305: Mul ['tangents_1', 'primals_3'] -> ['mul_13']
    16/305: Constant [] -> ['_val_54']
    17/305: Pow ['embedding', '_val_54'] -> ['pow_9']
    18/305: Constant [] -> ['aten_unsqueeze_133_dim_0']
    19/305: Unsqueeze ['sin', 'aten_unsqueeze_133_dim_0'] -> ['unsqueeze_11']
    20/305: Constant [] -> ['aten_unsqueeze_134_dim_0']
    21/305: Unsqueeze ['eq_1', 'aten_unsqueeze_134_dim_0'] -> ['unsqueeze_12']
    22/305: Constant [] -> ['aten_unsqueeze_136_dim_0']
    23/305: Unsqueeze ['cos', 'aten_unsqueeze_136_dim_0'] -> ['unsqueeze_10']
    24/305: Constant [] -> ['_val_60']
    25/305: Mul ['pow_5', '_val_60'] -> ['mul_19']
    26/305: Mul ['view_23', 'sigmoid'] -> ['mul_9']
    27/305: Constant [] -> ['fill']
    28/305: Mul ['tangents_1', 'mul_11'] -> ['mul_14']
    29/305: Constant [] -> ['alpha__1']
    30/305: Mul ['view_21', 'alpha__1'] -> ['other_1__1']
    31/305: Add ['embedding', 'other_1__1'] -> ['add_5']
    32/305: Mul ['mul_13', 'add_7'] -> ['mul_15']
    33/305: Mul ['mul_13', 'rsqrt_2'] -> ['mul_16']
    34/305: Constant [] -> ['_val_71']
    35/305: Mul ['pow_9', '_val_71'] -> ['mul_46']
    36/305: Constant [] -> ['alpha__2']
    37/305: Mul ['sigmoid', 'alpha__2'] -> ['other_1__2']
    38/305: Sub ['fill', 'other_1__2'] -> ['sub']
    39/305: Constant [] -> ['_val_75']
    40/305: ReduceSum ['mul_14', '_val_75'] -> ['sum_1']
    41/305: Mul ['add_5', 'rsqrt_1'] -> ['mul_7']
    42/305: Constant [] -> ['_val_78']
    43/305: Pow ['add_5', '_val_78'] -> ['pow_7']
    44/305: Constant [] -> ['_val_82']
    45/305: ReduceSum ['mul_15', '_val_82'] -> ['sum_2']
    46/305: Mul ['view_23', 'sub'] -> ['mul_23']
    47/305: Constant [] -> ['aten_view_164_size_0']
    48/305: Reshape ['sum_1', 'aten_view_164_size_0'] -> ['view_28']
    49/305: Constant [] -> ['_val_88']
    50/305: Mul ['pow_7', '_val_88'] -> ['mul_32']
    51/305: Constant [] -> ['_val_92']
    52/305: Mul ['sum_2', '_val_92'] -> ['mul_17']
    53/305: Constant [] -> ['scalar_tensor_default']
    54/305: Pow ['rsqrt_2', 'scalar_tensor_default'] -> ['pow_4']
    55/305: Constant [] -> ['aten_add_176_other_1']
    56/305: Add ['mul_23', 'aten_add_176_other_1'] -> ['add_10']
    57/305: Constant [] -> ['scalar_tensor_default_2']
    58/305: Pow ['rsqrt_1', 'scalar_tensor_default_2'] -> ['pow_6']
    59/305: Constant [] -> ['scalar_tensor_default_3']
    60/305: Pow ['rsqrt', 'scalar_tensor_default_3'] -> ['pow_8']
    61/305: Mul ['mul_17', 'pow_4'] -> ['mul_18']
    62/305: Mul ['sigmoid', 'add_10'] -> ['mul_24']
    63/305: Constant [] -> ['aten_expand_186_size_1']
    64/305: Expand ['mul_18', 'aten_expand_186_size_1'] -> ['expand_9']
    65/305: Constant [] -> ['scalar_tensor_default_4']
    66/305: Div ['expand_9', 'scalar_tensor_default_4'] -> ['div_1']
    67/305: Mul ['div_1', 'mul_19'] -> ['mul_20']
    68/305: Constant [] -> ['alpha__3']
    69/305: Mul ['mul_20', 'alpha__3'] -> ['other_1__3']
    70/305: Add ['mul_16', 'other_1__3'] -> ['add_9']
    71/305: Constant [] -> ['aten_view_193_size_0']
    72/305: Reshape ['add_9', 'aten_view_193_size_0'] -> ['view_29']
    73/305: Transpose ['view_29'] -> ['t_7']
    74/305: MatMul ['view_29', 't_9'] -> ['mm_8']
    75/305: MatMul ['t_7', 'view_26'] -> ['mm_7']
    76/305: Constant [] -> ['aten_view_198_size_0']
    77/305: Reshape ['mm_8', 'aten_view_198_size_0'] -> ['view_30']
    78/305: Transpose ['mm_7'] -> ['t_8']
    79/305: Mul ['view_30', 'mul_9'] -> ['mul_21']
    80/305: Mul ['view_30', 'view_25'] -> ['mul_22']
    81/305: Transpose ['t_8'] -> ['t_10']
    82/305: Constant [] -> ['aten_view_204_size_0']
    83/305: Reshape ['mul_21', 'aten_view_204_size_0'] -> ['view_31']
    84/305: Mul ['mul_22', 'mul_24'] -> ['mul_25']
    85/305: Transpose ['view_31'] -> ['t_11']
    86/305: MatMul ['view_31', 't_13'] -> ['mm_10']
    87/305: Constant [] -> ['aten_view_209_size_0']
    88/305: Reshape ['mul_25', 'aten_view_209_size_0'] -> ['view_33']
    89/305: MatMul ['t_11', 'view_22'] -> ['mm_9']
    90/305: Constant [] -> ['aten_view_212_size_0']
    91/305: Reshape ['mm_10', 'aten_view_212_size_0'] -> ['view_32']
    92/305: Transpose ['view_33'] -> ['t_15']
    93/305: MatMul ['view_33', 't_17'] -> ['mm_12']
    94/305: Transpose ['mm_9'] -> ['t_12']
    95/305: MatMul ['t_15', 'view_22'] -> ['mm_11']
    96/305: Constant [] -> ['aten_view_218_size_0']
    97/305: Reshape ['mm_12', 'aten_view_218_size_0'] -> ['view_34']
    98/305: Transpose ['t_12'] -> ['t_14']
    99/305: Transpose ['mm_11'] -> ['t_16']
    100/305: Constant [] -> ['alpha__4']
    101/305: Mul ['view_34', 'alpha__4'] -> ['other_1__4']
    102/305: Add ['view_32', 'other_1__4'] -> ['add_11']
    103/305: Transpose ['t_16'] -> ['t_18']
    104/305: Mul ['add_11', 'primals_2'] -> ['mul_26']
    105/305: Mul ['add_11', 'mul_7'] -> ['mul_27']
    106/305: Mul ['mul_26', 'add_5'] -> ['mul_28']
    107/305: Mul ['mul_26', 'rsqrt_1'] -> ['mul_29']
    108/305: Constant [] -> ['_val_150']
    109/305: ReduceSum ['mul_27', '_val_150'] -> ['sum_3']
    110/305: Constant [] -> ['_val_152']
    111/305: ReduceSum ['mul_28', '_val_152'] -> ['sum_4']
    112/305: Constant [] -> ['alpha__5']
    113/305: Mul ['mul_29', 'alpha__5'] -> ['other_1__5']
    114/305: Add ['add_9', 'other_1__5'] -> ['add_12']
    115/305: Constant [] -> ['aten_view_233_size_0']
    116/305: Reshape ['sum_3', 'aten_view_233_size_0'] -> ['view_35']
    117/305: Constant [] -> ['_val_157']
    118/305: Mul ['sum_4', '_val_157'] -> ['mul_30']
    119/305: Mul ['mul_30', 'pow_6'] -> ['mul_31']
    120/305: Constant [] -> ['aten_expand_238_size_1']
    121/305: Expand ['mul_31', 'aten_expand_238_size_1'] -> ['expand_10']
    122/305: Constant [] -> ['scalar_tensor_default_5']
    123/305: Div ['expand_10', 'scalar_tensor_default_5'] -> ['div_2']
    124/305: Mul ['div_2', 'mul_32'] -> ['mul_33']
    125/305: Constant [] -> ['alpha__6']
    126/305: Mul ['mul_33', 'alpha__6'] -> ['other_1__6']
    127/305: Add ['add_12', 'other_1__6'] -> ['add_13']
    128/305: Constant [] -> ['aten_view_245_size_0']
    129/305: Reshape ['add_13', 'aten_view_245_size_0'] -> ['view_36']
    130/305: Transpose ['view_36'] -> ['t_19']
    131/305: MatMul ['view_36', 't_21'] -> ['mm_14']
    132/305: MatMul ['t_19', 'view_20'] -> ['mm_13']
    133/305: Constant [] -> ['aten_view_250_size_0']
    134/305: Reshape ['mm_14', 'aten_view_250_size_0'] -> ['view_37']
    135/305: Transpose ['mm_13'] -> ['t_20']
    136/305: Constant [] -> ['aten_view_253_size_0']
    137/305: Reshape ['view_37', 'aten_view_253_size_0'] -> ['view_38']
    138/305: Transpose ['t_20'] -> ['t_22']
    139/305: Transpose ['view_38'] -> ['transpose_6']
    140/305: Constant [] -> ['aten_view_258_size_0']
    141/305: Reshape ['transpose_6', 'aten_view_258_size_0'] -> ['view_39']
    142/305: MatMul ['transpose_7', 'view_39'] -> ['bmm_3']
    143/305: MatMul ['view_39', 'transpose_8'] -> ['bmm_4']
    144/305: Constant [] -> ['aten_view_262_size_0']
    145/305: Reshape ['bmm_3', 'aten_view_262_size_0'] -> ['view_40']
    146/305: Constant [] -> ['aten_view_264_size_0']
    147/305: Reshape ['bmm_4', 'aten_view_264_size_0'] -> ['view_41']
    148/305: Constant [] -> ['alpha__7']
    149/305: Mul ['view_40', 'alpha__7'] -> ['other_1__7']
    150/305: Add ['tangents_3', 'other_1__7'] -> ['add_14']
    151/305: Mul ['view_41', 'detach_13'] -> ['mul_34']
    152/305: Transpose ['add_14'] -> ['transpose_12']
    153/305: Constant [] -> ['_val_191']
    154/305: ReduceSum ['mul_34', '_val_191'] -> ['sum_5']
    155/305: Mul ['detach_13', 'sum_5'] -> ['mul_35']
    156/305: Constant [] -> ['aten_view_273_size_0']
    157/305: Reshape ['transpose_12', 'aten_view_273_size_0'] -> ['view_45']
    158/305: Constant [] -> ['alpha__8']
    159/305: Mul ['mul_35', 'alpha__8'] -> ['other_1__8']
    160/305: Sub ['mul_34', 'other_1__8'] -> ['sub_1']
    161/305: Constant [] -> ['aten_view_276_size_0']
    162/305: Reshape ['view_45', 'aten_view_276_size_0'] -> ['view_48']
    163/305: Constant [] -> ['_val_200']
    164/305: Div ['sub_1', '_val_200'] -> ['div_3']
    165/305: Transpose ['view_48'] -> ['t_23']
    166/305: MatMul ['view_48', 't_25'] -> ['mm_16']
    167/305: Constant [] -> ['aten_view_282_size_0']
    168/305: Reshape ['div_3', 'aten_view_282_size_0'] -> ['view_42']
    169/305: MatMul ['t_23', 'view_1'] -> ['mm_15']
    170/305: Constant [] -> ['aten_view_285_size_0']
    171/305: Reshape ['mm_16', 'aten_view_285_size_0'] -> ['view_49']
    172/305: MatMul ['transpose_9', 'view_42'] -> ['bmm_5']
    173/305: MatMul ['view_42', 'transpose_10'] -> ['bmm_6']
    174/305: Transpose ['mm_15'] -> ['t_24']
    175/305: Constant [] -> ['aten_view_290_size_0']
    176/305: Reshape ['bmm_5', 'aten_view_290_size_0'] -> ['view_43']
    177/305: Constant [] -> ['aten_view_292_size_0']
    178/305: Reshape ['bmm_6', 'aten_view_292_size_0'] -> ['view_44']
    179/305: Transpose ['t_24'] -> ['t_26']
    180/305: Transpose ['view_43'] -> ['transpose_11']
    181/305: Mul ['view_44', 'unsqueeze_11'] -> ['mul_38']
    182/305: Mul ['view_44', 'unsqueeze_10'] -> ['mul_39']
    183/305: Constant [] -> ['alpha__9']
    184/305: Mul ['transpose_11', 'alpha__9'] -> ['other_1__9']
    185/305: Add ['tangents_2', 'other_1__9'] -> ['add_15']
    186/305: Constant [] -> ['_val_224']
    187/305: Constant [] -> ['_val_228']
    188/305: Constant [] -> ['_val_232']
    189/305: Constant [] -> ['_val_236']
    190/305: Slice ['mul_38', '_val_224', '_val_228', '_val_232', '_val_236'] -> ['slice_36']
    191/305: Constant [] -> ['_val_241']
    192/305: Constant [] -> ['_val_245']
    193/305: Constant [] -> ['_val_249']
    194/305: Constant [] -> ['_val_253']
    195/305: Slice ['mul_38', '_val_241', '_val_245', '_val_249', '_val_253'] -> ['slice_37']
    196/305: Mul ['add_15', 'unsqueeze_11'] -> ['mul_36']
    197/305: Mul ['add_15', 'unsqueeze_10'] -> ['mul_37']
    198/305: Neg ['slice_36'] -> ['neg_3']
    199/305: Constant [] -> ['_val_263']
    200/305: Constant [] -> ['_val_267']
    201/305: Constant [] -> ['_val_271']
    202/305: Constant [] -> ['_val_275']
    203/305: Slice ['mul_36', '_val_263', '_val_267', '_val_271', '_val_275'] -> ['slice_34']
    204/305: Constant [] -> ['_val_280']
    205/305: Constant [] -> ['_val_284']
    206/305: Constant [] -> ['_val_288']
    207/305: Constant [] -> ['_val_292']
    208/305: Slice ['mul_36', '_val_280', '_val_284', '_val_288', '_val_292'] -> ['slice_35']
    209/305: Constant [] -> ['_val_311']
    210/305: Transpose ['slice_37'] -> ['_val_312']
    211/305: Constant [] -> ['_val_313']
    212/305: ScatterND ['_val_313', '_val_311', '_val_312'] -> ['_val_314']
    213/305: Transpose ['_val_314'] -> ['slice_scatter_6']
    214/305: Neg ['slice_34'] -> ['neg_2']
    215/305: Constant [] -> ['_val_334']
    216/305: Transpose ['neg_3'] -> ['_val_335']
    217/305: Constant [] -> ['_val_336']
    218/305: ScatterND ['_val_336', '_val_334', '_val_335'] -> ['_val_337']
    219/305: Transpose ['_val_337'] -> ['slice_scatter_5']
    220/305: Constant [] -> ['_val_356']
    221/305: Transpose ['slice_35'] -> ['_val_357']
    222/305: Constant [] -> ['_val_358']
    223/305: ScatterND ['_val_358', '_val_356', '_val_357'] -> ['_val_359']
    224/305: Transpose ['_val_359'] -> ['slice_scatter_4']
    225/305: Constant [] -> ['alpha__10']
    226/305: Mul ['slice_scatter_6', 'alpha__10'] -> ['other_1__10']
    227/305: Add ['slice_scatter_5', 'other_1__10'] -> ['add_18']
    228/305: Constant [] -> ['_val_377']
    229/305: Transpose ['neg_2'] -> ['_val_378']
    230/305: Constant [] -> ['_val_379']
    231/305: ScatterND ['_val_379', '_val_377', '_val_378'] -> ['_val_380']
    232/305: Transpose ['_val_380'] -> ['slice_scatter_3']
    233/305: Constant [] -> ['alpha__11']
    234/305: Mul ['mul_39', 'alpha__11'] -> ['other_1__11']
    235/305: Add ['add_18', 'other_1__11'] -> ['add_19']
    236/305: Constant [] -> ['alpha__12']
    237/305: Mul ['slice_scatter_4', 'alpha__12'] -> ['other_1__12']
    238/305: Add ['slice_scatter_3', 'other_1__12'] -> ['add_16']
    239/305: Transpose ['add_19'] -> ['transpose_14']
    240/305: Constant [] -> ['alpha__13']
    241/305: Mul ['mul_37', 'alpha__13'] -> ['other_1__13']
    242/305: Add ['add_16', 'other_1__13'] -> ['add_17']
    243/305: Transpose ['add_17'] -> ['transpose_13']
    244/305: Constant [] -> ['aten_view_466_size_0']
    245/305: Reshape ['transpose_14', 'aten_view_466_size_0'] -> ['view_47']
    246/305: Constant [] -> ['aten_view_469_size_0']
    247/305: Reshape ['view_47', 'aten_view_469_size_0'] -> ['view_52']
    248/305: Constant [] -> ['aten_view_471_size_0']
    249/305: Reshape ['transpose_13', 'aten_view_471_size_0'] -> ['view_46']
    250/305: Transpose ['view_52'] -> ['t_31']
    251/305: MatMul ['view_52', 't_33'] -> ['mm_20']
    252/305: Constant [] -> ['aten_view_475_size_0']
    253/305: Reshape ['view_46', 'aten_view_475_size_0'] -> ['view_50']
    254/305: MatMul ['t_31', 'view_1'] -> ['mm_19']
    255/305: Constant [] -> ['aten_view_478_size_0']
    256/305: Reshape ['mm_20', 'aten_view_478_size_0'] -> ['view_53']
    257/305: Transpose ['view_50'] -> ['t_27']
    258/305: MatMul ['view_50', 't_29'] -> ['mm_18']
    259/305: Transpose ['mm_19'] -> ['t_32']
    260/305: MatMul ['t_27', 'view_1'] -> ['mm_17']
    261/305: Constant [] -> ['aten_view_484_size_0']
    262/305: Reshape ['mm_18', 'aten_view_484_size_0'] -> ['view_51']
    263/305: Transpose ['t_32'] -> ['t_34']
    264/305: Transpose ['mm_17'] -> ['t_28']
    265/305: Constant [] -> ['alpha__14']
    266/305: Mul ['view_51', 'alpha__14'] -> ['other_1__14']
    267/305: Add ['view_49', 'other_1__14'] -> ['add_20']
    268/305: Transpose ['t_28'] -> ['t_30']
    269/305: Constant [] -> ['alpha__15']
    270/305: Mul ['view_53', 'alpha__15'] -> ['other_1__15']
    271/305: Add ['add_20', 'other_1__15'] -> ['add_21']
    272/305: Mul ['add_21', 'primals_1'] -> ['mul_40']
    273/305: Mul ['add_21', 'mul_1'] -> ['mul_41']
    274/305: Mul ['mul_40', 'embedding'] -> ['mul_42']
    275/305: Mul ['mul_40', 'rsqrt'] -> ['mul_43']
    276/305: Constant [] -> ['_val_417']
    277/305: ReduceSum ['mul_41', '_val_417'] -> ['sum_6']
    278/305: Constant [] -> ['_val_419']
    279/305: ReduceSum ['mul_42', '_val_419'] -> ['sum_7']
    280/305: Constant [] -> ['alpha__16']
    281/305: Mul ['mul_43', 'alpha__16'] -> ['other_1__16']
    282/305: Add ['add_13', 'other_1__16'] -> ['add_22']
    283/305: Constant [] -> ['aten_view_500_size_0']
    284/305: Reshape ['sum_6', 'aten_view_500_size_0'] -> ['view_54']
    285/305: Constant [] -> ['_val_424']
    286/305: Mul ['sum_7', '_val_424'] -> ['mul_44']
    287/305: Mul ['mul_44', 'pow_8'] -> ['mul_45']
    288/305: Constant [] -> ['aten_expand_505_size_1']
    289/305: Expand ['mul_45', 'aten_expand_505_size_1'] -> ['expand_11']
    290/305: Constant [] -> ['scalar_tensor_default_6']
    291/305: Div ['expand_11', 'scalar_tensor_default_6'] -> ['div_4']
    292/305: Mul ['div_4', 'mul_46'] -> ['mul_47']
    293/305: Constant [] -> ['alpha__17']
    294/305: Mul ['mul_47', 'alpha__17'] -> ['other_1__17']
    295/305: Add ['add_22', 'other_1__17'] -> ['add_23']
    296/305: Constant [] -> ['aten_masked_fill_512_value_cast']
    297/305: Where ['unsqueeze_12', 'aten_masked_fill_512_value_cast', 'add_23'] -> ['masked_fill_1']
    298/305: Constant [] -> ['_val_436']
    299/305: ConstantOfShape ['_val_436'] -> ['aten_new_zeros_514_result']
    300/305: SequenceConstruct ['primals_13'] -> ['438']
    301/305: Constant [] -> ['int64_0__18']
    302/305: SequenceAt ['438', 'int64_0__18'] -> ['index__18']
    303/305: Constant [] -> ['int64_m1_1d__18']
    304/305: Unsqueeze ['index__18', 'int64_m1_1d__18'] -> ['new_index__18']
    305/305: ScatterND ['aten_new_zeros_514_result', 'new_index__18', 'masked_fill_1'] -> ['_unsafe_index_put']
    
    NODES in {name!r}
    1/275: Gather ['primals_4', 'primals_13'] -> ['embedding']
    2/275: Transpose ['primals_8'] -> ['t_3']
    3/275: Constant [] -> ['_val_22']
    4/275: Constant [] -> ['_val_23']
    5/275: Constant [] -> ['size_0__1']
    6/275: Constant [] -> ['fill_value_1__1']
    7/275: Expand ['fill_value_1__1', 'size_0__1'] -> ['full']
    8/275: Constant [] -> ['_val_36']
    9/275: Constant [] -> ['_val_40']
    10/275: Constant [] -> ['_val_44']
    11/275: Constant [] -> ['_val_48']
    12/275: Slice ['primals_14', '_val_36', '_val_40', '_val_44', '_val_48'] -> ['slice_8']
    13/275: Transpose ['primals_9'] -> ['t_4']
    14/275: Transpose ['primals_10'] -> ['t_5']
    15/275: Transpose ['primals_11'] -> ['t_6']
    16/275: Transpose ['primals_5'] -> ['t']
    17/275: Transpose ['primals_6'] -> ['t_1']
    18/275: Transpose ['primals_7'] -> ['t_2']
    19/275: Constant [] -> ['aten_unsqueeze_187_dim_0']
    20/275: Unsqueeze ['primals_12', 'aten_unsqueeze_187_dim_0'] -> ['unsqueeze_7']
    21/275: Constant [] -> ['scalar_tensor_default']
    22/275: Pow ['embedding', 'scalar_tensor_default'] -> ['pow_1']
    23/275: Transpose ['t_3'] -> ['t_21']
    24/275: Constant [] -> ['aten_triu_195_diagonal']
    25/275: Trilu ['full', 'aten_triu_195_diagonal'] -> ['triu']
    26/275: Constant [] -> ['aten_unsqueeze_196_dim_0']
    27/275: Unsqueeze ['slice_8', 'aten_unsqueeze_196_dim_0'] -> ['unsqueeze_5']
    28/275: Transpose ['t_4'] -> ['t_17']
    29/275: Transpose ['t_5'] -> ['t_13']
    30/275: Transpose ['t_6'] -> ['t_9']
    31/275: Transpose ['t'] -> ['t_33']
    32/275: Transpose ['t_1'] -> ['t_29']
    33/275: Transpose ['t_2'] -> ['t_25']
    34/275: Constant [] -> ['_val_75']
    35/275: Constant [] -> ['_val_79']
    36/275: Constant [] -> ['_val_83']
    37/275: Constant [] -> ['_val_87']
    38/275: Slice ['unsqueeze_7', '_val_75', '_val_79', '_val_83', '_val_87'] -> ['slice_21']
    39/275: Constant [] -> ['_val_89']
    40/275: ReduceMean ['pow_1', '_val_89'] -> ['mean']
    41/275: Constant [] -> ['gt']
    42/275: Constant [] -> ['aten_unsqueeze_240_dim_0']
    43/275: Unsqueeze ['unsqueeze_5', 'aten_unsqueeze_240_dim_0'] -> ['unsqueeze_6']
    44/275: Constant [] -> ['aten_unsqueeze_241_dim_0']
    45/275: Unsqueeze ['slice_21', 'aten_unsqueeze_241_dim_0'] -> ['unsqueeze_8']
    46/275: Constant [] -> ['aten_add_243_other_1']
    47/275: Add ['mean', 'aten_add_243_other_1'] -> ['add_1']
    48/275: Cast ['gt'] -> ['convert_element_type_default']
    49/275: Mul ['triu', 'convert_element_type_default'] -> ['mul']
    50/275: Constant [] -> ['_val_119']
    51/275: Constant [] -> ['_val_123']
    52/275: Constant [] -> ['_val_127']
    53/275: Constant [] -> ['_val_131']
    54/275: Slice ['unsqueeze_6', '_val_119', '_val_123', '_val_127', '_val_131'] -> ['slice_9']
    55/275: Constant [] -> ['aten_expand_265_size_1']
    56/275: Expand ['unsqueeze_8', 'aten_expand_265_size_1'] -> ['expand_2']
    57/275: Sqrt ['add_1'] -> ['aten_rsqrt_266_tmp']
    58/275: Reciprocal ['aten_rsqrt_266_tmp'] -> ['rsqrt']
    59/275: Constant [] -> ['aten_unsqueeze_284_dim_0']
    60/275: Unsqueeze ['mul', 'aten_unsqueeze_284_dim_0'] -> ['unsqueeze_3']
    61/275: Constant [] -> ['aten_expand_286_size_1']
    62/275: Expand ['expand_2', 'aten_expand_286_size_1'] -> ['expand_3']
    63/275: Mul ['embedding', 'rsqrt'] -> ['mul_1']
    64/275: Constant [] -> ['aten_unsqueeze_289_dim_0']
    65/275: Unsqueeze ['unsqueeze_3', 'aten_unsqueeze_289_dim_0'] -> ['unsqueeze_4']
    66/275: Mul ['primals_1', 'mul_1'] -> ['mul_2']
    67/275: Constant [] -> ['_val_167']
    68/275: Constant [] -> ['_val_171']
    69/275: Constant [] -> ['_val_175']
    70/275: Constant [] -> ['_val_179']
    71/275: Slice ['unsqueeze_4', '_val_167', '_val_171', '_val_175', '_val_179'] -> ['slice_3']
    72/275: Constant [] -> ['aten_view_313_size_0']
    73/275: Reshape ['mul_2', 'aten_view_313_size_0'] -> ['view_1']
    74/275: Constant [] -> ['view_11']
    75/275: Constant [] -> ['_val_188']
    76/275: Constant [] -> ['_val_192']
    77/275: Constant [] -> ['_val_196']
    78/275: Constant [] -> ['_val_200']
    79/275: Slice ['slice_3', '_val_188', '_val_192', '_val_196', '_val_200'] -> ['slice_4']
    80/275: MatMul ['view_1', 't'] -> ['mm']
    81/275: MatMul ['view_1', 't_1'] -> ['mm_1']
    82/275: MatMul ['view_1', 't_2'] -> ['mm_2']
    83/275: Constant [] -> ['aten_expand_338_size_1']
    84/275: Expand ['slice_4', 'aten_expand_338_size_1'] -> ['expand_1']
    85/275: Constant [] -> ['aten_view_340_size_0']
    86/275: Reshape ['mm', 'aten_view_340_size_0'] -> ['view_2']
    87/275: Constant [] -> ['aten_view_342_size_0']
    88/275: Reshape ['mm_1', 'aten_view_342_size_0'] -> ['view_4']
    89/275: Constant [] -> ['aten_view_344_size_0']
    90/275: Reshape ['mm_2', 'aten_view_344_size_0'] -> ['view_6']
    91/275: MatMul ['expand_3', 'view_11'] -> ['view_12']
    92/275: Constant [] -> ['aten_view_349_size_0']
    93/275: Reshape ['view_2', 'aten_view_349_size_0'] -> ['view_7']
    94/275: Constant [] -> ['aten_view_351_size_0']
    95/275: Reshape ['view_4', 'aten_view_351_size_0'] -> ['view_8']
    96/275: Constant [] -> ['aten_view_353_size_0']
    97/275: Reshape ['view_6', 'aten_view_353_size_0'] -> ['view_9']
    98/275: Transpose ['view_12'] -> ['transpose_3']
    99/275: Constant [] -> ['_val_227']
    100/275: Constant [] -> ['_val_231']
    101/275: Constant [] -> ['_val_235']
    102/275: Constant [] -> ['_val_239']
    103/275: Slice ['expand_1', '_val_227', '_val_231', '_val_235', '_val_239'] -> ['slice_5']
    104/275: Transpose ['view_7'] -> ['transpose']
    105/275: Transpose ['view_8'] -> ['transpose_1']
    106/275: Transpose ['view_9'] -> ['transpose_2']
    107/275: Concat ['transpose_3', 'transpose_3'] -> ['cat']
    108/275: Constant [] -> ['_val_249']
    109/275: Constant [] -> ['_val_253']
    110/275: Constant [] -> ['_val_257']
    111/275: Constant [] -> ['_val_261']
    112/275: Slice ['slice_5', '_val_249', '_val_253', '_val_257', '_val_261'] -> ['slice_6']
    113/275: Constant [] -> ['_val_266']
    114/275: Constant [] -> ['_val_270']
    115/275: Constant [] -> ['_val_274']
    116/275: Constant [] -> ['_val_278']
    117/275: Slice ['transpose', '_val_266', '_val_270', '_val_274', '_val_278'] -> ['slice_24']
    118/275: Constant [] -> ['_val_283']
    119/275: Constant [] -> ['_val_287']
    120/275: Constant [] -> ['_val_291']
    121/275: Constant [] -> ['_val_295']
    122/275: Slice ['transpose', '_val_283', '_val_287', '_val_291', '_val_295'] -> ['slice_25']
    123/275: Constant [] -> ['_val_300']
    124/275: Constant [] -> ['_val_304']
    125/275: Constant [] -> ['_val_308']
    126/275: Constant [] -> ['_val_312']
    127/275: Slice ['transpose_1', '_val_300', '_val_304', '_val_308', '_val_312'] -> ['slice_26']
    128/275: Constant [] -> ['_val_317']
    129/275: Constant [] -> ['_val_321']
    130/275: Constant [] -> ['_val_325']
    131/275: Constant [] -> ['_val_329']
    132/275: Slice ['transpose_1', '_val_317', '_val_321', '_val_325', '_val_329'] -> ['slice_27']
    133/275: Constant [] -> ['aten_expand_463_size_1']
    134/275: Expand ['transpose_2', 'aten_expand_463_size_1'] -> ['expand_8']
    135/275: Cos ['cat'] -> ['cos']
    136/275: Sin ['cat'] -> ['sin']
    137/275: Constant [] -> ['_val_338']
    138/275: Constant [] -> ['_val_342']
    139/275: Constant [] -> ['_val_346']
    140/275: Constant [] -> ['_val_350']
    141/275: Slice ['slice_6', '_val_338', '_val_342', '_val_346', '_val_350'] -> ['slice_7']
    142/275: Neg ['slice_25'] -> ['neg']
    143/275: Neg ['slice_27'] -> ['neg_1']
    144/275: Constant [] -> ['aten_unsqueeze_486_dim_0']
    145/275: Unsqueeze ['cos', 'aten_unsqueeze_486_dim_0'] -> ['unsqueeze_10']
    146/275: Constant [] -> ['aten_unsqueeze_487_dim_0']
    147/275: Unsqueeze ['sin', 'aten_unsqueeze_487_dim_0'] -> ['unsqueeze_11']
    148/275: Constant [] -> ['alpha__2']
    149/275: Mul ['slice_9', 'alpha__2'] -> ['other_1__2']
    150/275: Add ['slice_7', 'other_1__2'] -> ['add']
    151/275: Concat ['neg', 'slice_24'] -> ['cat_1']
    152/275: Concat ['neg_1', 'slice_26'] -> ['cat_2']
    153/275: Constant [] -> ['aten_view_494_size_0']
    154/275: Reshape ['expand_8', 'aten_view_494_size_0'] -> ['view_17']
    155/275: Mul ['transpose', 'unsqueeze_10'] -> ['mul_3']
    156/275: Mul ['transpose_1', 'unsqueeze_10'] -> ['mul_5']
    157/275: Constant [] -> ['scalar_tensor_default_1']
    158/275: Equal ['add', 'scalar_tensor_default_1'] -> ['eq']
    159/275: Mul ['cat_1', 'unsqueeze_11'] -> ['mul_4']
    160/275: Mul ['cat_2', 'unsqueeze_11'] -> ['mul_6']
    161/275: Transpose ['view_17'] -> ['transpose_8']
    162/275: Constant [] -> ['_val_372']
    163/275: Where ['eq', '_val_372', 'slice_7'] -> ['masked_fill']
    164/275: Constant [] -> ['alpha__3']
    165/275: Mul ['mul_4', 'alpha__3'] -> ['other_1__3']
    166/275: Add ['mul_3', 'other_1__3'] -> ['add_2']
    167/275: Constant [] -> ['alpha__4']
    168/275: Mul ['mul_6', 'alpha__4'] -> ['other_1__4']
    169/275: Add ['mul_5', 'other_1__4'] -> ['add_3']
    170/275: Constant [] -> ['aten_expand_509_size_1']
    171/275: Expand ['add_2', 'aten_expand_509_size_1'] -> ['expand_5']
    172/275: Transpose ['add_3'] -> ['transpose_4']
    173/275: Constant [] -> ['_val_395']
    174/275: Transpose ['masked_fill'] -> ['_val_396']
    175/275: Transpose ['slice_6'] -> ['_val_397']
    176/275: ScatterND ['_val_397', '_val_395', '_val_396'] -> ['_val_398']
    177/275: Transpose ['_val_398'] -> ['slice_scatter']
    178/275: Constant [] -> ['aten_expand_533_size_1']
    179/275: Expand ['transpose_4', 'aten_expand_533_size_1'] -> ['expand_6']
    180/275: Constant [] -> ['_val_418']
    181/275: Transpose ['slice_scatter'] -> ['_val_419']
    182/275: Transpose ['slice_5'] -> ['_val_420']
    183/275: ScatterND ['_val_420', '_val_418', '_val_419'] -> ['_val_421']
    184/275: Transpose ['_val_421'] -> ['slice_scatter_1']
    185/275: Constant [] -> ['aten_view_555_size_0']
    186/275: Reshape ['expand_5', 'aten_view_555_size_0'] -> ['view_13']
    187/275: Constant [] -> ['_val_441']
    188/275: ScatterND ['expand_1', '_val_441', 'slice_scatter_1'] -> ['slice_scatter_2']
    189/275: Transpose ['view_13'] -> ['transpose_9']
    190/275: Constant [] -> ['aten_view_576_size_0']
    191/275: Reshape ['expand_6', 'aten_view_576_size_0'] -> ['view_14']
    192/275: Constant [] -> ['_val_449']
    193/275: Constant [] -> ['_val_453']
    194/275: Constant [] -> ['_val_457']
    195/275: Constant [] -> ['_val_461']
    196/275: Slice ['slice_scatter_2', '_val_449', '_val_453', '_val_457', '_val_461'] -> ['slice_31']
    197/275: MatMul ['view_13', 'view_14'] -> ['bmm_1']
    198/275: Transpose ['view_14'] -> ['transpose_10']
    199/275: Constant [] -> ['_val_468']
    200/275: Constant [] -> ['_val_472']
    201/275: Constant [] -> ['_val_476']
    202/275: Constant [] -> ['_val_480']
    203/275: Slice ['slice_31', '_val_468', '_val_472', '_val_476', '_val_480'] -> ['slice_32']
    204/275: Constant [] -> ['aten_view_614_size_0']
    205/275: Reshape ['bmm_1', 'aten_view_614_size_0'] -> ['view_15']
    206/275: Constant [] -> ['_val_487']
    207/275: Constant [] -> ['_val_491']
    208/275: Constant [] -> ['_val_495']
    209/275: Constant [] -> ['_val_499']
    210/275: Slice ['slice_32', '_val_487', '_val_491', '_val_495', '_val_499'] -> ['slice_33']
    211/275: Constant [] -> ['_val_501']
    212/275: Div ['view_15', '_val_501'] -> ['div']
    213/275: Constant [] -> ['alpha__5']
    214/275: Mul ['slice_33', 'alpha__5'] -> ['other_1__5']
    215/275: Add ['div', 'other_1__5'] -> ['add_4']
    216/275: Softmax ['add_4'] -> ['_softmax']
    217/275: Constant [] -> ['aten_expand_640_size_1']
    218/275: Expand ['_softmax', 'aten_expand_640_size_1'] -> ['expand_7']
    219/275: Constant [] -> ['aten_view_643_size_0']
    220/275: Reshape ['expand_7', 'aten_view_643_size_0'] -> ['view_16']
    221/275: Identity ['_softmax'] -> ['detach_13']
    222/275: MatMul ['view_16', 'view_17'] -> ['bmm_2']
    223/275: Transpose ['view_16'] -> ['transpose_7']
    224/275: Constant [] -> ['aten_view_648_size_0']
    225/275: Reshape ['bmm_2', 'aten_view_648_size_0'] -> ['view_18']
    226/275: Transpose ['view_18'] -> ['transpose_5']
    227/275: Constant [] -> ['aten_view_652_size_0']
    228/275: Reshape ['transpose_5', 'aten_view_652_size_0'] -> ['view_19']
    229/275: Constant [] -> ['aten_view_654_size_0']
    230/275: Reshape ['view_19', 'aten_view_654_size_0'] -> ['view_20']
    231/275: MatMul ['view_20', 't_3'] -> ['mm_3']
    232/275: Constant [] -> ['aten_view_657_size_0']
    233/275: Reshape ['mm_3', 'aten_view_657_size_0'] -> ['view_21']
    234/275: Constant [] -> ['alpha__6']
    235/275: Mul ['view_21', 'alpha__6'] -> ['other_1__6']
    236/275: Add ['embedding', 'other_1__6'] -> ['add_5']
    237/275: Constant [] -> ['scalar_tensor_default_2']
    238/275: Pow ['add_5', 'scalar_tensor_default_2'] -> ['pow_2']
    239/275: Constant [] -> ['_val_531']
    240/275: ReduceMean ['pow_2', '_val_531'] -> ['mean_1']
    241/275: Constant [] -> ['aten_add_665_other_1']
    242/275: Add ['mean_1', 'aten_add_665_other_1'] -> ['add_6']
    243/275: Sqrt ['add_6'] -> ['aten_rsqrt_666_tmp']
    244/275: Reciprocal ['aten_rsqrt_666_tmp'] -> ['rsqrt_1']
    245/275: Mul ['add_5', 'rsqrt_1'] -> ['mul_7']
    246/275: Mul ['primals_2', 'mul_7'] -> ['mul_8']
    247/275: Constant [] -> ['aten_view_670_size_0']
    248/275: Reshape ['mul_8', 'aten_view_670_size_0'] -> ['view_22']
    249/275: MatMul ['view_22', 't_4'] -> ['mm_4']
    250/275: MatMul ['view_22', 't_5'] -> ['mm_5']
    251/275: Constant [] -> ['aten_view_674_size_0']
    252/275: Reshape ['mm_4', 'aten_view_674_size_0'] -> ['view_23']
    253/275: Constant [] -> ['aten_view_676_size_0']
    254/275: Reshape ['mm_5', 'aten_view_676_size_0'] -> ['view_25']
    255/275: Sigmoid ['view_23'] -> ['sigmoid']
    256/275: Mul ['view_23', 'sigmoid'] -> ['mul_9']
    257/275: Mul ['mul_9', 'view_25'] -> ['mul_10']
    258/275: Constant [] -> ['aten_view_681_size_0']
    259/275: Reshape ['mul_10', 'aten_view_681_size_0'] -> ['view_26']
    260/275: MatMul ['view_26', 't_6'] -> ['mm_6']
    261/275: Constant [] -> ['aten_view_684_size_0']
    262/275: Reshape ['mm_6', 'aten_view_684_size_0'] -> ['view_27']
    263/275: Constant [] -> ['alpha__7']
    264/275: Mul ['view_27', 'alpha__7'] -> ['other_1__7']
    265/275: Add ['add_5', 'other_1__7'] -> ['add_7']
    266/275: Constant [] -> ['scalar_tensor_default_3']
    267/275: Pow ['add_7', 'scalar_tensor_default_3'] -> ['pow_3']
    268/275: Constant [] -> ['_val_558']
    269/275: ReduceMean ['pow_3', '_val_558'] -> ['mean_2']
    270/275: Constant [] -> ['aten_add_692_other_1']
    271/275: Add ['mean_2', 'aten_add_692_other_1'] -> ['add_8']
    272/275: Sqrt ['add_8'] -> ['aten_rsqrt_693_tmp']
    273/275: Reciprocal ['aten_rsqrt_693_tmp'] -> ['rsqrt_2']
    274/275: Mul ['add_7', 'rsqrt_2'] -> ['mul_11']
    275/275: Mul ['primals_3', 'mul_11'] -> ['mul_12']
    [runpythonerror]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
      warnings.warn(
    2024-05-27 18:42:38,223 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-27 18:42:38,223 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue full due to large size 4194304.
    2024-05-27 18:42:38,255 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue triu due to large size 4194304.
    2024-05-27 18:42:38,306 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue convert_element_type_default due to large size 4194304.
    2024-05-27 18:42:38,315 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue mul due to large size 4194304.
    2024-05-27 18:42:38,329 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-27 18:42:38,329 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_3 due to large size 4194304.
    2024-05-27 18:42:38,332 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-27 18:42:38,333 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_4 due to large size 4194304.
    2024-05-27 18:42:38,340 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_3 due to large size 4194304.
    2024-05-27 18:42:38,347 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_4 due to large size 4194304.
    2024-05-27 18:42:38,359 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 8388608.
    2024-05-27 18:42:38,360 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue expand_1 due to large size 8388608.
    2024-05-27 18:42:38,366 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue clone due to large size 8388608.
    2024-05-27 18:42:38,373 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_5 due to large size 8388608.
    2024-05-27 18:42:38,378 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_6 due to large size 8388608.
    2024-05-27 18:42:38,405 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_7 due to large size 8388608.
    2024-05-27 18:42:38,420 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue _val_397 due to large size 8388608.
    2024-05-27 18:42:38,430 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue _val_420 due to large size 8388608.
    2024-05-27 18:42:38,635 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-27 18:42:38,635 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue full due to large size 4194304.
    2024-05-27 18:42:38,639 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue triu due to large size 4194304.
    2024-05-27 18:42:38,645 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue convert_element_type_default due to large size 4194304.
    2024-05-27 18:42:38,646 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue mul due to large size 4194304.
    2024-05-27 18:42:38,647 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_3 due to large size 4194304.
    2024-05-27 18:42:38,648 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_4 due to large size 4194304.
    2024-05-27 18:42:38,649 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_3 due to large size 4194304.
    2024-05-27 18:42:38,651 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_4 due to large size 4194304.
    2024-05-27 18:42:38,654 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue expand_1 due to large size 8388608.
    2024-05-27 18:42:38,659 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_5 due to large size 8388608.
    2024-05-27 18:42:38,662 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_6 due to large size 8388608.
    2024-05-27 18:42:38,668 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_7 due to large size 8388608.
    2024-05-27 18:42:38,673 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue _val_397 due to large size 8388608.
    2024-05-27 18:42:38,674 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue _val_420 due to large size 8388608.
    2024-05-27 18:42:38.818138500 [W:onnxruntime:, graph.cc:4051 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_23'. It is not used by any node and should be removed from the model.
    2024-05-27 18:42:38.818187900 [W:onnxruntime:, graph.cc:4051 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_22'. It is not used by any node and should be removed from the model.