Use the custom exporter in torch

Subject to change

File onnxruntime.py

This change enables the custom rewriter is an environment variable is enabled. Look for substring TODO:.

def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs):
    """This function replaces GraphModule._wrapped_call in compiled model.

    The _wrapped_call is the underlying implementation of forward method. Replacing
    it means we delegate the computation to _ort_acclerated_call and therefore
    onnxruntime.InferenceSession.
    """
    cached_execution_info_per_session = (
        self._all_ort_execution_info.search_reusable_session_execution_info(
            graph_module, *args
        )
    )
    if cached_execution_info_per_session:
        onnx_session = cached_execution_info_per_session.session
        input_names = cached_execution_info_per_session.input_names
        output_names = cached_execution_info_per_session.output_names
        input_value_infos = cached_execution_info_per_session.input_value_infos
        output_value_infos = cached_execution_info_per_session.output_value_infos
        input_devices = cached_execution_info_per_session.input_devices
        output_devices = cached_execution_info_per_session.output_devices
        prim_outputs = cached_execution_info_per_session.example_outputs
    else:
        # It's first time seeing such as graph. Let's make a new session
        # (type: onnxruntime.InferenceSession) for it.

        ##########################
        # TODO: Insert these lines
        ##########################

        use_other_rewriter = os.environ.get("ONNXRT_CHANGE_REWRITER", None) in (1, "1")
        if use_other_rewriter:
            from experimental_experiment.torch_interpreter import to_onnx
            from experimental_experiment.torch_interpreter._torch_models import create_input_names
            from experimental_experiment.xbuilder import OptimizationOptions
            from experimental_experiment.torch_interpreter.oxs_dispatcher import OxsDispatcher

            input_names = input_names = create_input_names(graph_module, args)
            dispatcher = OxsDispatcher()
            target_opset = self._resolved_onnx_exporter_options.onnx_registry.opset_version
            options = OptimizationOptions(
                remove_unused=True,
                constant_folding=False,
                patterns="default",
                verbose=1,
            )
            onnx_model, builder = to_onnx(
                graph_module,
                tuple(args),
                input_names=input_names,
                options=options,
                verbose=1,
                target_opset=target_opset,
                return_builder=True,
                dispatcher=dispatcher,
            )

            def maybe_map_to_meta_val(value):
                if hasattr(value, "meta") and "val" in value.meta:
                    # Select outputs with "val" information. Without "val",
                    # it's not possible access output_arg.meta["val"].device.
                    return value.meta["val"]
                return value

            extracted_outputs = _extract_graph_module_outputs(graph_module)
            prim_outputs = _pytree.tree_map(maybe_map_to_meta_val, extracted_outputs)

        else:

        ####################################
        # TODO: end of the insertion
        # TODO: indent what follows
        ####################################

            graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront(
                self._resolved_onnx_exporter_options.diagnostic_context,
                graph_module,
            ).run()
            # Generate reference outputs. They are used to indicate output
            # tensors' types and devices when calling ORT.
            #
            # WARNING: The downstream code should not change prim_outputs and
            # this backend should always produces output with schema identical to prim_outputs'.

            if self._resolved_onnx_exporter_options.dynamic_shapes:
                # No pre-allocation when dynamic shape is enabled.
                self.preallocate_output = False
                extracted_outputs = _extract_graph_module_outputs(graph_module)

                def maybe_map_to_meta_val(value):
                    if hasattr(value, "meta") and "val" in value.meta:
                        # Select outputs with "val" information. Without "val",
                        # it's not possible access output_arg.meta["val"].device.
                        return value.meta["val"]
                    else:
                        return value

                prim_outputs = _pytree.tree_map(
                    maybe_map_to_meta_val, extracted_outputs
                )
            else:
                try:
                    prim_outputs = FakeTensorProp(graph_module).propagate(
                        *args, **kwargs
                    )
                except Exception:
                    logger.warning("FakeTensorProb failed for %s", graph_module)
                    # When FakeTensorProp fails, it is not possible to preallocate output buffers
                    # because the output shapes are not inferred.
                    self.preallocate_output = False

                    # rethrow FakeTensorProb failure because it is not yet currently handled.
                    raise

            # Create the object to iterate through the nodes in graph one-by-one
            # and calls the corresponding ONNX exporter for each node.
            fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
                diagnostic_context=self._resolved_onnx_exporter_options.diagnostic_context
            )
            # Cast FX variables if they will result schema-mismatch when searching
            # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch,
            # but ONNX expects add(double_tensor, double_tensor).
            graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion(
                self._resolved_onnx_exporter_options.diagnostic_context, graph_module
            ).run()
            # Start the per-node exporting process. It's conceptually a for loop
            # scanning through the nodes in the graph.
            exported = fx_interpreter.run(
                fx_graph_module=graph_module,
                onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher,
                op_level_debug=self._resolved_onnx_exporter_options.op_level_debug,
            )
            # Convert the exported result to ONNX ModelProto.
            onnx_model = exported.to_model_proto(
                opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version,
            )

        ####################################
        # TODO: end of the modification
        ####################################

        # Modify ONNX model using pre-registered graph transforms.
        # They are in-place modifications for avoiding unnecessary
        # copy of ONNX initializers.
        if self._options.pre_ort_model_transforms:
            for transform in self._options.pre_ort_model_transforms:
                transform(onnx_model)

        onnx_model_bytes = onnx_model.SerializeToString()
        if os.environ.get("ONNXRT_DUMP_PATH", None):
            # If not empty, environment variable ONNXRT_DUMP_PATH defined the path
            # where generated onnx files should be stored.
            # This module keeps a global variables keeping track of the
            # stored models.
            # If ONNXRT_DUMP_PATH="dumped/dumped_model_"
            # The first file name will be 'dumped/dumped_model_0.onnx'.
            # For every dumped model, a text file 'dumped/dumped_model_0.txt'
            # is created as well to contain the string representing the graph_module.
            _dump_onnx_model(onnx_model_bytes, graph_module=graph_module)

        # Initialize a ORT session to execute this ONNX model.
        # Note that TorchDynamo assumes all inputs/outputs are on the
        # same device, but it's subject to change (very likely with
        # dynamic shape support), so we add execution providers
        # based on the logic in _select_eps: (explicitly preferred EPs,
        # EPs inferred from inputs or graph, and the fallback default EP)/
        #
        # TODO(wschin): enable external allocators.
        # See https://github.com/pytorch/pytorch/issues/106867
        onnx_session = onnxruntime.InferenceSession(
            path_or_bytes=onnx_model_bytes,
            sess_options=self._options.ort_session_options,
            providers=self._select_eps(graph_module, *args),
        )

        # Cache ORT session. It's reused for the same "graph_module".
        # Generate ONNX model and extract its input and output names.
        input_names = tuple(input.name for input in onnx_model.graph.input)
        output_names = tuple(output.name for output in onnx_model.graph.output)
        input_devices = _get_onnx_devices(args)
        # Cache devices for inputs and outputs. They are used to invoke
        # ORT session. Output devices indicate where (e.g., GPU or CPU)
        # to store outputs
        if isinstance(prim_outputs, tuple):
            output_devices = _get_onnx_devices(prim_outputs)
        else:
            output_devices = _get_onnx_devices((prim_outputs,))

        input_value_infos = tuple(input for input in onnx_model.graph.input)
        output_value_infos = tuple(output for output in onnx_model.graph.output)

        execution_info_per_session = OrtExecutionInfoPerSession(
            session=onnx_session,
            input_names=input_names,
            input_value_infos=input_value_infos,
            output_names=output_names,
            output_value_infos=output_value_infos,
            input_devices=input_devices,
            output_devices=output_devices,
            example_outputs=prim_outputs,
        )

        self._all_ort_execution_info.cache_session_execution_info(
            graph_module, execution_info_per_session
        )

    self.execution_count += 1

    # ORT always returns a tuple of outputs. If the original output is a tensor,
    # ORT output's first element must be extracted and returned. Otherwise, type
    # mismatch may happen in downstream computation.
    is_single_tensor_output = isinstance(prim_outputs, torch.Tensor)
    normalized_prim_outputs = (
        (prim_outputs,) if is_single_tensor_output else prim_outputs
    )
    assert isinstance(normalized_prim_outputs, tuple)
    assert all(
        isinstance(elem, (torch.Tensor, torch.SymInt, int))
        for elem in normalized_prim_outputs
    )

    _nvtx_range_push("run_onnx_session_with_ortvaluevector")
    onnx_outputs = self.run(
        onnx_session,
        input_names,
        args,
        input_devices,
        output_names,
        normalized_prim_outputs,
        output_devices,
        self._options.preallocate_output,
        input_value_infos,
        normalized_prim_outputs,
    )
    _nvtx_range_pop()

    if self._assert_allclose_to_baseline:
        # Compute baseline.
        baseline_outputs = torch._prims.executor.execute(
            graph_module, *args, executor="aten"
        )
        normalized_baseline_ouptuts = (
            (baseline_outputs,) if is_single_tensor_output else baseline_outputs
        )
        # Ensure every output tensor is close to the corresponding baseline.
        for onnx_output, baseline_output in zip(
            onnx_outputs, normalized_baseline_ouptuts
        ):
            torch.testing.assert_close(onnx_output, baseline_output)
    return onnx_outputs[0] if is_single_tensor_output else onnx_outputs

Examples

Baseline

<<<

import os
import warnings
import numpy as np
import onnx
import torch
import torch.onnx
from experimental_experiment.torch_models.training_helper import (
    make_aot_ort,
    train_loop,
)
from experimental_experiment.torch_models.dump_helper import dump_onnx

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    from transformers import LlamaConfig
    from transformers.models.llama.modeling_llama import LlamaModel


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = LlamaConfig(
    hidden_size=16,
    num_hidden_layers=1,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=1024,
    num_attention_heads=2,
)
config._attn_implementation = "eager"

model = LlamaModel(config)

batch, seq, vocab_size = 2, 1024, 1024

input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

model(input_ids, input_mask)

