experimental_experiment.torch_interpreter.oxs_dispatcher

class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDebugDispatcher(verbose: int = 0, raise_exc: bool = True)[source]

Tries the fallback even if is not necessary to check it is working.

Parameters:
  • verbose – verbosity

  • raise_exc – fail or raise an exception

The class can be used the following way.

<<<

import torch
from experimental_experiment.torch_models.llama_helper import get_llama_model
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.torch_interpreter.oxs_dispatcher import (
    OxsDebugDispatcher,
)

with torch.no_grad():
    model, input_tensors = get_llama_model()
    input_tensors = input_tensors[0]

    to_onnx(
        model,
        input_tensors,
        input_names=[f"input{i}" for i in range(len(input_tensors))],
        options=OptimizationOptions(patterns=None),
        verbose=0,
        dispatcher=OxsDebugDispatcher(verbose=2, raise_exc=False),
    )

>>>

    
    [runpythonerror]
    Traceback (most recent call last):
      File "<stdin>", line 17, in <module>
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 793, in to_onnx
        graph_module, builder, interpreter, mask_outputs = _make_builder_interpreter(
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 469, in _make_builder_interpreter
        exported_program = export_options.export(
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/export_options.py", line 380, in export
        modified = CustomTracer().remove_inplace(exported_program.graph)
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/tracing.py", line 556, in remove_inplace
        assert node.target in {
    AssertionError: Unsupported target <OpOverload(op='aten.copy_', overload='default')> at position 45/135
    --graph
    graph():
        %p_model_layers_0_input_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_input_layernorm_weight]
        %p_model_layers_0_post_attention_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_post_attention_layernorm_weight]
        %p_model_norm_weight : [num_users=1] = placeholder[target=p_model_norm_weight]
        %p_model_embed_tokens_weight : [num_users=1] = placeholder[target=p_model_embed_tokens_weight]
        %p_model_layers_0_self_attn_q_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_q_proj_weight]
        %p_model_layers_0_self_attn_k_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_k_proj_weight]
        %p_model_layers_0_self_attn_v_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_v_proj_weight]
        %p_model_layers_0_self_attn_o_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_o_proj_weight]
        %p_model_layers_0_mlp_gate_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_gate_proj_weight]
        %p_model_layers_0_mlp_up_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_up_proj_weight]
        %p_model_layers_0_mlp_down_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_down_proj_weight]
        %b_model_rotary_emb_inv_freq : [num_users=1] = placeholder[target=b_model_rotary_emb_inv_freq]
        %input_ids : [num_users=1] = placeholder[target=input_ids]
        %attention_mask : [num_users=1] = placeholder[target=attention_mask]
        %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_model_embed_tokens_weight, %input_ids), kwargs = {})
        %arange : [num_users=2] = call_function[target=torch.ops.aten.arange.start](args = (0, 8), kwargs = {device: cpu, pin_memory: False})
        %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arange, 0), kwargs = {})
        %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 8], -3.4028234663852886e+38), kwargs = {dtype: torch.float32, device: cpu, pin_memory: False})
        %triu : [num_users=1] = call_function[target=torch.ops.aten.triu.default](args = (%full, 1), kwargs = {})
        %arange_1 : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (8,), kwargs = {device: cpu, pin_memory: False})
        %reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%arange, [-1, 1]), kwargs = {})
        %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Tensor](args = (%arange_1, %reshape), kwargs = {})
        %mul_ : [num_users=1] = call_function[target=torch.ops.aten.mul_.Tensor](args = (%triu, %gt), kwargs = {})
        %unsqueeze_1 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%mul_, 0), kwargs = {})
        %unsqueeze_2 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_1, 1), kwargs = {})
        %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%unsqueeze_2, 2, 0, 9223372036854775807), kwargs = {})
        %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 3, 0, 9223372036854775807), kwargs = {})
        %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%slice_2, [2, 1, -1, -1]), kwargs = {})
        %clone : [num_users=4] = call_function[target=torch.ops.aten.clone.default](args = (%expand,), kwargs = {})
        %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
        %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 9223372036854775807), kwargs = {})
        %slice_5 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_4, 2, 0, 9223372036854775807), kwargs = {})
        %slice_6 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%attention_mask, 0, 0, 9223372036854775807), kwargs = {})
        %unsqueeze_3 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%slice_6, 1), kwargs = {})
        %unsqueeze_4 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_3, 2), kwargs = {})
        %slice_7 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%unsqueeze_4, 3, 0, 9223372036854775807), kwargs = {})
        %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_5, %slice_7), kwargs = {})
        %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%add, 0), kwargs = {})
        %slice_8 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
        %slice_9 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_8, 1, 0, 9223372036854775807), kwargs = {})
        %slice_10 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_9, 2, 0, 9223372036854775807), kwargs = {})
        %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%slice_10, %eq, -3.4028234663852886e+38), kwargs = {})
        %slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
        %slice_12 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_11, 1, 0, 9223372036854775807), kwargs = {})
        %slice_13 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_12, 2, 0, 9223372036854775807), kwargs = {})
        %copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_13, %masked_fill), kwargs = {})
        %unsqueeze_5 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%b_model_rotary_emb_inv_freq, 0), kwargs = {})
        %slice_14 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%unsqueeze_5, 1, 0, 9223372036854775807), kwargs = {})
        %unsqueeze_6 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%slice_14, 2), kwargs = {})
        %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%unsqueeze_6, torch.float32), kwargs = {})
        %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%to, [1, -1, 1]), kwargs = {})
        %slice_15 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%unsqueeze, 0, 0, 9223372036854775807), kwargs = {})
        %unsqueeze_7 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%slice_15, 1), kwargs = {})
        %slice_16 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%unsqueeze_7, 2, 0, 9223372036854775807), kwargs = {})
        %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%slice_16, torch.float32), kwargs = {})
        %submod_3 : [num_users=1] = get_attr[target=submod_1]
        %wrap_with_autocast : [num_users=2] = call_function[target=torch.ops.higher_order.wrap_with_autocast](args = (cpu, None, False, None, %submod_3, %expand_1, %to_1), kwargs = {})
        %cos : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_autocast, 0), kwargs = {})
        %sin : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_autocast, 1), kwargs = {})
        %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cos, 1.0), kwargs = {})
        %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sin, 1.0), kwargs = {})
        %to_4 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul, torch.float32), kwargs = {})
        %to_5 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_1, torch.float32), kwargs = {})
        %to_6 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%embedding, torch.float32), kwargs = {})
        %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_6, 2), kwargs = {})
        %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
        %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-06), kwargs = {})
        %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_1,), kwargs = {})
        %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_6, %rsqrt), kwargs = {})
        %to_7 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_2, torch.float32), kwargs = {})
        %mul_3 : [num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_input_layernorm_weight, %to_7), kwargs = {})
        %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_q_proj_weight), kwargs = {})
        %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_k_proj_weight), kwargs = {})
        %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_v_proj_weight), kwargs = {})
        %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear, [2, 8, 2, 8]), kwargs = {})
        %transpose_1 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view, 1, 2), kwargs = {})
        %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_1, [2, 8, 2, 8]), kwargs = {})
        %transpose_2 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view_1, 1, 2), kwargs = {})
        %view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_2, [2, 8, 2, 8]), kwargs = {})
        %transpose_3 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%view_2, 1, 2), kwargs = {})
        %unsqueeze_8 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_4, 1), kwargs = {})
        %unsqueeze_9 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_5, 1), kwargs = {})
        %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_1, %unsqueeze_8), kwargs = {})
        %slice_17 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 0, 4), kwargs = {})
        %slice_18 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 4, 9223372036854775807), kwargs = {})
        %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_18,), kwargs = {})
        %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_17], -1), kwargs = {})
        %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_9), kwargs = {})
        %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
        %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_2, %unsqueeze_8), kwargs = {})
        %slice_19 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 0, 4), kwargs = {})
        %slice_20 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 4, 9223372036854775807), kwargs = {})
        %neg_1 : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_20,), kwargs = {})
        %cat_2 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_19], -1), kwargs = {})
        %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %unsqueeze_9), kwargs = {})
        %add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {})
        %transpose_4 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%add_3, 2, 3), kwargs = {})
        %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%add_2, %transpose_4), kwargs = {})
        %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%matmul_1, 2.8284271247461903), kwargs = {})
        %slice_21 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
        %slice_22 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_21, 1, 0, 9223372036854775807), kwargs = {})
        %slice_23 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_22, 2, 0, 9223372036854775807), kwargs = {})
        %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %slice_23), kwargs = {})
        %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%add_4, -1, torch.float32), kwargs = {})
        %to_8 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%softmax, torch.float32), kwargs = {})
        %dropout : [num_users=1] = call_function[target=torch.ops.aten.dropout.default](args = (%to_8, 0.0, True), kwargs = {})
        %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%dropout, %transpose_3), kwargs = {})
        %transpose_5 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%matmul_2, 1, 2), kwargs = {})
        %contiguous : [num_users=1] = call_function[target=torch.ops.aten.contiguous.default](args = (%transpose_5,), kwargs = {})
        %reshape_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%contiguous, [2, 8, -1]), kwargs = {})
        %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%reshape_1, %p_model_layers_0_self_attn_o_proj_weight), kwargs = {})
        %add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_6, %linear_3), kwargs = {})
        %to_9 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%add_5, torch.float32), kwargs = {})
        %pow_2 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_9, 2), kwargs = {})
        %mean_1 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
        %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-06), kwargs = {})
        %rsqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_6,), kwargs = {})
        %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_9, %rsqrt_1), kwargs = {})
        %to_10 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_8, torch.float32), kwargs = {})
        %mul_9 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_post_attention_layernorm_weight, %to_10), kwargs = {})
        %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_9, %p_model_layers_0_mlp_gate_proj_weight), kwargs = {})
        %silu : [num_users=1] = call_function[target=torch.ops.aten.silu.default](args = (%linear_4,), kwargs = {})
        %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_9, %p_model_layers_0_mlp_up_proj_weight), kwargs = {})
        %mul_10 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%silu, %linear_5), kwargs = {})
        %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_10, %p_model_layers_0_mlp_down_proj_weight), kwargs = {})
        %add_7 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_9, %linear_6), kwargs = {})
        %to_11 : [num_users=2] = call_function[target=torch.ops.aten.to.dtype](args = (%add_7, torch.float32), kwargs = {})
        %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_11, 2), kwargs = {})
        %mean_2 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
        %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-06), kwargs = {})
        %rsqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_8,), kwargs = {})
        %mul_11 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_11, %rsqrt_2), kwargs = {})
        %to_12 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_11, torch.float32), kwargs = {})
        %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_norm_weight, %to_12), kwargs = {})
        return (mul_12,)
fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDispatcher(verbose: int = 0)[source]

If DynamoInterpreter cannot find any converting function for a specific function, it tries to find an existing one in onnxscript. The converting function from onnxscript is run in trace only mode. The variable and functions op, Rank, IsScalar are replaced by op = OwsOpset(), op.Rank, op.Scalar. onnxscript may have multiple overloaded functions. Right now, it takes the first one.

Parameters:

verbose – verbose

fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

property submodules: Dict[str, Callable]

Returns the submodules implementing torch functions.