experimental_experiment.torch_interpreter.oxs_dispatcher¶
- class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDebugDispatcher(verbose: int = 0, raise_exc: bool = True)[source]¶
Tries the fallback even if is not necessary to check it is working.
- Parameters:
verbose – verbosity
raise_exc – fail or raise an exception
The class can be used the following way.
<<<
import torch from experimental_experiment.torch_models.llama_helper import get_llama_model from experimental_experiment.xbuilder import OptimizationOptions from experimental_experiment.torch_interpreter import to_onnx from experimental_experiment.torch_interpreter.oxs_dispatcher import ( OxsDebugDispatcher, ) with torch.no_grad(): model, input_tensors = get_llama_model() input_tensors = input_tensors[0] to_onnx( model, input_tensors, input_names=[f"input{i}" for i in range(len(input_tensors))], options=OptimizationOptions(patterns=None), verbose=0, dispatcher=OxsDebugDispatcher(verbose=2, raise_exc=False), )
>>>
[runpythonerror] Traceback (most recent call last): File "<stdin>", line 17, in <module> File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 793, in to_onnx graph_module, builder, interpreter, mask_outputs = _make_builder_interpreter( File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 469, in _make_builder_interpreter exported_program = export_options.export( File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/export_options.py", line 380, in export modified = CustomTracer().remove_inplace(exported_program.graph) File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/tracing.py", line 556, in remove_inplace assert node.target in { AssertionError: Unsupported target <OpOverload(op='aten.copy_', overload='default')> at position 45/135 --graph graph(): %p_model_layers_0_input_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_input_layernorm_weight] %p_model_layers_0_post_attention_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_post_attention_layernorm_weight] %p_model_norm_weight : [num_users=1] = placeholder[target=p_model_norm_weight] %p_model_embed_tokens_weight : [num_users=1] = placeholder[target=p_model_embed_tokens_weight] %p_model_layers_0_self_attn_q_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_q_proj_weight] %p_model_layers_0_self_attn_k_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_k_proj_weight] %p_model_layers_0_self_attn_v_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_v_proj_weight] %p_model_layers_0_self_attn_o_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_o_proj_weight] %p_model_layers_0_mlp_gate_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_gate_proj_weight] %p_model_layers_0_mlp_up_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_up_proj_weight] %p_model_layers_0_mlp_down_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_down_proj_weight] %b_model_rotary_emb_inv_freq : [num_users=1] = placeholder[target=b_model_rotary_emb_inv_freq] %input_ids : [num_users=1] = placeholder[target=input_ids] %attention_mask : [num_users=1] = placeholder[target=attention_mask] %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_model_embed_tokens_weight, %input_ids), kwargs = {}) %arange : [num_users=2] = call_function[target=torch.ops.aten.arange.start](args = (0, 8), kwargs = {device: cpu, pin_memory: False}) %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arange, 0), kwargs = {}) %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 8], -3.4028234663852886e+38), kwargs = {dtype: torch.float32, device: cpu, pin_memory: False}) %triu : [num_users=1] = call_function[target=torch.ops.aten.triu.default](args = (%full, 1), kwargs = {}) %arange_1 : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (8,), kwargs = {device: cpu, pin_memory: False}) %reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%arange, [-1, 1]), kwargs = {}) %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Tensor](args = (%arange_1, %reshape), kwargs = {}) %mul_ : [num_users=1] = call_function[target=torch.ops.aten.mul_.Tensor](args = (%triu, %gt), kwargs = {}) %unsqueeze_1 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%mul_, 0), kwargs = {}) %unsqueeze_2 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_1, 1), kwargs = {}) %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%unsqueeze_2, 2, 0, 9223372036854775807), kwargs = {}) %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 3, 0, 9223372036854775807), kwargs = {}) %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%slice_2, [2, 1, -1, -1]), kwargs = {}) %clone : [num_users=4] = call_function[target=torch.ops.aten.clone.default](args = (%expand,), kwargs = {}) %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%clone, 0, 0, 9223372036854775807), kwargs = {}) %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 9223372036854775807), kwargs = {}) %slice_5 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_4, 2, 0, 9223372036854775807), kwargs = {}) %slice_6 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%attention_mask, 0, 0, 9223372036854775807), kwargs = {}) %unsqueeze_3 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%slice_6, 1), kwargs = {}) %unsqueeze_4 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_3, 2), kwargs = {}) %slice_7 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%unsqueeze_4, 3, 0, 9223372036854775807), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_5, %slice_7), kwargs = {}) %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%add, 0), kwargs = {}) %slice_8 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%clone, 0, 0, 9223372036854775807), kwargs = {}) %slice_9 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_8, 1, 0, 9223372036854775807), kwargs = {}) %slice_10 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_9, 2, 0, 9223372036854775807), kwargs = {}) %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%slice_10, %eq, -3.4028234663852886e+38), kwargs = {}) %slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%clone, 0, 0, 9223372036854775807), kwargs = {}) %slice_12 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_11, 1, 0, 9223372036854775807), kwargs = {}) %slice_13 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_12, 2, 0, 9223372036854775807), kwargs = {}) %copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_13, %masked_fill), kwargs = {}) %unsqueeze_5 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%b_model_rotary_emb_inv_freq, 0), kwargs = {}) %slice_14 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%unsqueeze_5, 1, 0, 9223372036854775807), kwargs = {}) %unsqueeze_6 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%slice_14, 2), kwargs = {}) %to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%unsqueeze_6, torch.float32), kwargs = {}) %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%to, [1, -1, 1]), kwargs = {}) %slice_15 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%unsqueeze, 0, 0, 9223372036854775807), kwargs = {}) %unsqueeze_7 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%slice_15, 1), kwargs = {}) %slice_16 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%unsqueeze_7, 2, 0, 9223372036854775807), kwargs = {}) %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%slice_16, torch.float32), kwargs = {}) %submod_3 : [num_users=1] = get_attr[target=submod_1] %wrap_with_autocast : [num_users=2] = call_function[target=torch.ops.higher_order.wrap_with_autocast](args = (cpu, None, False, None, %submod_3, %expand_1, %to_1), kwargs = {}) %cos : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_autocast, 0), kwargs = {}) %sin : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_autocast, 1), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cos, 1.0), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sin, 1.0), kwargs = {}) %to_4 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul, torch.float32), kwargs = {}) %to_5 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_1, torch.float32), kwargs = {}) %to_6 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%embedding, torch.float32), kwargs = {}) %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_6, 2), kwargs = {}) %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-06), kwargs = {}) %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_1,), kwargs = {}) %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_6, %rsqrt), kwargs = {}) %to_7 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_2, torch.float32), kwargs = {}) %mul_3 : [num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_input_layernorm_weight, %to_7), kwargs = {}) %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_q_proj_weight), kwargs = {}) %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_k_proj_weight), kwargs = {}) %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_v_proj_weight), kwargs = {}) %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear, [2, 8, 2, 8]), kwargs = {}) %transpose_1 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view, 1, 2), kwargs = {}) %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_1, [2, 8, 2, 8]), kwargs = {}) %transpose_2 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view_1, 1, 2), kwargs = {}) %view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_2, [2, 8, 2, 8]), kwargs = {}) %transpose_3 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%view_2, 1, 2), kwargs = {}) %unsqueeze_8 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_4, 1), kwargs = {}) %unsqueeze_9 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_5, 1), kwargs = {}) %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_1, %unsqueeze_8), kwargs = {}) %slice_17 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 0, 4), kwargs = {}) %slice_18 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 4, 9223372036854775807), kwargs = {}) %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_18,), kwargs = {}) %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_17], -1), kwargs = {}) %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_9), kwargs = {}) %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {}) %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_2, %unsqueeze_8), kwargs = {}) %slice_19 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 0, 4), kwargs = {}) %slice_20 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 4, 9223372036854775807), kwargs = {}) %neg_1 : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_20,), kwargs = {}) %cat_2 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_19], -1), kwargs = {}) %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %unsqueeze_9), kwargs = {}) %add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {}) %transpose_4 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%add_3, 2, 3), kwargs = {}) %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%add_2, %transpose_4), kwargs = {}) %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%matmul_1, 2.8284271247461903), kwargs = {}) %slice_21 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%clone, 0, 0, 9223372036854775807), kwargs = {}) %slice_22 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_21, 1, 0, 9223372036854775807), kwargs = {}) %slice_23 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_22, 2, 0, 9223372036854775807), kwargs = {}) %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %slice_23), kwargs = {}) %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%add_4, -1, torch.float32), kwargs = {}) %to_8 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%softmax, torch.float32), kwargs = {}) %dropout : [num_users=1] = call_function[target=torch.ops.aten.dropout.default](args = (%to_8, 0.0, True), kwargs = {}) %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%dropout, %transpose_3), kwargs = {}) %transpose_5 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%matmul_2, 1, 2), kwargs = {}) %contiguous : [num_users=1] = call_function[target=torch.ops.aten.contiguous.default](args = (%transpose_5,), kwargs = {}) %reshape_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%contiguous, [2, 8, -1]), kwargs = {}) %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%reshape_1, %p_model_layers_0_self_attn_o_proj_weight), kwargs = {}) %add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_6, %linear_3), kwargs = {}) %to_9 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%add_5, torch.float32), kwargs = {}) %pow_2 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_9, 2), kwargs = {}) %mean_1 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {}) %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-06), kwargs = {}) %rsqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_6,), kwargs = {}) %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_9, %rsqrt_1), kwargs = {}) %to_10 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_8, torch.float32), kwargs = {}) %mul_9 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_post_attention_layernorm_weight, %to_10), kwargs = {}) %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_9, %p_model_layers_0_mlp_gate_proj_weight), kwargs = {}) %silu : [num_users=1] = call_function[target=torch.ops.aten.silu.default](args = (%linear_4,), kwargs = {}) %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_9, %p_model_layers_0_mlp_up_proj_weight), kwargs = {}) %mul_10 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%silu, %linear_5), kwargs = {}) %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_10, %p_model_layers_0_mlp_down_proj_weight), kwargs = {}) %add_7 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_9, %linear_6), kwargs = {}) %to_11 : [num_users=2] = call_function[target=torch.ops.aten.to.dtype](args = (%add_7, torch.float32), kwargs = {}) %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_11, 2), kwargs = {}) %mean_2 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {}) %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-06), kwargs = {}) %rsqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_8,), kwargs = {}) %mul_11 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_11, %rsqrt_2), kwargs = {}) %to_12 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_11, torch.float32), kwargs = {}) %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_norm_weight, %to_12), kwargs = {}) return (mul_12,)
- fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None [source]¶
The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.
- Parameters:
name – object or str
fct – function found so far
args – known arguments coming from the graph module
kwargs – known named arguments coming from the graph module
builder – GraphBuilder
- Returns:
callable
- class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDispatcher(verbose: int = 0)[source]¶
If
DynamoInterpreter
cannot find any converting function for a specific function, it tries to find an existing one in onnxscript. The converting function from onnxscript is run in trace only mode. The variable and functions op, Rank, IsScalar are replaced by op = OwsOpset(), op.Rank, op.Scalar. onnxscript may have multiple overloaded functions. Right now, it takes the first one.- Parameters:
verbose – verbose
- fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None [source]¶
The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.
- Parameters:
name – object or str
fct – function found so far
args – known arguments coming from the graph module
kwargs – known named arguments coming from the graph module
builder – GraphBuilder
- Returns:
callable