local_aot_ort, _ = make_aot_ort(
    dynamic=True,
    rewrite=True,
    verbose=1,
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
    with dump_onnx("dort-llama-ort", folder="dump_llama", clean=True):
        train_loop(optimized_mod, input_ids, input_mask)

names = [_ for _ in os.listdir("dump_llama") if _.endswith(".onnx")]
print("------------------------------------------")
print(f"exported model: {names}")
for name in names:
    print()
    print("NODES in {name!r}")
    onx = onnx.load(os.path.join("dump_llama", name))
    for i, node in enumerate(onx.graph.node):
        print(
            f"{i+1}/{len(onx.graph.node)}: {node.op_type} {node.input} -> {node.output}"
        )

>>>

    [2024-05-08 14:07:04,496] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    Applied 1 of general pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific pattern rewrite rules.
    ------------------------------------------
    exported model: ['dort-llama-ort_1.onnx', 'dort-llama-ort_0.onnx']
    
    NODES in {name!r}
    1/305: Constant [] -> ['aten_view_111_size_0']
    2/305: Reshape ['mm_5', 'aten_view_111_size_0'] -> ['view_25']
    3/305: Cos ['cat'] -> ['cos']
    4/305: Mul ['embedding', 'rsqrt'] -> ['mul_2']
    5/305: Constant [] -> ['_val_38']
    6/305: Pow ['add_6', '_val_38'] -> ['pow_5']
    7/305: Constant [] -> ['_val_41']
    8/305: Equal ['primals_13', '_val_41'] -> ['eq_2']
    9/305: Constant [] -> ['aten_view_121_size_0']
    10/305: Reshape ['mm_3', 'aten_view_121_size_0'] -> ['view_21']
    11/305: Mul ['tangents_1', 'primals_3'] -> ['mul_14']
    12/305: Constant [] -> ['_val_46']
    13/305: Pow ['embedding', '_val_46'] -> ['pow_9']
    14/305: Mul ['add_6', 'rsqrt_2'] -> ['mul_12']
    15/305: Constant [] -> ['aten_view_131_size_0']
    16/305: Reshape ['mm_4', 'aten_view_131_size_0'] -> ['view_23']
    17/305: Sin ['cat'] -> ['sin']
    18/305: Constant [] -> ['aten_unsqueeze_134_dim_0']
    19/305: Unsqueeze ['cos', 'aten_unsqueeze_134_dim_0'] -> ['unsqueeze_10']
    20/305: Constant [] -> ['_val_58']
    21/305: Mul ['pow_5', '_val_58'] -> ['mul_20']
    22/305: Constant [] -> ['aten_unsqueeze_138_dim_0']
    23/305: Unsqueeze ['eq_2', 'aten_unsqueeze_138_dim_0'] -> ['unsqueeze_12']
    24/305: Constant [] -> ['alpha__1']
    25/305: Mul ['view_21', 'alpha__1'] -> ['other_1__1']
    26/305: Add ['embedding', 'other_1__1'] -> ['add_4']
    27/305: Mul ['mul_14', 'add_6'] -> ['mul_16']
    28/305: Mul ['mul_14', 'rsqrt_2'] -> ['mul_17']
    29/305: Constant [] -> ['_val_65']
    30/305: Mul ['pow_9', '_val_65'] -> ['mul_47']
    31/305: Mul ['tangents_1', 'mul_12'] -> ['mul_15']
    32/305: Constant [] -> ['fill']
    33/305: Mul ['view_23', 'sigmoid'] -> ['mul_10']
    34/305: Constant [] -> ['aten_unsqueeze_149_dim_0']
    35/305: Unsqueeze ['sin', 'aten_unsqueeze_149_dim_0'] -> ['unsqueeze_11']
    36/305: Mul ['add_4', 'rsqrt_1'] -> ['mul_8']
    37/305: Constant [] -> ['_val_76']
    38/305: Pow ['add_4', '_val_76'] -> ['pow_7']
    39/305: Constant [] -> ['_val_78']
    40/305: ReduceSum ['mul_16', '_val_78'] -> ['sum_2']
    41/305: Constant [] -> ['_val_80']
    42/305: ReduceSum ['mul_15', '_val_80'] -> ['sum_1']
    43/305: Constant [] -> ['alpha__2']
    44/305: Mul ['sigmoid', 'alpha__2'] -> ['other_1__2']
    45/305: Sub ['fill', 'other_1__2'] -> ['sub']
    46/305: Constant [] -> ['_val_86']
    47/305: Mul ['pow_7', '_val_86'] -> ['mul_33']
    48/305: Constant [] -> ['_val_88']
    49/305: Mul ['sum_2', '_val_88'] -> ['mul_18']
    50/305: Constant [] -> ['aten_view_168_size_0']
    51/305: Reshape ['sum_1', 'aten_view_168_size_0'] -> ['view_28']
    52/305: Mul ['view_23', 'sub'] -> ['mul_24']
    53/305: Constant [] -> ['scalar_tensor_default']
    54/305: Pow ['rsqrt', 'scalar_tensor_default'] -> ['pow_8']
    55/305: Constant [] -> ['scalar_tensor_default_1']
    56/305: Pow ['rsqrt_1', 'scalar_tensor_default_1'] -> ['pow_6']
    57/305: Constant [] -> ['scalar_tensor_default_2']
    58/305: Pow ['rsqrt_2', 'scalar_tensor_default_2'] -> ['pow_4']
    59/305: Constant [] -> ['aten_add_182_other_1']
    60/305: Add ['mul_24', 'aten_add_182_other_1'] -> ['add_9']
    61/305: Mul ['mul_18', 'pow_4'] -> ['mul_19']
    62/305: Mul ['sigmoid', 'add_9'] -> ['mul_25']
    63/305: Constant [] -> ['aten_expand_186_size_1']
    64/305: Expand ['mul_19', 'aten_expand_186_size_1'] -> ['expand_9']
    65/305: Constant [] -> ['scalar_tensor_default_4']
    66/305: Div ['expand_9', 'scalar_tensor_default_4'] -> ['div_1']
    67/305: Mul ['div_1', 'mul_20'] -> ['mul_21']
    68/305: Constant [] -> ['alpha__3']
    69/305: Mul ['mul_21', 'alpha__3'] -> ['other_1__3']
    70/305: Add ['mul_17', 'other_1__3'] -> ['add_8']
    71/305: Constant [] -> ['aten_view_193_size_0']
    72/305: Reshape ['add_8', 'aten_view_193_size_0'] -> ['view_29']
    73/305: Transpose ['view_29'] -> ['t_7']
    74/305: MatMul ['view_29', 't_9'] -> ['mm_8']
    75/305: MatMul ['t_7', 'view_26'] -> ['mm_7']
    76/305: Constant [] -> ['aten_view_198_size_0']
    77/305: Reshape ['mm_8', 'aten_view_198_size_0'] -> ['view_30']
    78/305: Transpose ['mm_7'] -> ['t_8']
    79/305: Mul ['view_30', 'mul_10'] -> ['mul_22']
    80/305: Mul ['view_30', 'view_25'] -> ['mul_23']
    81/305: Transpose ['t_8'] -> ['t_10']
    82/305: Constant [] -> ['aten_view_204_size_0']
    83/305: Reshape ['mul_22', 'aten_view_204_size_0'] -> ['view_31']
    84/305: Mul ['mul_23', 'mul_25'] -> ['mul_26']
    85/305: Transpose ['view_31'] -> ['t_11']
    86/305: MatMul ['view_31', 't_13'] -> ['mm_10']
    87/305: Constant [] -> ['aten_view_209_size_0']
    88/305: Reshape ['mul_26', 'aten_view_209_size_0'] -> ['view_33']
    89/305: MatMul ['t_11', 'view_22'] -> ['mm_9']
    90/305: Constant [] -> ['aten_view_212_size_0']
    91/305: Reshape ['mm_10', 'aten_view_212_size_0'] -> ['view_32']
    92/305: Transpose ['view_33'] -> ['t_15']
    93/305: MatMul ['view_33', 't_17'] -> ['mm_12']
    94/305: Transpose ['mm_9'] -> ['t_12']
    95/305: MatMul ['t_15', 'view_22'] -> ['mm_11']
    96/305: Constant [] -> ['aten_view_218_size_0']
    97/305: Reshape ['mm_12', 'aten_view_218_size_0'] -> ['view_34']
    98/305: Transpose ['t_12'] -> ['t_14']
    99/305: Transpose ['mm_11'] -> ['t_16']
    100/305: Constant [] -> ['alpha__4']
    101/305: Mul ['view_34', 'alpha__4'] -> ['other_1__4']
    102/305: Add ['view_32', 'other_1__4'] -> ['add_10']
    103/305: Transpose ['t_16'] -> ['t_18']
    104/305: Mul ['add_10', 'primals_2'] -> ['mul_27']
    105/305: Mul ['add_10', 'mul_8'] -> ['mul_28']
    106/305: Mul ['mul_27', 'add_4'] -> ['mul_29']
    107/305: Mul ['mul_27', 'rsqrt_1'] -> ['mul_30']
    108/305: Constant [] -> ['_val_150']
    109/305: ReduceSum ['mul_28', '_val_150'] -> ['sum_3']
    110/305: Constant [] -> ['_val_152']
    111/305: ReduceSum ['mul_29', '_val_152'] -> ['sum_4']
    112/305: Constant [] -> ['alpha__5']
    113/305: Mul ['mul_30', 'alpha__5'] -> ['other_1__5']
    114/305: Add ['add_8', 'other_1__5'] -> ['add_11']
    115/305: Constant [] -> ['aten_view_233_size_0']
    116/305: Reshape ['sum_3', 'aten_view_233_size_0'] -> ['view_35']
    117/305: Constant [] -> ['_val_157']
    118/305: Mul ['sum_4', '_val_157'] -> ['mul_31']
    119/305: Mul ['mul_31', 'pow_6'] -> ['mul_32']
    120/305: Constant [] -> ['aten_expand_238_size_1']
    121/305: Expand ['mul_32', 'aten_expand_238_size_1'] -> ['expand_10']
    122/305: Constant [] -> ['scalar_tensor_default_5']
    123/305: Div ['expand_10', 'scalar_tensor_default_5'] -> ['div_2']
    124/305: Mul ['div_2', 'mul_33'] -> ['mul_34']
    125/305: Constant [] -> ['alpha__6']
    126/305: Mul ['mul_34', 'alpha__6'] -> ['other_1__6']
    127/305: Add ['add_11', 'other_1__6'] -> ['add_12']
    128/305: Constant [] -> ['aten_view_245_size_0']
    129/305: Reshape ['add_12', 'aten_view_245_size_0'] -> ['view_36']
    130/305: Transpose ['view_36'] -> ['t_19']
    131/305: MatMul ['view_36', 't_21'] -> ['mm_14']
    132/305: MatMul ['t_19', 'view_20'] -> ['mm_13']
    133/305: Constant [] -> ['aten_view_250_size_0']
    134/305: Reshape ['mm_14', 'aten_view_250_size_0'] -> ['view_37']
    135/305: Transpose ['mm_13'] -> ['t_20']
    136/305: Constant [] -> ['aten_view_253_size_0']
    137/305: Reshape ['view_37', 'aten_view_253_size_0'] -> ['view_38']
    138/305: Transpose ['t_20'] -> ['t_22']
    139/305: Transpose ['view_38'] -> ['transpose_6']
    140/305: Constant [] -> ['aten_view_258_size_0']
    141/305: Reshape ['transpose_6', 'aten_view_258_size_0'] -> ['view_39']
    142/305: MatMul ['transpose_7', 'view_39'] -> ['bmm_3']
    143/305: MatMul ['view_39', 'transpose_8'] -> ['bmm_4']
    144/305: Constant [] -> ['aten_view_262_size_0']
    145/305: Reshape ['bmm_3', 'aten_view_262_size_0'] -> ['view_40']
    146/305: Constant [] -> ['aten_view_264_size_0']
    147/305: Reshape ['bmm_4', 'aten_view_264_size_0'] -> ['view_41']
    148/305: Constant [] -> ['alpha__7']
    149/305: Mul ['view_40', 'alpha__7'] -> ['other_1__7']
    150/305: Add ['tangents_3', 'other_1__7'] -> ['add_13']
    151/305: Mul ['view_41', 'detach_13'] -> ['mul_35']
    152/305: Transpose ['add_13'] -> ['transpose_12']
    153/305: Constant [] -> ['_val_191']
    154/305: ReduceSum ['mul_35', '_val_191'] -> ['sum_5']
    155/305: Mul ['detach_13', 'sum_5'] -> ['mul_36']
    156/305: Constant [] -> ['aten_view_273_size_0']
    157/305: Reshape ['transpose_12', 'aten_view_273_size_0'] -> ['view_45']
    158/305: Constant [] -> ['alpha__8']
    159/305: Mul ['mul_36', 'alpha__8'] -> ['other_1__8']
    160/305: Sub ['mul_35', 'other_1__8'] -> ['sub_1']
    161/305: Constant [] -> ['aten_view_276_size_0']
    162/305: Reshape ['view_45', 'aten_view_276_size_0'] -> ['view_48']
    163/305: Constant [] -> ['_val_200']
    164/305: Div ['sub_1', '_val_200'] -> ['div_3']
    165/305: Transpose ['view_48'] -> ['t_23']
    166/305: MatMul ['view_48', 't_25'] -> ['mm_16']
    167/305: Constant [] -> ['aten_view_282_size_0']
    168/305: Reshape ['div_3', 'aten_view_282_size_0'] -> ['view_42']
    169/305: MatMul ['t_23', 'view_1'] -> ['mm_15']
    170/305: Constant [] -> ['aten_view_285_size_0']
    171/305: Reshape ['mm_16', 'aten_view_285_size_0'] -> ['view_49']
    172/305: MatMul ['transpose_9', 'view_42'] -> ['bmm_5']
    173/305: MatMul ['view_42', 'transpose_10'] -> ['bmm_6']
    174/305: Transpose ['mm_15'] -> ['t_24']
    175/305: Constant [] -> ['aten_view_290_size_0']
    176/305: Reshape ['bmm_5', 'aten_view_290_size_0'] -> ['view_43']
    177/305: Constant [] -> ['aten_view_292_size_0']
    178/305: Reshape ['bmm_6', 'aten_view_292_size_0'] -> ['view_44']
    179/305: Transpose ['t_24'] -> ['t_26']
    180/305: Transpose ['view_43'] -> ['transpose_11']
    181/305: Mul ['view_44', 'unsqueeze_11'] -> ['mul_39']
    182/305: Mul ['view_44', 'unsqueeze_10'] -> ['mul_40']
    183/305: Constant [] -> ['alpha__9']
    184/305: Mul ['transpose_11', 'alpha__9'] -> ['other_1__9']
    185/305: Add ['tangents_2', 'other_1__9'] -> ['add_14']
    186/305: Constant [] -> ['_val_224']
    187/305: Constant [] -> ['_val_228']
    188/305: Constant [] -> ['_val_232']
    189/305: Constant [] -> ['_val_236']
    190/305: Slice ['mul_39', '_val_224', '_val_228', '_val_232', '_val_236'] -> ['slice_22']
    191/305: Constant [] -> ['_val_241']
    192/305: Constant [] -> ['_val_245']
    193/305: Constant [] -> ['_val_249']
    194/305: Constant [] -> ['_val_253']
    195/305: Slice ['mul_39', '_val_241', '_val_245', '_val_249', '_val_253'] -> ['slice_23']
    196/305: Mul ['add_14', 'unsqueeze_11'] -> ['mul_37']
    197/305: Mul ['add_14', 'unsqueeze_10'] -> ['mul_38']
    198/305: Neg ['slice_22'] -> ['neg_3']
    199/305: Constant [] -> ['_val_263']
    200/305: Constant [] -> ['_val_267']
    201/305: Constant [] -> ['_val_271']
    202/305: Constant [] -> ['_val_275']
    203/305: Slice ['mul_37', '_val_263', '_val_267', '_val_271', '_val_275'] -> ['slice_20']
    204/305: Constant [] -> ['_val_280']
    205/305: Constant [] -> ['_val_284']
    206/305: Constant [] -> ['_val_288']
    207/305: Constant [] -> ['_val_292']
    208/305: Slice ['mul_37', '_val_280', '_val_284', '_val_288', '_val_292'] -> ['slice_21']
    209/305: Constant [] -> ['_val_311']
    210/305: Transpose ['slice_23'] -> ['_val_312']
    211/305: Constant [] -> ['_val_313']
    212/305: ScatterND ['_val_313', '_val_311', '_val_312'] -> ['_val_314']
    213/305: Transpose ['_val_314'] -> ['slice_scatter_3']
    214/305: Neg ['slice_20'] -> ['neg_2']
    215/305: Constant [] -> ['_val_334']
    216/305: Transpose ['neg_3'] -> ['_val_335']
    217/305: Constant [] -> ['_val_336']
    218/305: ScatterND ['_val_336', '_val_334', '_val_335'] -> ['_val_337']
    219/305: Transpose ['_val_337'] -> ['slice_scatter_2']
    220/305: Constant [] -> ['_val_356']
    221/305: Transpose ['slice_21'] -> ['_val_357']
    222/305: Constant [] -> ['_val_358']
    223/305: ScatterND ['_val_358', '_val_356', '_val_357'] -> ['_val_359']
    224/305: Transpose ['_val_359'] -> ['slice_scatter_1']
    225/305: Constant [] -> ['alpha__10']
    226/305: Mul ['slice_scatter_3', 'alpha__10'] -> ['other_1__10']
    227/305: Add ['slice_scatter_2', 'other_1__10'] -> ['add_17']
    228/305: Constant [] -> ['_val_377']
    229/305: Transpose ['neg_2'] -> ['_val_378']
    230/305: Constant [] -> ['_val_379']
    231/305: ScatterND ['_val_379', '_val_377', '_val_378'] -> ['_val_380']
    232/305: Transpose ['_val_380'] -> ['slice_scatter']
    233/305: Constant [] -> ['alpha__11']
    234/305: Mul ['mul_40', 'alpha__11'] -> ['other_1__11']
    235/305: Add ['add_17', 'other_1__11'] -> ['add_18']
    236/305: Constant [] -> ['alpha__12']
    237/305: Mul ['slice_scatter_1', 'alpha__12'] -> ['other_1__12']
    238/305: Add ['slice_scatter', 'other_1__12'] -> ['add_15']
    239/305: Transpose ['add_18'] -> ['transpose_14']
    240/305: Constant [] -> ['alpha__13']
    241/305: Mul ['mul_38', 'alpha__13'] -> ['other_1__13']
    242/305: Add ['add_15', 'other_1__13'] -> ['add_16']
    243/305: Transpose ['add_16'] -> ['transpose_13']
    244/305: Constant [] -> ['aten_view_466_size_0']
    245/305: Reshape ['transpose_14', 'aten_view_466_size_0'] -> ['view_47']
    246/305: Constant [] -> ['aten_view_469_size_0']
    247/305: Reshape ['view_47', 'aten_view_469_size_0'] -> ['view_52']
    248/305: Constant [] -> ['aten_view_471_size_0']
    249/305: Reshape ['transpose_13', 'aten_view_471_size_0'] -> ['view_46']
    250/305: Transpose ['view_52'] -> ['t_31']
    251/305: MatMul ['view_52', 't_33'] -> ['mm_20']
    252/305: Constant [] -> ['aten_view_475_size_0']
    253/305: Reshape ['view_46', 'aten_view_475_size_0'] -> ['view_50']
    254/305: MatMul ['t_31', 'view_1'] -> ['mm_19']
    255/305: Constant [] -> ['aten_view_478_size_0']
    256/305: Reshape ['mm_20', 'aten_view_478_size_0'] -> ['view_53']
    257/305: Transpose ['view_50'] -> ['t_27']
    258/305: MatMul ['view_50', 't_29'] -> ['mm_18']
    259/305: Transpose ['mm_19'] -> ['t_32']
    260/305: MatMul ['t_27', 'view_1'] -> ['mm_17']
    261/305: Constant [] -> ['aten_view_484_size_0']
    262/305: Reshape ['mm_18', 'aten_view_484_size_0'] -> ['view_51']
    263/305: Transpose ['t_32'] -> ['t_34']
    264/305: Transpose ['mm_17'] -> ['t_28']
    265/305: Constant [] -> ['alpha__14']
    266/305: Mul ['view_51', 'alpha__14'] -> ['other_1__14']
    267/305: Add ['view_49', 'other_1__14'] -> ['add_19']
    268/305: Transpose ['t_28'] -> ['t_30']
    269/305: Constant [] -> ['alpha__15']
    270/305: Mul ['view_53', 'alpha__15'] -> ['other_1__15']
    271/305: Add ['add_19', 'other_1__15'] -> ['add_20']
    272/305: Mul ['add_20', 'primals_1'] -> ['mul_41']
    273/305: Mul ['add_20', 'mul_2'] -> ['mul_42']
    274/305: Mul ['mul_41', 'embedding'] -> ['mul_43']
    275/305: Mul ['mul_41', 'rsqrt'] -> ['mul_44']
    276/305: Constant [] -> ['_val_417']
    277/305: ReduceSum ['mul_42', '_val_417'] -> ['sum_6']
    278/305: Constant [] -> ['_val_419']
    279/305: ReduceSum ['mul_43', '_val_419'] -> ['sum_7']
    280/305: Constant [] -> ['alpha__16']
    281/305: Mul ['mul_44', 'alpha__16'] -> ['other_1__16']
    282/305: Add ['add_12', 'other_1__16'] -> ['add_21']
    283/305: Constant [] -> ['aten_view_500_size_0']
    284/305: Reshape ['sum_6', 'aten_view_500_size_0'] -> ['view_54']
    285/305: Constant [] -> ['_val_424']
    286/305: Mul ['sum_7', '_val_424'] -> ['mul_45']
    287/305: Mul ['mul_45', 'pow_8'] -> ['mul_46']
    288/305: Constant [] -> ['aten_expand_505_size_1']
    289/305: Expand ['mul_46', 'aten_expand_505_size_1'] -> ['expand_11']
    290/305: Constant [] -> ['scalar_tensor_default_6']
    291/305: Div ['expand_11', 'scalar_tensor_default_6'] -> ['div_4']
    292/305: Mul ['div_4', 'mul_47'] -> ['mul_48']
    293/305: Constant [] -> ['alpha__17']
    294/305: Mul ['mul_48', 'alpha__17'] -> ['other_1__17']
    295/305: Add ['add_21', 'other_1__17'] -> ['add_22']
    296/305: Constant [] -> ['aten_masked_fill_512_value_cast']
    297/305: Where ['unsqueeze_12', 'aten_masked_fill_512_value_cast', 'add_22'] -> ['masked_fill_1']
    298/305: Constant [] -> ['_val_436']
    299/305: ConstantOfShape ['_val_436'] -> ['aten_new_zeros_514_result']
    300/305: SequenceConstruct ['primals_13'] -> ['438']
    301/305: Constant [] -> ['int64_0__18']
    302/305: SequenceAt ['438', 'int64_0__18'] -> ['index__18']
    303/305: Constant [] -> ['int64_m1_1d__18']
    304/305: Unsqueeze ['index__18', 'int64_m1_1d__18'] -> ['new_index__18']
    305/305: ScatterND ['aten_new_zeros_514_result', 'new_index__18', 'masked_fill_1'] -> ['_unsafe_index_put']
    
    NODES in {name!r}
    1/248: Gather ['primals_4', 'primals_13'] -> ['embedding']
    2/248: Transpose ['primals_8'] -> ['t_3']
    3/248: Constant [] -> ['_val_22']
    4/248: Constant [] -> ['_val_23']
    5/248: Constant [] -> ['size_0__1']
    6/248: Constant [] -> ['fill_value_1__1']
    7/248: Expand ['fill_value_1__1', 'size_0__1'] -> ['full']
    8/248: Constant [] -> ['_val_36']
    9/248: Constant [] -> ['_val_40']
    10/248: Constant [] -> ['_val_44']
    11/248: Constant [] -> ['_val_48']
    12/248: Slice ['primals_14', '_val_36', '_val_40', '_val_44', '_val_48'] -> ['slice_5']
    13/248: Transpose ['primals_9'] -> ['t_4']
    14/248: Transpose ['primals_10'] -> ['t_5']
    15/248: Transpose ['primals_11'] -> ['t_6']
    16/248: Transpose ['primals_5'] -> ['t']
    17/248: Transpose ['primals_6'] -> ['t_1']
    18/248: Transpose ['primals_7'] -> ['t_2']
    19/248: Constant [] -> ['aten_unsqueeze_155_dim_0']
    20/248: Unsqueeze ['primals_12', 'aten_unsqueeze_155_dim_0'] -> ['unsqueeze_7']
    21/248: Constant [] -> ['scalar_tensor_default']
    22/248: Pow ['embedding', 'scalar_tensor_default'] -> ['pow_1']
    23/248: Transpose ['t_3'] -> ['t_21']
    24/248: Constant [] -> ['aten_triu_163_diagonal']
    25/248: Trilu ['full', 'aten_triu_163_diagonal'] -> ['triu']
    26/248: Constant [] -> ['aten_unsqueeze_164_dim_0']
    27/248: Unsqueeze ['slice_5', 'aten_unsqueeze_164_dim_0'] -> ['unsqueeze_5']
    28/248: Transpose ['t_4'] -> ['t_17']
    29/248: Transpose ['t_5'] -> ['t_13']
    30/248: Transpose ['t_6'] -> ['t_9']
    31/248: Transpose ['t'] -> ['t_33']
    32/248: Transpose ['t_1'] -> ['t_29']
    33/248: Transpose ['t_2'] -> ['t_25']
    34/248: Constant [] -> ['_val_75']
    35/248: Constant [] -> ['_val_79']
    36/248: Constant [] -> ['_val_83']
    37/248: Constant [] -> ['_val_87']
    38/248: Slice ['unsqueeze_7', '_val_75', '_val_79', '_val_83', '_val_87'] -> ['slice_7']
    39/248: Constant [] -> ['gt']
    40/248: Constant [] -> ['_val_107']
    41/248: ReduceMean ['pow_1', '_val_107'] -> ['mean']
    42/248: Constant [] -> ['aten_unsqueeze_208_dim_0']
    43/248: Unsqueeze ['unsqueeze_5', 'aten_unsqueeze_208_dim_0'] -> ['unsqueeze_6']
    44/248: Constant [] -> ['aten_unsqueeze_209_dim_0']
    45/248: Unsqueeze ['slice_7', 'aten_unsqueeze_209_dim_0'] -> ['unsqueeze_8']
    46/248: Cast ['gt'] -> ['convert_element_type_default']
    47/248: Mul ['triu', 'convert_element_type_default'] -> ['mul']
    48/248: Constant [] -> ['aten_add_214_other_1']
    49/248: Add ['mean', 'aten_add_214_other_1'] -> ['add']
    50/248: Constant [] -> ['_val_119']
    51/248: Constant [] -> ['_val_123']
    52/248: Constant [] -> ['_val_127']
    53/248: Constant [] -> ['_val_131']
    54/248: Slice ['unsqueeze_6', '_val_119', '_val_123', '_val_127', '_val_131'] -> ['slice_6']
    55/248: Constant [] -> ['aten_expand_233_size_1']
    56/248: Expand ['unsqueeze_8', 'aten_expand_233_size_1'] -> ['expand_2']
    57/248: Constant [] -> ['aten_unsqueeze_251_dim_0']
    58/248: Unsqueeze ['mul', 'aten_unsqueeze_251_dim_0'] -> ['unsqueeze_3']
    59/248: Sqrt ['add'] -> ['aten_rsqrt_252_tmp']
    60/248: Reciprocal ['aten_rsqrt_252_tmp'] -> ['rsqrt']
    61/248: Constant [] -> ['_val_154']
    62/248: Equal ['slice_6', '_val_154'] -> ['eq_1']
    63/248: Constant [] -> ['aten_expand_256_size_1']
    64/248: Expand ['expand_2', 'aten_expand_256_size_1'] -> ['expand_3']
    65/248: Constant [] -> ['aten_unsqueeze_258_dim_0']
    66/248: Unsqueeze ['unsqueeze_3', 'aten_unsqueeze_258_dim_0'] -> ['unsqueeze_4']
    67/248: Mul ['embedding', 'rsqrt'] -> ['mul_2']
    68/248: Constant [] -> ['_val_168']
    69/248: Constant [] -> ['_val_172']
    70/248: Constant [] -> ['_val_176']
    71/248: Constant [] -> ['_val_180']
    72/248: Slice ['unsqueeze_4', '_val_168', '_val_172', '_val_176', '_val_180'] -> ['slice_3']
    73/248: Mul ['primals_1', 'mul_2'] -> ['mul_3']
    74/248: Constant [] -> ['view_11']
    75/248: Constant [] -> ['_val_188']
    76/248: Constant [] -> ['_val_192']
    77/248: Constant [] -> ['_val_196']
    78/248: Constant [] -> ['_val_200']
    79/248: Slice ['slice_3', '_val_188', '_val_192', '_val_196', '_val_200'] -> ['slice_4']
    80/248: Constant [] -> ['aten_view_302_size_0']
    81/248: Reshape ['mul_3', 'aten_view_302_size_0'] -> ['view_1']
    82/248: Constant [] -> ['aten_expand_305_size_1']
    83/248: Expand ['slice_4', 'aten_expand_305_size_1'] -> ['expand_1']
    84/248: MatMul ['view_1', 't'] -> ['mm']
    85/248: MatMul ['view_1', 't_1'] -> ['mm_1']
    86/248: MatMul ['view_1', 't_2'] -> ['mm_2']
    87/248: MatMul ['expand_3', 'view_11'] -> ['view_12']
    88/248: Constant [] -> ['aten_view_313_size_0']
    89/248: Reshape ['mm', 'aten_view_313_size_0'] -> ['view_2']
    90/248: Constant [] -> ['aten_view_315_size_0']
    91/248: Reshape ['mm_1', 'aten_view_315_size_0'] -> ['view_4']
    92/248: Constant [] -> ['aten_view_317_size_0']
    93/248: Reshape ['mm_2', 'aten_view_317_size_0'] -> ['view_6']
    94/248: Transpose ['view_12'] -> ['transpose_3']
    95/248: Constant [] -> ['aten_view_321_size_0']
    96/248: Reshape ['view_2', 'aten_view_321_size_0'] -> ['view_7']
    97/248: Constant [] -> ['aten_view_323_size_0']
    98/248: Reshape ['view_4', 'aten_view_323_size_0'] -> ['view_8']
    99/248: Constant [] -> ['aten_view_325_size_0']
    100/248: Reshape ['view_6', 'aten_view_325_size_0'] -> ['view_9']
    101/248: Concat ['transpose_3', 'transpose_3'] -> ['cat']
    102/248: Constant [] -> ['_val_229']
    103/248: Equal ['expand_1', '_val_229'] -> ['eq']
    104/248: Transpose ['view_7'] -> ['transpose']
    105/248: Transpose ['view_8'] -> ['transpose_1']
    106/248: Transpose ['view_9'] -> ['transpose_2']
    107/248: Cos ['cat'] -> ['cos']
    108/248: Sin ['cat'] -> ['sin']
    109/248: And ['eq', 'eq_1'] -> ['mul_1']
    110/248: Constant [] -> ['_val_240']
    111/248: Constant [] -> ['_val_244']
    112/248: Constant [] -> ['_val_248']
    113/248: Constant [] -> ['_val_252']
    114/248: Slice ['transpose', '_val_240', '_val_244', '_val_248', '_val_252'] -> ['slice_10']
    115/248: Constant [] -> ['_val_257']
    116/248: Constant [] -> ['_val_261']
    117/248: Constant [] -> ['_val_265']
    118/248: Constant [] -> ['_val_269']
    119/248: Slice ['transpose', '_val_257', '_val_261', '_val_265', '_val_269'] -> ['slice_11']
    120/248: Constant [] -> ['_val_274']
    121/248: Constant [] -> ['_val_278']
    122/248: Constant [] -> ['_val_282']
    123/248: Constant [] -> ['_val_286']
    124/248: Slice ['transpose_1', '_val_274', '_val_278', '_val_282', '_val_286'] -> ['slice_12']
    125/248: Constant [] -> ['_val_291']
    126/248: Constant [] -> ['_val_295']
    127/248: Constant [] -> ['_val_299']
    128/248: Constant [] -> ['_val_303']
    129/248: Slice ['transpose_1', '_val_291', '_val_295', '_val_299', '_val_303'] -> ['slice_13']
    130/248: Constant [] -> ['aten_expand_405_size_1']
    131/248: Expand ['transpose_2', 'aten_expand_405_size_1'] -> ['expand_8']
    132/248: Constant [] -> ['aten_unsqueeze_406_dim_0']
    133/248: Unsqueeze ['cos', 'aten_unsqueeze_406_dim_0'] -> ['unsqueeze_10']
    134/248: Constant [] -> ['aten_unsqueeze_407_dim_0']
    135/248: Unsqueeze ['sin', 'aten_unsqueeze_407_dim_0'] -> ['unsqueeze_11']
    136/248: Constant [] -> ['_val_309']
    137/248: Where ['mul_1', '_val_309', 'expand_1'] -> ['masked_fill']
    138/248: Neg ['slice_11'] -> ['neg']
    139/248: Neg ['slice_13'] -> ['neg_1']
    140/248: Mul ['transpose', 'unsqueeze_10'] -> ['mul_4']
    141/248: Mul ['transpose_1', 'unsqueeze_10'] -> ['mul_6']
    142/248: Concat ['neg', 'slice_10'] -> ['cat_1']
    143/248: Concat ['neg_1', 'slice_12'] -> ['cat_2']
    144/248: Constant [] -> ['aten_view_421_size_0']
    145/248: Reshape ['expand_8', 'aten_view_421_size_0'] -> ['view_17']
    146/248: Constant [] -> ['_val_326']
    147/248: Constant [] -> ['_val_330']
    148/248: Constant [] -> ['_val_334']
    149/248: Constant [] -> ['_val_338']
    150/248: Slice ['masked_fill', '_val_326', '_val_330', '_val_334', '_val_338'] -> ['slice_17']
    151/248: Mul ['cat_1', 'unsqueeze_11'] -> ['mul_5']
    152/248: Mul ['cat_2', 'unsqueeze_11'] -> ['mul_7']
    153/248: Transpose ['view_17'] -> ['transpose_8']
    154/248: Constant [] -> ['_val_346']
    155/248: Constant [] -> ['_val_350']
    156/248: Constant [] -> ['_val_354']
    157/248: Constant [] -> ['_val_358']
    158/248: Slice ['slice_17', '_val_346', '_val_350', '_val_354', '_val_358'] -> ['slice_18']
    159/248: Constant [] -> ['alpha__2']
    160/248: Mul ['mul_5', 'alpha__2'] -> ['other_1__2']
    161/248: Add ['mul_4', 'other_1__2'] -> ['add_1']
    162/248: Constant [] -> ['alpha__3']
    163/248: Mul ['mul_7', 'alpha__3'] -> ['other_1__3']
    164/248: Add ['mul_6', 'other_1__3'] -> ['add_2']
    165/248: Constant [] -> ['_val_365']
    166/248: Constant [] -> ['_val_369']
    167/248: Constant [] -> ['_val_373']
    168/248: Constant [] -> ['_val_377']
    169/248: Slice ['slice_18', '_val_365', '_val_369', '_val_373', '_val_377'] -> ['slice_19']
    170/248: Constant [] -> ['aten_expand_479_size_1']
    171/248: Expand ['add_1', 'aten_expand_479_size_1'] -> ['expand_5']
    172/248: Transpose ['add_2'] -> ['transpose_4']
    173/248: Constant [] -> ['aten_expand_483_size_1']
    174/248: Expand ['transpose_4', 'aten_expand_483_size_1'] -> ['expand_6']
    175/248: Constant [] -> ['aten_view_485_size_0']
    176/248: Reshape ['expand_5', 'aten_view_485_size_0'] -> ['view_13']
    177/248: Transpose ['view_13'] -> ['transpose_9']
    178/248: Constant [] -> ['aten_view_489_size_0']
    179/248: Reshape ['expand_6', 'aten_view_489_size_0'] -> ['view_14']
    180/248: MatMul ['view_13', 'view_14'] -> ['bmm_1']
    181/248: Transpose ['view_14'] -> ['transpose_10']
    182/248: Constant [] -> ['aten_view_493_size_0']
    183/248: Reshape ['bmm_1', 'aten_view_493_size_0'] -> ['view_15']
    184/248: Constant [] -> ['_val_395']
    185/248: Div ['view_15', '_val_395'] -> ['div']
    186/248: Constant [] -> ['alpha__4']
    187/248: Mul ['slice_19', 'alpha__4'] -> ['other_1__4']
    188/248: Add ['div', 'other_1__4'] -> ['add_3']
    189/248: Softmax ['add_3'] -> ['_softmax']
    190/248: Constant [] -> ['aten_expand_502_size_1']
    191/248: Expand ['_softmax', 'aten_expand_502_size_1'] -> ['expand_7']
    192/248: Constant [] -> ['aten_view_505_size_0']
    193/248: Reshape ['expand_7', 'aten_view_505_size_0'] -> ['view_16']
    194/248: Identity ['_softmax'] -> ['detach_13']
    195/248: MatMul ['view_16', 'view_17'] -> ['bmm_2']
    196/248: Transpose ['view_16'] -> ['transpose_7']
    197/248: Constant [] -> ['aten_view_510_size_0']
    198/248: Reshape ['bmm_2', 'aten_view_510_size_0'] -> ['view_18']
    199/248: Transpose ['view_18'] -> ['transpose_5']
    200/248: Constant [] -> ['aten_view_514_size_0']
    201/248: Reshape ['transpose_5', 'aten_view_514_size_0'] -> ['view_19']
    202/248: Constant [] -> ['aten_view_516_size_0']
    203/248: Reshape ['view_19', 'aten_view_516_size_0'] -> ['view_20']
    204/248: MatMul ['view_20', 't_3'] -> ['mm_3']
    205/248: Constant [] -> ['aten_view_519_size_0']
    206/248: Reshape ['mm_3', 'aten_view_519_size_0'] -> ['view_21']
    207/248: Constant [] -> ['alpha__5']
    208/248: Mul ['view_21', 'alpha__5'] -> ['other_1__5']
    209/248: Add ['embedding', 'other_1__5'] -> ['add_4']
    210/248: Constant [] -> ['scalar_tensor_default_1']
    211/248: Pow ['add_4', 'scalar_tensor_default_1'] -> ['pow_2']
    212/248: Constant [] -> ['_val_425']
    213/248: ReduceMean ['pow_2', '_val_425'] -> ['mean_1']
    214/248: Constant [] -> ['aten_add_527_other_1']
    215/248: Add ['mean_1', 'aten_add_527_other_1'] -> ['add_5']
    216/248: Sqrt ['add_5'] -> ['aten_rsqrt_528_tmp']
    217/248: Reciprocal ['aten_rsqrt_528_tmp'] -> ['rsqrt_1']
    218/248: Mul ['add_4', 'rsqrt_1'] -> ['mul_8']
    219/248: Mul ['primals_2', 'mul_8'] -> ['mul_9']
    220/248: Constant [] -> ['aten_view_532_size_0']
    221/248: Reshape ['mul_9', 'aten_view_532_size_0'] -> ['view_22']
    222/248: MatMul ['view_22', 't_4'] -> ['mm_4']
    223/248: MatMul ['view_22', 't_5'] -> ['mm_5']
    224/248: Constant [] -> ['aten_view_536_size_0']
    225/248: Reshape ['mm_4', 'aten_view_536_size_0'] -> ['view_23']
    226/248: Constant [] -> ['aten_view_538_size_0']
    227/248: Reshape ['mm_5', 'aten_view_538_size_0'] -> ['view_25']
    228/248: Sigmoid ['view_23'] -> ['sigmoid']
    229/248: Mul ['view_23', 'sigmoid'] -> ['mul_10']
    230/248: Mul ['mul_10', 'view_25'] -> ['mul_11']
    231/248: Constant [] -> ['aten_view_543_size_0']
    232/248: Reshape ['mul_11', 'aten_view_543_size_0'] -> ['view_26']
    233/248: MatMul ['view_26', 't_6'] -> ['mm_6']
    234/248: Constant [] -> ['aten_view_546_size_0']
    235/248: Reshape ['mm_6', 'aten_view_546_size_0'] -> ['view_27']
    236/248: Constant [] -> ['alpha__6']
    237/248: Mul ['view_27', 'alpha__6'] -> ['other_1__6']
    238/248: Add ['add_4', 'other_1__6'] -> ['add_6']
    239/248: Constant [] -> ['scalar_tensor_default_2']
    240/248: Pow ['add_6', 'scalar_tensor_default_2'] -> ['pow_3']
    241/248: Constant [] -> ['_val_452']
    242/248: ReduceMean ['pow_3', '_val_452'] -> ['mean_2']
    243/248: Constant [] -> ['aten_add_554_other_1']
    244/248: Add ['mean_2', 'aten_add_554_other_1'] -> ['add_7']
    245/248: Sqrt ['add_7'] -> ['aten_rsqrt_555_tmp']
    246/248: Reciprocal ['aten_rsqrt_555_tmp'] -> ['rsqrt_2']
    247/248: Mul ['add_6', 'rsqrt_2'] -> ['mul_12']
    248/248: Mul ['primals_3', 'mul_12'] -> ['mul_13']
    [runpythonerror]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
      warnings.warn(
    2024-05-08 14:07:07,804 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-08 14:07:07,805 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue full due to large size 4194304.
    2024-05-08 14:07:07,829 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue triu due to large size 4194304.
    2024-05-08 14:07:07,864 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue convert_element_type_default due to large size 4194304.
    2024-05-08 14:07:07,865 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue mul due to large size 4194304.
    2024-05-08 14:07:07,879 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-08 14:07:07,880 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_3 due to large size 4194304.
    2024-05-08 14:07:07,882 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-08 14:07:07,882 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_4 due to large size 4194304.
    2024-05-08 14:07:07,889 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_3 due to large size 4194304.
    2024-05-08 14:07:07,895 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_4 due to large size 4194304.
    2024-05-08 14:07:07,900 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 8388608.
    2024-05-08 14:07:07,901 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue expand_1 due to large size 8388608.
    2024-05-08 14:07:07,905 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue clone due to large size 8388608.
    2024-05-08 14:07:07,910 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue alias due to large size 8388608.
    2024-05-08 14:07:07,913 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue eq due to large size 2097152.
    2024-05-08 14:07:08,071 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-08 14:07:08,071 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue full due to large size 4194304.
    2024-05-08 14:07:08,075 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue triu due to large size 4194304.
    2024-05-08 14:07:08,080 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue convert_element_type_default due to large size 4194304.
    2024-05-08 14:07:08,081 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue mul due to large size 4194304.
    2024-05-08 14:07:08,083 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_3 due to large size 4194304.
    2024-05-08 14:07:08,084 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_4 due to large size 4194304.
    2024-05-08 14:07:08,085 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_3 due to large size 4194304.
    2024-05-08 14:07:08,086 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_4 due to large size 4194304.
    2024-05-08 14:07:08,088 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue expand_1 due to large size 8388608.
    2024-05-08 14:07:08,092 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue eq due to large size 2097152.
    2024-05-08 14:07:08.190834200 [W:onnxruntime:, graph.cc:4051 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_23'. It is not used by any node and should be removed from the model.
    2024-05-08 14:07:08.190879200 [W:onnxruntime:, graph.cc:4051 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_22'. It is not used by any node and should be removed from the model.

With the custom exporter

<<<

import os
import warnings
import numpy as np
import onnx

# from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
import torch
import torch.onnx
from experimental_experiment.torch_models.training_helper import (
    make_aot_ort,
    train_loop,
)
from experimental_experiment.torch_models.dump_helper import dump_onnx

# from experimental_experiment.torch_interpreter import to_onnx

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    from transformers import LlamaConfig
    from transformers.models.llama.modeling_llama import LlamaModel


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = LlamaConfig(
    hidden_size=16,
    num_hidden_layers=1,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=1024,
    num_attention_heads=2,
)
config._attn_implementation = "eager"

model = LlamaModel(config)

batch, seq, vocab_size = 2, 1024, 1024

input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

model(input_ids, input_mask)

os.environ["ONNXRT_CHANGE_REWRITER"] = "1"

local_aot_ort, _ = make_aot_ort(
    dynamic=True,
    verbose=1,
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
    with dump_onnx("dort-llama-ort", folder="dump_llama", clean=True):
        train_loop(optimized_mod, input_ids, input_mask)

names = [_ for _ in os.listdir("dump_llama") if _.endswith(".onnx")]
print("------------------------------------------")
print(f"exported model: {names}")
for name in names:
    print()
    print("NODES in {name!r}")
    onx = onnx.load(os.path.join("dump_llama", name))
    for i, node in enumerate(onx.graph.node):
        print(
            f"{i+1}/{len(onx.graph.node)}: {node.op_type} {node.input} -> {node.output}"
        )

os.environ["ONNXRT_CHANGE_REWRITER"] = "0"

>>>

    [2024-05-08 14:07:12,831] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    Applied 1 of general pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific function rewrite rules.
    Applied 0 of onnxruntime specific pattern rewrite rules.
    ------------------------------------------
    exported model: ['dort-llama-ort_1.onnx', 'dort-llama-ort_0.onnx']
    
    NODES in {name!r}
    1/305: Cos ['cat'] -> ['cos']
    2/305: Constant [] -> ['aten_view_112_size_0']
    3/305: Reshape ['mm_5', 'aten_view_112_size_0'] -> ['view_25']
    4/305: Constant [] -> ['aten_view_114_size_0']
    5/305: Reshape ['mm_4', 'aten_view_114_size_0'] -> ['view_23']
    6/305: Mul ['embedding', 'rsqrt'] -> ['mul_2']
    7/305: Constant [] -> ['_val_41']
    8/305: Equal ['primals_13', '_val_41'] -> ['eq_2']
    9/305: Mul ['tangents_1', 'primals_3'] -> ['mul_14']
    10/305: Constant [] -> ['_val_44']
    11/305: Pow ['embedding', '_val_44'] -> ['pow_9']
    12/305: Mul ['add_6', 'rsqrt_2'] -> ['mul_12']
    13/305: Constant [] -> ['aten_view_125_size_0']
    14/305: Reshape ['mm_3', 'aten_view_125_size_0'] -> ['view_21']
    15/305: Sin ['cat'] -> ['sin']
    16/305: Constant [] -> ['_val_51']
    17/305: Pow ['add_6', '_val_51'] -> ['pow_5']
    18/305: Constant [] -> ['aten_unsqueeze_133_dim_0']
    19/305: Unsqueeze ['cos', 'aten_unsqueeze_133_dim_0'] -> ['unsqueeze_10']
    20/305: Mul ['view_23', 'sigmoid'] -> ['mul_10']
    21/305: Constant [] -> ['aten_unsqueeze_137_dim_0']
    22/305: Unsqueeze ['eq_2', 'aten_unsqueeze_137_dim_0'] -> ['unsqueeze_12']
    23/305: Mul ['mul_14', 'add_6'] -> ['mul_16']
    24/305: Mul ['mul_14', 'rsqrt_2'] -> ['mul_17']
    25/305: Constant [] -> ['_val_63']
    26/305: Mul ['pow_9', '_val_63'] -> ['mul_47']
    27/305: Mul ['tangents_1', 'mul_12'] -> ['mul_15']
    28/305: Constant [] -> ['alpha__1']
    29/305: Mul ['view_21', 'alpha__1'] -> ['other_1__1']
    30/305: Add ['embedding', 'other_1__1'] -> ['add_4']
    31/305: Constant [] -> ['aten_unsqueeze_145_dim_0']
    32/305: Unsqueeze ['sin', 'aten_unsqueeze_145_dim_0'] -> ['unsqueeze_11']
    33/305: Constant [] -> ['_val_69']
    34/305: Mul ['pow_5', '_val_69'] -> ['mul_20']
    35/305: Constant [] -> ['fill']
    36/305: Constant [] -> ['_val_75']
    37/305: ReduceSum ['mul_16', '_val_75'] -> ['sum_2']
    38/305: Constant [] -> ['_val_77']
    39/305: ReduceSum ['mul_15', '_val_77'] -> ['sum_1']
    40/305: Mul ['add_4', 'rsqrt_1'] -> ['mul_8']
    41/305: Constant [] -> ['_val_80']
    42/305: Pow ['add_4', '_val_80'] -> ['pow_7']
    43/305: Constant [] -> ['alpha__2']
    44/305: Mul ['sigmoid', 'alpha__2'] -> ['other_1__2']
    45/305: Sub ['fill', 'other_1__2'] -> ['sub']
    46/305: Constant [] -> ['_val_86']
    47/305: Mul ['sum_2', '_val_86'] -> ['mul_18']
    48/305: Constant [] -> ['aten_view_166_size_0']
    49/305: Reshape ['sum_1', 'aten_view_166_size_0'] -> ['view_28']
    50/305: Constant [] -> ['_val_90']
    51/305: Mul ['pow_7', '_val_90'] -> ['mul_33']
    52/305: Mul ['view_23', 'sub'] -> ['mul_24']
    53/305: Constant [] -> ['scalar_tensor_default']
    54/305: Pow ['rsqrt', 'scalar_tensor_default'] -> ['pow_8']
    55/305: Constant [] -> ['scalar_tensor_default_1']
    56/305: Pow ['rsqrt_1', 'scalar_tensor_default_1'] -> ['pow_6']
    57/305: Constant [] -> ['scalar_tensor_default_2']
    58/305: Pow ['rsqrt_2', 'scalar_tensor_default_2'] -> ['pow_4']
    59/305: Constant [] -> ['aten_add_182_other_1']
    60/305: Add ['mul_24', 'aten_add_182_other_1'] -> ['add_9']
    61/305: Mul ['mul_18', 'pow_4'] -> ['mul_19']
    62/305: Mul ['sigmoid', 'add_9'] -> ['mul_25']
    63/305: Constant [] -> ['aten_expand_186_size_1']
    64/305: Expand ['mul_19', 'aten_expand_186_size_1'] -> ['expand_9']
    65/305: Constant [] -> ['scalar_tensor_default_4']
    66/305: Div ['expand_9', 'scalar_tensor_default_4'] -> ['div_1']
    67/305: Mul ['div_1', 'mul_20'] -> ['mul_21']
    68/305: Constant [] -> ['alpha__3']
    69/305: Mul ['mul_21', 'alpha__3'] -> ['other_1__3']
    70/305: Add ['mul_17', 'other_1__3'] -> ['add_8']
    71/305: Constant [] -> ['aten_view_193_size_0']
    72/305: Reshape ['add_8', 'aten_view_193_size_0'] -> ['view_29']
    73/305: Transpose ['view_29'] -> ['t_7']
    74/305: MatMul ['view_29', 't_9'] -> ['mm_8']
    75/305: MatMul ['t_7', 'view_26'] -> ['mm_7']
    76/305: Constant [] -> ['aten_view_198_size_0']
    77/305: Reshape ['mm_8', 'aten_view_198_size_0'] -> ['view_30']
    78/305: Transpose ['mm_7'] -> ['t_8']
    79/305: Mul ['view_30', 'mul_10'] -> ['mul_22']
    80/305: Mul ['view_30', 'view_25'] -> ['mul_23']
    81/305: Transpose ['t_8'] -> ['t_10']
    82/305: Constant [] -> ['aten_view_204_size_0']
    83/305: Reshape ['mul_22', 'aten_view_204_size_0'] -> ['view_31']
    84/305: Mul ['mul_23', 'mul_25'] -> ['mul_26']
    85/305: Transpose ['view_31'] -> ['t_11']
    86/305: MatMul ['view_31', 't_13'] -> ['mm_10']
    87/305: Constant [] -> ['aten_view_209_size_0']
    88/305: Reshape ['mul_26', 'aten_view_209_size_0'] -> ['view_33']
    89/305: MatMul ['t_11', 'view_22'] -> ['mm_9']
    90/305: Constant [] -> ['aten_view_212_size_0']
    91/305: Reshape ['mm_10', 'aten_view_212_size_0'] -> ['view_32']
    92/305: Transpose ['view_33'] -> ['t_15']
    93/305: MatMul ['view_33', 't_17'] -> ['mm_12']
    94/305: Transpose ['mm_9'] -> ['t_12']
    95/305: MatMul ['t_15', 'view_22'] -> ['mm_11']
    96/305: Constant [] -> ['aten_view_218_size_0']
    97/305: Reshape ['mm_12', 'aten_view_218_size_0'] -> ['view_34']
    98/305: Transpose ['t_12'] -> ['t_14']
    99/305: Transpose ['mm_11'] -> ['t_16']
    100/305: Constant [] -> ['alpha__4']
    101/305: Mul ['view_34', 'alpha__4'] -> ['other_1__4']
    102/305: Add ['view_32', 'other_1__4'] -> ['add_10']
    103/305: Transpose ['t_16'] -> ['t_18']
    104/305: Mul ['add_10', 'primals_2'] -> ['mul_27']
    105/305: Mul ['add_10', 'mul_8'] -> ['mul_28']
    106/305: Mul ['mul_27', 'add_4'] -> ['mul_29']
    107/305: Mul ['mul_27', 'rsqrt_1'] -> ['mul_30']
    108/305: Constant [] -> ['_val_150']
    109/305: ReduceSum ['mul_28', '_val_150'] -> ['sum_3']
    110/305: Constant [] -> ['_val_152']
    111/305: ReduceSum ['mul_29', '_val_152'] -> ['sum_4']
    112/305: Constant [] -> ['alpha__5']
    113/305: Mul ['mul_30', 'alpha__5'] -> ['other_1__5']
    114/305: Add ['add_8', 'other_1__5'] -> ['add_11']
    115/305: Constant [] -> ['aten_view_233_size_0']
    116/305: Reshape ['sum_3', 'aten_view_233_size_0'] -> ['view_35']
    117/305: Constant [] -> ['_val_157']
    118/305: Mul ['sum_4', '_val_157'] -> ['mul_31']
    119/305: Mul ['mul_31', 'pow_6'] -> ['mul_32']
    120/305: Constant [] -> ['aten_expand_238_size_1']
    121/305: Expand ['mul_32', 'aten_expand_238_size_1'] -> ['expand_10']
    122/305: Constant [] -> ['scalar_tensor_default_5']
    123/305: Div ['expand_10', 'scalar_tensor_default_5'] -> ['div_2']
    124/305: Mul ['div_2', 'mul_33'] -> ['mul_34']
    125/305: Constant [] -> ['alpha__6']
    126/305: Mul ['mul_34', 'alpha__6'] -> ['other_1__6']
    127/305: Add ['add_11', 'other_1__6'] -> ['add_12']
    128/305: Constant [] -> ['aten_view_245_size_0']
    129/305: Reshape ['add_12', 'aten_view_245_size_0'] -> ['view_36']
    130/305: Transpose ['view_36'] -> ['t_19']
    131/305: MatMul ['view_36', 't_21'] -> ['mm_14']
    132/305: MatMul ['t_19', 'view_20'] -> ['mm_13']
    133/305: Constant [] -> ['aten_view_250_size_0']
    134/305: Reshape ['mm_14', 'aten_view_250_size_0'] -> ['view_37']
    135/305: Transpose ['mm_13'] -> ['t_20']
    136/305: Constant [] -> ['aten_view_253_size_0']
    137/305: Reshape ['view_37', 'aten_view_253_size_0'] -> ['view_38']
    138/305: Transpose ['t_20'] -> ['t_22']
    139/305: Transpose ['view_38'] -> ['transpose_6']
    140/305: Constant [] -> ['aten_view_258_size_0']
    141/305: Reshape ['transpose_6', 'aten_view_258_size_0'] -> ['view_39']
    142/305: MatMul ['transpose_7', 'view_39'] -> ['bmm_3']
    143/305: MatMul ['view_39', 'transpose_8'] -> ['bmm_4']
    144/305: Constant [] -> ['aten_view_262_size_0']
    145/305: Reshape ['bmm_3', 'aten_view_262_size_0'] -> ['view_40']
    146/305: Constant [] -> ['aten_view_264_size_0']
    147/305: Reshape ['bmm_4', 'aten_view_264_size_0'] -> ['view_41']
    148/305: Constant [] -> ['alpha__7']
    149/305: Mul ['view_40', 'alpha__7'] -> ['other_1__7']
    150/305: Add ['tangents_3', 'other_1__7'] -> ['add_13']
    151/305: Mul ['view_41', 'detach_13'] -> ['mul_35']
    152/305: Transpose ['add_13'] -> ['transpose_12']
    153/305: Constant [] -> ['_val_191']
    154/305: ReduceSum ['mul_35', '_val_191'] -> ['sum_5']
    155/305: Mul ['detach_13', 'sum_5'] -> ['mul_36']
    156/305: Constant [] -> ['aten_view_273_size_0']
    157/305: Reshape ['transpose_12', 'aten_view_273_size_0'] -> ['view_45']
    158/305: Constant [] -> ['alpha__8']
    159/305: Mul ['mul_36', 'alpha__8'] -> ['other_1__8']
    160/305: Sub ['mul_35', 'other_1__8'] -> ['sub_1']
    161/305: Constant [] -> ['aten_view_276_size_0']
    162/305: Reshape ['view_45', 'aten_view_276_size_0'] -> ['view_48']
    163/305: Constant [] -> ['_val_200']
    164/305: Div ['sub_1', '_val_200'] -> ['div_3']
    165/305: Transpose ['view_48'] -> ['t_23']
    166/305: MatMul ['view_48', 't_25'] -> ['mm_16']
    167/305: Constant [] -> ['aten_view_282_size_0']
    168/305: Reshape ['div_3', 'aten_view_282_size_0'] -> ['view_42']
    169/305: MatMul ['t_23', 'view_1'] -> ['mm_15']
    170/305: Constant [] -> ['aten_view_285_size_0']
    171/305: Reshape ['mm_16', 'aten_view_285_size_0'] -> ['view_49']
    172/305: MatMul ['transpose_9', 'view_42'] -> ['bmm_5']
    173/305: MatMul ['view_42', 'transpose_10'] -> ['bmm_6']
    174/305: Transpose ['mm_15'] -> ['t_24']
    175/305: Constant [] -> ['aten_view_290_size_0']
    176/305: Reshape ['bmm_5', 'aten_view_290_size_0'] -> ['view_43']
    177/305: Constant [] -> ['aten_view_292_size_0']
    178/305: Reshape ['bmm_6', 'aten_view_292_size_0'] -> ['view_44']
    179/305: Transpose ['t_24'] -> ['t_26']
    180/305: Transpose ['view_43'] -> ['transpose_11']
    181/305: Mul ['view_44', 'unsqueeze_11'] -> ['mul_39']
    182/305: Mul ['view_44', 'unsqueeze_10'] -> ['mul_40']
    183/305: Constant [] -> ['alpha__9']
    184/305: Mul ['transpose_11', 'alpha__9'] -> ['other_1__9']
    185/305: Add ['tangents_2', 'other_1__9'] -> ['add_14']
    186/305: Constant [] -> ['_val_224']
    187/305: Constant [] -> ['_val_228']
    188/305: Constant [] -> ['_val_232']
    189/305: Constant [] -> ['_val_236']
    190/305: Slice ['mul_39', '_val_224', '_val_228', '_val_232', '_val_236'] -> ['slice_22']
    191/305: Constant [] -> ['_val_241']
    192/305: Constant [] -> ['_val_245']
    193/305: Constant [] -> ['_val_249']
    194/305: Constant [] -> ['_val_253']
    195/305: Slice ['mul_39', '_val_241', '_val_245', '_val_249', '_val_253'] -> ['slice_23']
    196/305: Mul ['add_14', 'unsqueeze_11'] -> ['mul_37']
    197/305: Mul ['add_14', 'unsqueeze_10'] -> ['mul_38']
    198/305: Neg ['slice_22'] -> ['neg_3']
    199/305: Constant [] -> ['_val_263']
    200/305: Constant [] -> ['_val_267']
    201/305: Constant [] -> ['_val_271']
    202/305: Constant [] -> ['_val_275']
    203/305: Slice ['mul_37', '_val_263', '_val_267', '_val_271', '_val_275'] -> ['slice_20']
    204/305: Constant [] -> ['_val_280']
    205/305: Constant [] -> ['_val_284']
    206/305: Constant [] -> ['_val_288']
    207/305: Constant [] -> ['_val_292']
    208/305: Slice ['mul_37', '_val_280', '_val_284', '_val_288', '_val_292'] -> ['slice_21']
    209/305: Constant [] -> ['_val_311']
    210/305: Transpose ['slice_23'] -> ['_val_312']
    211/305: Constant [] -> ['_val_313']
    212/305: ScatterND ['_val_313', '_val_311', '_val_312'] -> ['_val_314']
    213/305: Transpose ['_val_314'] -> ['slice_scatter_3']
    214/305: Neg ['slice_20'] -> ['neg_2']
    215/305: Constant [] -> ['_val_334']
    216/305: Transpose ['neg_3'] -> ['_val_335']
    217/305: Constant [] -> ['_val_336']
    218/305: ScatterND ['_val_336', '_val_334', '_val_335'] -> ['_val_337']
    219/305: Transpose ['_val_337'] -> ['slice_scatter_2']
    220/305: Constant [] -> ['_val_356']
    221/305: Transpose ['slice_21'] -> ['_val_357']
    222/305: Constant [] -> ['_val_358']
    223/305: ScatterND ['_val_358', '_val_356', '_val_357'] -> ['_val_359']
    224/305: Transpose ['_val_359'] -> ['slice_scatter_1']
    225/305: Constant [] -> ['alpha__10']
    226/305: Mul ['slice_scatter_3', 'alpha__10'] -> ['other_1__10']
    227/305: Add ['slice_scatter_2', 'other_1__10'] -> ['add_17']
    228/305: Constant [] -> ['_val_377']
    229/305: Transpose ['neg_2'] -> ['_val_378']
    230/305: Constant [] -> ['_val_379']
    231/305: ScatterND ['_val_379', '_val_377', '_val_378'] -> ['_val_380']
    232/305: Transpose ['_val_380'] -> ['slice_scatter']
    233/305: Constant [] -> ['alpha__11']
    234/305: Mul ['mul_40', 'alpha__11'] -> ['other_1__11']
    235/305: Add ['add_17', 'other_1__11'] -> ['add_18']
    236/305: Constant [] -> ['alpha__12']
    237/305: Mul ['slice_scatter_1', 'alpha__12'] -> ['other_1__12']
    238/305: Add ['slice_scatter', 'other_1__12'] -> ['add_15']
    239/305: Transpose ['add_18'] -> ['transpose_14']
    240/305: Constant [] -> ['alpha__13']
    241/305: Mul ['mul_38', 'alpha__13'] -> ['other_1__13']
    242/305: Add ['add_15', 'other_1__13'] -> ['add_16']
    243/305: Transpose ['add_16'] -> ['transpose_13']
    244/305: Constant [] -> ['aten_view_466_size_0']
    245/305: Reshape ['transpose_14', 'aten_view_466_size_0'] -> ['view_47']
    246/305: Constant [] -> ['aten_view_469_size_0']
    247/305: Reshape ['view_47', 'aten_view_469_size_0'] -> ['view_52']
    248/305: Constant [] -> ['aten_view_471_size_0']
    249/305: Reshape ['transpose_13', 'aten_view_471_size_0'] -> ['view_46']
    250/305: Transpose ['view_52'] -> ['t_31']
    251/305: MatMul ['view_52', 't_33'] -> ['mm_20']
    252/305: Constant [] -> ['aten_view_475_size_0']
    253/305: Reshape ['view_46', 'aten_view_475_size_0'] -> ['view_50']
    254/305: MatMul ['t_31', 'view_1'] -> ['mm_19']
    255/305: Constant [] -> ['aten_view_478_size_0']
    256/305: Reshape ['mm_20', 'aten_view_478_size_0'] -> ['view_53']
    257/305: Transpose ['view_50'] -> ['t_27']
    258/305: MatMul ['view_50', 't_29'] -> ['mm_18']
    259/305: Transpose ['mm_19'] -> ['t_32']
    260/305: MatMul ['t_27', 'view_1'] -> ['mm_17']
    261/305: Constant [] -> ['aten_view_484_size_0']
    262/305: Reshape ['mm_18', 'aten_view_484_size_0'] -> ['view_51']
    263/305: Transpose ['t_32'] -> ['t_34']
    264/305: Transpose ['mm_17'] -> ['t_28']
    265/305: Constant [] -> ['alpha__14']
    266/305: Mul ['view_51', 'alpha__14'] -> ['other_1__14']
    267/305: Add ['view_49', 'other_1__14'] -> ['add_19']
    268/305: Transpose ['t_28'] -> ['t_30']
    269/305: Constant [] -> ['alpha__15']
    270/305: Mul ['view_53', 'alpha__15'] -> ['other_1__15']
    271/305: Add ['add_19', 'other_1__15'] -> ['add_20']
    272/305: Mul ['add_20', 'primals_1'] -> ['mul_41']
    273/305: Mul ['add_20', 'mul_2'] -> ['mul_42']
    274/305: Mul ['mul_41', 'embedding'] -> ['mul_43']
    275/305: Mul ['mul_41', 'rsqrt'] -> ['mul_44']
    276/305: Constant [] -> ['_val_417']
    277/305: ReduceSum ['mul_42', '_val_417'] -> ['sum_6']
    278/305: Constant [] -> ['_val_419']
    279/305: ReduceSum ['mul_43', '_val_419'] -> ['sum_7']
    280/305: Constant [] -> ['alpha__16']
    281/305: Mul ['mul_44', 'alpha__16'] -> ['other_1__16']
    282/305: Add ['add_12', 'other_1__16'] -> ['add_21']
    283/305: Constant [] -> ['aten_view_500_size_0']
    284/305: Reshape ['sum_6', 'aten_view_500_size_0'] -> ['view_54']
    285/305: Constant [] -> ['_val_424']
    286/305: Mul ['sum_7', '_val_424'] -> ['mul_45']
    287/305: Mul ['mul_45', 'pow_8'] -> ['mul_46']
    288/305: Constant [] -> ['aten_expand_505_size_1']
    289/305: Expand ['mul_46', 'aten_expand_505_size_1'] -> ['expand_11']
    290/305: Constant [] -> ['scalar_tensor_default_6']
    291/305: Div ['expand_11', 'scalar_tensor_default_6'] -> ['div_4']
    292/305: Mul ['div_4', 'mul_47'] -> ['mul_48']
    293/305: Constant [] -> ['alpha__17']
    294/305: Mul ['mul_48', 'alpha__17'] -> ['other_1__17']
    295/305: Add ['add_21', 'other_1__17'] -> ['add_22']
    296/305: Constant [] -> ['aten_masked_fill_512_value_cast']
    297/305: Where ['unsqueeze_12', 'aten_masked_fill_512_value_cast', 'add_22'] -> ['masked_fill_1']
    298/305: Constant [] -> ['_val_436']
    299/305: ConstantOfShape ['_val_436'] -> ['aten_new_zeros_514_result']
    300/305: SequenceConstruct ['primals_13'] -> ['438']
    301/305: Constant [] -> ['int64_0__18']
    302/305: SequenceAt ['438', 'int64_0__18'] -> ['index__18']
    303/305: Constant [] -> ['int64_m1_1d__18']
    304/305: Unsqueeze ['index__18', 'int64_m1_1d__18'] -> ['new_index__18']
    305/305: ScatterND ['aten_new_zeros_514_result', 'new_index__18', 'masked_fill_1'] -> ['_unsafe_index_put']
    
    NODES in {name!r}
    1/248: Gather ['primals_4', 'primals_13'] -> ['embedding']
    2/248: Transpose ['primals_8'] -> ['t_3']
    3/248: Constant [] -> ['_val_22']
    4/248: Constant [] -> ['_val_23']
    5/248: Constant [] -> ['size_0__1']
    6/248: Constant [] -> ['fill_value_1__1']
    7/248: Expand ['fill_value_1__1', 'size_0__1'] -> ['full']
    8/248: Constant [] -> ['_val_36']
    9/248: Constant [] -> ['_val_40']
    10/248: Constant [] -> ['_val_44']
    11/248: Constant [] -> ['_val_48']
    12/248: Slice ['primals_14', '_val_36', '_val_40', '_val_44', '_val_48'] -> ['slice_5']
    13/248: Transpose ['primals_9'] -> ['t_4']
    14/248: Transpose ['primals_10'] -> ['t_5']
    15/248: Transpose ['primals_11'] -> ['t_6']
    16/248: Transpose ['primals_5'] -> ['t']
    17/248: Transpose ['primals_6'] -> ['t_1']
    18/248: Transpose ['primals_7'] -> ['t_2']
    19/248: Constant [] -> ['aten_unsqueeze_155_dim_0']
    20/248: Unsqueeze ['primals_12', 'aten_unsqueeze_155_dim_0'] -> ['unsqueeze_7']
    21/248: Constant [] -> ['scalar_tensor_default']
    22/248: Pow ['embedding', 'scalar_tensor_default'] -> ['pow_1']
    23/248: Transpose ['t_3'] -> ['t_21']
    24/248: Constant [] -> ['aten_triu_163_diagonal']
    25/248: Trilu ['full', 'aten_triu_163_diagonal'] -> ['triu']
    26/248: Constant [] -> ['aten_unsqueeze_164_dim_0']
    27/248: Unsqueeze ['slice_5', 'aten_unsqueeze_164_dim_0'] -> ['unsqueeze_5']
    28/248: Transpose ['t_4'] -> ['t_17']
    29/248: Transpose ['t_5'] -> ['t_13']
    30/248: Transpose ['t_6'] -> ['t_9']
    31/248: Transpose ['t'] -> ['t_33']
    32/248: Transpose ['t_1'] -> ['t_29']
    33/248: Transpose ['t_2'] -> ['t_25']
    34/248: Constant [] -> ['_val_75']
    35/248: Constant [] -> ['_val_79']
    36/248: Constant [] -> ['_val_83']
    37/248: Constant [] -> ['_val_87']
    38/248: Slice ['unsqueeze_7', '_val_75', '_val_79', '_val_83', '_val_87'] -> ['slice_7']
    39/248: Constant [] -> ['gt']
    40/248: Constant [] -> ['_val_107']
    41/248: ReduceMean ['pow_1', '_val_107'] -> ['mean']
    42/248: Constant [] -> ['aten_unsqueeze_208_dim_0']
    43/248: Unsqueeze ['unsqueeze_5', 'aten_unsqueeze_208_dim_0'] -> ['unsqueeze_6']
    44/248: Constant [] -> ['aten_unsqueeze_209_dim_0']
    45/248: Unsqueeze ['slice_7', 'aten_unsqueeze_209_dim_0'] -> ['unsqueeze_8']
    46/248: Cast ['gt'] -> ['convert_element_type_default']
    47/248: Mul ['triu', 'convert_element_type_default'] -> ['mul']
    48/248: Constant [] -> ['aten_add_214_other_1']
    49/248: Add ['mean', 'aten_add_214_other_1'] -> ['add']
    50/248: Constant [] -> ['_val_119']
    51/248: Constant [] -> ['_val_123']
    52/248: Constant [] -> ['_val_127']
    53/248: Constant [] -> ['_val_131']
    54/248: Slice ['unsqueeze_6', '_val_119', '_val_123', '_val_127', '_val_131'] -> ['slice_6']
    55/248: Constant [] -> ['aten_expand_233_size_1']
    56/248: Expand ['unsqueeze_8', 'aten_expand_233_size_1'] -> ['expand_2']
    57/248: Constant [] -> ['aten_unsqueeze_251_dim_0']
    58/248: Unsqueeze ['mul', 'aten_unsqueeze_251_dim_0'] -> ['unsqueeze_3']
    59/248: Sqrt ['add'] -> ['aten_rsqrt_252_tmp']
    60/248: Reciprocal ['aten_rsqrt_252_tmp'] -> ['rsqrt']
    61/248: Constant [] -> ['_val_154']
    62/248: Equal ['slice_6', '_val_154'] -> ['eq_1']
    63/248: Constant [] -> ['aten_expand_256_size_1']
    64/248: Expand ['expand_2', 'aten_expand_256_size_1'] -> ['expand_3']
    65/248: Constant [] -> ['aten_unsqueeze_258_dim_0']
    66/248: Unsqueeze ['unsqueeze_3', 'aten_unsqueeze_258_dim_0'] -> ['unsqueeze_4']
    67/248: Mul ['embedding', 'rsqrt'] -> ['mul_2']
    68/248: Constant [] -> ['_val_168']
    69/248: Constant [] -> ['_val_172']
    70/248: Constant [] -> ['_val_176']
    71/248: Constant [] -> ['_val_180']
    72/248: Slice ['unsqueeze_4', '_val_168', '_val_172', '_val_176', '_val_180'] -> ['slice_3']
    73/248: Mul ['primals_1', 'mul_2'] -> ['mul_3']
    74/248: Constant [] -> ['view_11']
    75/248: Constant [] -> ['_val_188']
    76/248: Constant [] -> ['_val_192']
    77/248: Constant [] -> ['_val_196']
    78/248: Constant [] -> ['_val_200']
    79/248: Slice ['slice_3', '_val_188', '_val_192', '_val_196', '_val_200'] -> ['slice_4']
    80/248: Constant [] -> ['aten_view_302_size_0']
    81/248: Reshape ['mul_3', 'aten_view_302_size_0'] -> ['view_1']
    82/248: Constant [] -> ['aten_expand_305_size_1']
    83/248: Expand ['slice_4', 'aten_expand_305_size_1'] -> ['expand_1']
    84/248: MatMul ['view_1', 't'] -> ['mm']
    85/248: MatMul ['view_1', 't_1'] -> ['mm_1']
    86/248: MatMul ['view_1', 't_2'] -> ['mm_2']
    87/248: MatMul ['expand_3', 'view_11'] -> ['view_12']
    88/248: Constant [] -> ['aten_view_313_size_0']
    89/248: Reshape ['mm', 'aten_view_313_size_0'] -> ['view_2']
    90/248: Constant [] -> ['aten_view_315_size_0']
    91/248: Reshape ['mm_1', 'aten_view_315_size_0'] -> ['view_4']
    92/248: Constant [] -> ['aten_view_317_size_0']
    93/248: Reshape ['mm_2', 'aten_view_317_size_0'] -> ['view_6']
    94/248: Transpose ['view_12'] -> ['transpose_3']
    95/248: Constant [] -> ['aten_view_321_size_0']
    96/248: Reshape ['view_2', 'aten_view_321_size_0'] -> ['view_7']
    97/248: Constant [] -> ['aten_view_323_size_0']
    98/248: Reshape ['view_4', 'aten_view_323_size_0'] -> ['view_8']
    99/248: Constant [] -> ['aten_view_325_size_0']
    100/248: Reshape ['view_6', 'aten_view_325_size_0'] -> ['view_9']
    101/248: Concat ['transpose_3', 'transpose_3'] -> ['cat']
    102/248: Constant [] -> ['_val_229']
    103/248: Equal ['expand_1', '_val_229'] -> ['eq']
    104/248: Transpose ['view_7'] -> ['transpose']
    105/248: Transpose ['view_8'] -> ['transpose_1']
    106/248: Transpose ['view_9'] -> ['transpose_2']
    107/248: Cos ['cat'] -> ['cos']
    108/248: Sin ['cat'] -> ['sin']
    109/248: And ['eq', 'eq_1'] -> ['mul_1']
    110/248: Constant [] -> ['_val_240']
    111/248: Constant [] -> ['_val_244']
    112/248: Constant [] -> ['_val_248']
    113/248: Constant [] -> ['_val_252']
    114/248: Slice ['transpose', '_val_240', '_val_244', '_val_248', '_val_252'] -> ['slice_10']
    115/248: Constant [] -> ['_val_257']
    116/248: Constant [] -> ['_val_261']
    117/248: Constant [] -> ['_val_265']
    118/248: Constant [] -> ['_val_269']
    119/248: Slice ['transpose', '_val_257', '_val_261', '_val_265', '_val_269'] -> ['slice_11']
    120/248: Constant [] -> ['_val_274']
    121/248: Constant [] -> ['_val_278']
    122/248: Constant [] -> ['_val_282']
    123/248: Constant [] -> ['_val_286']
    124/248: Slice ['transpose_1', '_val_274', '_val_278', '_val_282', '_val_286'] -> ['slice_12']
    125/248: Constant [] -> ['_val_291']
    126/248: Constant [] -> ['_val_295']
    127/248: Constant [] -> ['_val_299']
    128/248: Constant [] -> ['_val_303']
    129/248: Slice ['transpose_1', '_val_291', '_val_295', '_val_299', '_val_303'] -> ['slice_13']
    130/248: Constant [] -> ['aten_expand_405_size_1']
    131/248: Expand ['transpose_2', 'aten_expand_405_size_1'] -> ['expand_8']
    132/248: Constant [] -> ['aten_unsqueeze_406_dim_0']
    133/248: Unsqueeze ['cos', 'aten_unsqueeze_406_dim_0'] -> ['unsqueeze_10']
    134/248: Constant [] -> ['aten_unsqueeze_407_dim_0']
    135/248: Unsqueeze ['sin', 'aten_unsqueeze_407_dim_0'] -> ['unsqueeze_11']
    136/248: Constant [] -> ['_val_309']
    137/248: Where ['mul_1', '_val_309', 'expand_1'] -> ['masked_fill']
    138/248: Neg ['slice_11'] -> ['neg']
    139/248: Neg ['slice_13'] -> ['neg_1']
    140/248: Mul ['transpose', 'unsqueeze_10'] -> ['mul_4']
    141/248: Mul ['transpose_1', 'unsqueeze_10'] -> ['mul_6']
    142/248: Concat ['neg', 'slice_10'] -> ['cat_1']
    143/248: Concat ['neg_1', 'slice_12'] -> ['cat_2']
    144/248: Constant [] -> ['aten_view_421_size_0']
    145/248: Reshape ['expand_8', 'aten_view_421_size_0'] -> ['view_17']
    146/248: Constant [] -> ['_val_326']
    147/248: Constant [] -> ['_val_330']
    148/248: Constant [] -> ['_val_334']
    149/248: Constant [] -> ['_val_338']
    150/248: Slice ['masked_fill', '_val_326', '_val_330', '_val_334', '_val_338'] -> ['slice_17']
    151/248: Mul ['cat_1', 'unsqueeze_11'] -> ['mul_5']
    152/248: Mul ['cat_2', 'unsqueeze_11'] -> ['mul_7']
    153/248: Transpose ['view_17'] -> ['transpose_8']
    154/248: Constant [] -> ['_val_346']
    155/248: Constant [] -> ['_val_350']
    156/248: Constant [] -> ['_val_354']
    157/248: Constant [] -> ['_val_358']
    158/248: Slice ['slice_17', '_val_346', '_val_350', '_val_354', '_val_358'] -> ['slice_18']
    159/248: Constant [] -> ['alpha__2']
    160/248: Mul ['mul_5', 'alpha__2'] -> ['other_1__2']
    161/248: Add ['mul_4', 'other_1__2'] -> ['add_1']
    162/248: Constant [] -> ['alpha__3']
    163/248: Mul ['mul_7', 'alpha__3'] -> ['other_1__3']
    164/248: Add ['mul_6', 'other_1__3'] -> ['add_2']
    165/248: Constant [] -> ['_val_365']
    166/248: Constant [] -> ['_val_369']
    167/248: Constant [] -> ['_val_373']
    168/248: Constant [] -> ['_val_377']
    169/248: Slice ['slice_18', '_val_365', '_val_369', '_val_373', '_val_377'] -> ['slice_19']
    170/248: Constant [] -> ['aten_expand_479_size_1']
    171/248: Expand ['add_1', 'aten_expand_479_size_1'] -> ['expand_5']
    172/248: Transpose ['add_2'] -> ['transpose_4']
    173/248: Constant [] -> ['aten_expand_483_size_1']
    174/248: Expand ['transpose_4', 'aten_expand_483_size_1'] -> ['expand_6']
    175/248: Constant [] -> ['aten_view_485_size_0']
    176/248: Reshape ['expand_5', 'aten_view_485_size_0'] -> ['view_13']
    177/248: Transpose ['view_13'] -> ['transpose_9']
    178/248: Constant [] -> ['aten_view_489_size_0']
    179/248: Reshape ['expand_6', 'aten_view_489_size_0'] -> ['view_14']
    180/248: MatMul ['view_13', 'view_14'] -> ['bmm_1']
    181/248: Transpose ['view_14'] -> ['transpose_10']
    182/248: Constant [] -> ['aten_view_493_size_0']
    183/248: Reshape ['bmm_1', 'aten_view_493_size_0'] -> ['view_15']
    184/248: Constant [] -> ['_val_395']
    185/248: Div ['view_15', '_val_395'] -> ['div']
    186/248: Constant [] -> ['alpha__4']
    187/248: Mul ['slice_19', 'alpha__4'] -> ['other_1__4']
    188/248: Add ['div', 'other_1__4'] -> ['add_3']
    189/248: Softmax ['add_3'] -> ['_softmax']
    190/248: Constant [] -> ['aten_expand_502_size_1']
    191/248: Expand ['_softmax', 'aten_expand_502_size_1'] -> ['expand_7']
    192/248: Constant [] -> ['aten_view_505_size_0']
    193/248: Reshape ['expand_7', 'aten_view_505_size_0'] -> ['view_16']
    194/248: Identity ['_softmax'] -> ['detach_13']
    195/248: MatMul ['view_16', 'view_17'] -> ['bmm_2']
    196/248: Transpose ['view_16'] -> ['transpose_7']
    197/248: Constant [] -> ['aten_view_510_size_0']
    198/248: Reshape ['bmm_2', 'aten_view_510_size_0'] -> ['view_18']
    199/248: Transpose ['view_18'] -> ['transpose_5']
    200/248: Constant [] -> ['aten_view_514_size_0']
    201/248: Reshape ['transpose_5', 'aten_view_514_size_0'] -> ['view_19']
    202/248: Constant [] -> ['aten_view_516_size_0']
    203/248: Reshape ['view_19', 'aten_view_516_size_0'] -> ['view_20']
    204/248: MatMul ['view_20', 't_3'] -> ['mm_3']
    205/248: Constant [] -> ['aten_view_519_size_0']
    206/248: Reshape ['mm_3', 'aten_view_519_size_0'] -> ['view_21']
    207/248: Constant [] -> ['alpha__5']
    208/248: Mul ['view_21', 'alpha__5'] -> ['other_1__5']
    209/248: Add ['embedding', 'other_1__5'] -> ['add_4']
    210/248: Constant [] -> ['scalar_tensor_default_1']
    211/248: Pow ['add_4', 'scalar_tensor_default_1'] -> ['pow_2']
    212/248: Constant [] -> ['_val_425']
    213/248: ReduceMean ['pow_2', '_val_425'] -> ['mean_1']
    214/248: Constant [] -> ['aten_add_527_other_1']
    215/248: Add ['mean_1', 'aten_add_527_other_1'] -> ['add_5']
    216/248: Sqrt ['add_5'] -> ['aten_rsqrt_528_tmp']
    217/248: Reciprocal ['aten_rsqrt_528_tmp'] -> ['rsqrt_1']
    218/248: Mul ['add_4', 'rsqrt_1'] -> ['mul_8']
    219/248: Mul ['primals_2', 'mul_8'] -> ['mul_9']
    220/248: Constant [] -> ['aten_view_532_size_0']
    221/248: Reshape ['mul_9', 'aten_view_532_size_0'] -> ['view_22']
    222/248: MatMul ['view_22', 't_4'] -> ['mm_4']
    223/248: MatMul ['view_22', 't_5'] -> ['mm_5']
    224/248: Constant [] -> ['aten_view_536_size_0']
    225/248: Reshape ['mm_4', 'aten_view_536_size_0'] -> ['view_23']
    226/248: Constant [] -> ['aten_view_538_size_0']
    227/248: Reshape ['mm_5', 'aten_view_538_size_0'] -> ['view_25']
    228/248: Sigmoid ['view_23'] -> ['sigmoid']
    229/248: Mul ['view_23', 'sigmoid'] -> ['mul_10']
    230/248: Mul ['mul_10', 'view_25'] -> ['mul_11']
    231/248: Constant [] -> ['aten_view_543_size_0']
    232/248: Reshape ['mul_11', 'aten_view_543_size_0'] -> ['view_26']
    233/248: MatMul ['view_26', 't_6'] -> ['mm_6']
    234/248: Constant [] -> ['aten_view_546_size_0']
    235/248: Reshape ['mm_6', 'aten_view_546_size_0'] -> ['view_27']
    236/248: Constant [] -> ['alpha__6']
    237/248: Mul ['view_27', 'alpha__6'] -> ['other_1__6']
    238/248: Add ['add_4', 'other_1__6'] -> ['add_6']
    239/248: Constant [] -> ['scalar_tensor_default_2']
    240/248: Pow ['add_6', 'scalar_tensor_default_2'] -> ['pow_3']
    241/248: Constant [] -> ['_val_452']
    242/248: ReduceMean ['pow_3', '_val_452'] -> ['mean_2']
    243/248: Constant [] -> ['aten_add_554_other_1']
    244/248: Add ['mean_2', 'aten_add_554_other_1'] -> ['add_7']
    245/248: Sqrt ['add_7'] -> ['aten_rsqrt_555_tmp']
    246/248: Reciprocal ['aten_rsqrt_555_tmp'] -> ['rsqrt_2']
    247/248: Mul ['add_6', 'rsqrt_2'] -> ['mul_12']
    248/248: Mul ['primals_3', 'mul_12'] -> ['mul_13']
    [runpythonerror]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
      warnings.warn(
    2024-05-08 14:07:16,000 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-08 14:07:16,000 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue full due to large size 4194304.
    2024-05-08 14:07:16,025 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue triu due to large size 4194304.
    2024-05-08 14:07:16,062 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue convert_element_type_default due to large size 4194304.
    2024-05-08 14:07:16,063 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue mul due to large size 4194304.
    2024-05-08 14:07:16,078 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-08 14:07:16,078 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_3 due to large size 4194304.
    2024-05-08 14:07:16,081 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-08 14:07:16,081 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_4 due to large size 4194304.
    2024-05-08 14:07:16,088 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_3 due to large size 4194304.
    2024-05-08 14:07:16,093 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_4 due to large size 4194304.
    2024-05-08 14:07:16,098 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 8388608.
    2024-05-08 14:07:16,099 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue expand_1 due to large size 8388608.
    2024-05-08 14:07:16,103 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue clone due to large size 8388608.
    2024-05-08 14:07:16,109 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue alias due to large size 8388608.
    2024-05-08 14:07:16,112 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue eq due to large size 2097152.
    2024-05-08 14:07:16,266 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue return_val due to large size 4194304.
    2024-05-08 14:07:16,266 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue full due to large size 4194304.
    2024-05-08 14:07:16,270 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue triu due to large size 4194304.
    2024-05-08 14:07:16,275 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue convert_element_type_default due to large size 4194304.
    2024-05-08 14:07:16,276 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue mul due to large size 4194304.
    2024-05-08 14:07:16,278 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_3 due to large size 4194304.
    2024-05-08 14:07:16,279 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue unsqueeze_4 due to large size 4194304.
    2024-05-08 14:07:16,280 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_3 due to large size 4194304.
    2024-05-08 14:07:16,281 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue slice_4 due to large size 4194304.
    2024-05-08 14:07:16,283 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue expand_1 due to large size 8388608.
    2024-05-08 14:07:16,286 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue eq due to large size 2097152.
    2024-05-08 14:07:16.373616600 [W:onnxruntime:, graph.cc:4051 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_23'. It is not used by any node and should be removed from the model.
    2024-05-08 14:07:16.373665800 [W:onnxruntime:, graph.cc:4051 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_22'. It is not used by any node and should be removed from the model.