.torch_interpreter.oxs_dispatcher¶
- class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDebugDispatcher(verbose: int = 0, raise_exc: bool = True)[source]¶
Tries the fallback even if is not necessary to check it is working.
- Parameters:
verbose – verbosity
raise_exc – fail or raise an exception
The class can be used the following way.
<<<
import torch from experimental_experiment.torch_models.llama_helper import get_llama_model from experimental_experiment.xbuilder import OptimizationOptions from experimental_experiment.torch_interpreter import to_onnx from experimental_experiment.torch_interpreter.oxs_dispatcher import ( OxsDebugDispatcher, ) with torch.no_grad(): model, input_tensors = get_llama_model() input_tensors = input_tensors[0] to_onnx( model, input_tensors, input_names=[f"input{i}" for i in range(len(input_tensors))], options=OptimizationOptions(patterns=None), verbose=0, dispatcher=OxsDebugDispatcher(verbose=2, raise_exc=False), )
>>>
[runpythonerror] Traceback (most recent call last): File "<stdin>", line 19, in <module> File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 1064, in to_onnx builder.process(graph_module, interpreter) File "~/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py", line 5173, in process interpreter.run_node(node) File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 229, in run_node res = self.get_attr(node) ^^^^^^^^^^^^^^^^^^^ File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 336, in get_attr raise NotImplementedError( NotImplementedError: Unable to handle type <class 'torch.utils._pytree.TreeSpec'> for node.name='function_const_func_spec0' Node(type=<class 'torch._higher_order_ops.flat_apply._ConstantFunction'>, leaves=0) -- --DEBUG-- [GraphBuilder-YHC] Message starts, there are 24 initializers, 2 nodes, 2 inputs, 2 outputs. --SHAPE-- _dynamic_examples= dynamic_objects= dynamic_objects_rev= dynamic_dimensions_source={} dynamic_dimensions_source_flat=None output_dynamic_dimensions_source_flat=None dynamic_alias={} dynamic_shapes=None _known_shapes={'attention_mask': (2, 8), 'b_model_rotary_emb_inv_freq': (4,), 'c_model_lifted_tensor_0': (), 'input0': (2, 8), 'input1': (2, 8), 'input_ids': (2, 8), 'model.embed_tokens.weight': (1024, 16), 'model.layers.0.input_layernorm.weight': (16,), 'model.layers.0.mlp.down_proj.weight': (16, 16), 'model.layers.0.mlp.gate_proj.weight': (16, 16), 'model.layers.0.mlp.up_proj.weight': (16, 16), 'model.layers.0.post_attention_layernorm.weight': (16,), 'model.layers.0.self_attn.k_proj.weight': (16, 16), 'model.layers.0.self_attn.o_proj.weight': (16, 16), 'model.layers.0.self_attn.q_proj.weight': (16, 16), 'model.layers.0.self_attn.v_proj.weight': (16, 16), 'model.norm.weight': (16,), 'p_model_embed_tokens_weight': (1024, 16), 'p_model_layers_0_input_layernorm_weight': (16,), 'p_model_layers_0_mlp_down_proj_weight': (16, 16), 'p_model_layers_0_mlp_gate_proj_weight': (16, 16), 'p_model_layers_0_mlp_up_proj_weight': (16, 16), 'p_model_layers_0_post_attention_layernorm_weight': (16,), 'p_model_layers_0_self_attn_k_proj_weight': (16, 16), 'p_model_layers_0_self_attn_o_proj_weight': (16, 16), 'p_model_layers_0_self_attn_q_proj_weight': (16, 16), 'p_model_layers_0_self_attn_v_proj_weight': (16, 16), 'p_model_norm_weight': (16,)} _known_types={'attention_mask': 1, 'b_model_rotary_emb_inv_freq': 1, 'c_model_lifted_tensor_0': 1, 'input0': 7, 'input1': 1, 'input_ids': 7, 'model.embed_tokens.weight': 1, 'model.layers.0.input_layernorm.weight': 1, 'model.layers.0.mlp.down_proj.weight': 1, 'model.layers.0.mlp.gate_proj.weight': 1, 'model.layers.0.mlp.up_proj.weight': 1, 'model.layers.0.post_attention_layernorm.weight': 1, 'model.layers.0.self_attn.k_proj.weight': 1, 'model.layers.0.self_attn.o_proj.weight': 1, 'model.layers.0.self_attn.q_proj.weight': 1, 'model.layers.0.self_attn.v_proj.weight': 1, 'model.norm.weight': 1, 'p_model_embed_tokens_weight': 1, 'p_model_layers_0_input_layernorm_weight': 1, 'p_model_layers_0_mlp_down_proj_weight': 1, 'p_model_layers_0_mlp_gate_proj_weight': 1, 'p_model_layers_0_mlp_up_proj_weight': 1, 'p_model_layers_0_post_attention_layernorm_weight': 1, 'p_model_layers_0_self_attn_k_proj_weight': 1, 'p_model_layers_0_self_attn_o_proj_weight': 1, 'p_model_layers_0_self_attn_q_proj_weight': 1, 'p_model_layers_0_self_attn_v_proj_weight': 1, 'p_model_norm_weight': 1} _known_value_shape={} _known_constants=['b_model_rotary_emb_inv_freq', 'c_model_lifted_tensor_0', 'model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.norm.weight', 'p_model_embed_tokens_weight', 'p_model_layers_0_input_layernorm_weight', 'p_model_layers_0_mlp_down_proj_weight', 'p_model_layers_0_mlp_gate_proj_weight', 'p_model_layers_0_mlp_up_proj_weight', 'p_model_layers_0_post_attention_layernorm_weight', 'p_model_layers_0_self_attn_k_proj_weight', 'p_model_layers_0_self_attn_o_proj_weight', 'p_model_layers_0_self_attn_q_proj_weight', 'p_model_layers_0_self_attn_v_proj_weight', 'p_model_norm_weight'] _known_ranks={} --PARAMETERS-- _parameter_renaming= p_model_embed_tokens_weight = 'model.embed_tokens.weight' p_model_layers_0_input_layernorm_weight = 'model.layers.0.input_layernorm.weight' p_model_layers_0_mlp_down_proj_weight = 'model.layers.0.mlp.down_proj.weight' p_model_layers_0_mlp_gate_proj_weight = 'model.layers.0.mlp.gate_proj.weight' p_model_layers_0_mlp_up_proj_weight = 'model.layers.0.mlp.up_proj.weight' p_model_layers_0_post_attention_layernorm_weight = 'model.layers.0.post_attention_layernorm.weight' p_model_layers_0_self_attn_k_proj_weight = 'model.layers.0.self_attn.k_proj.weight' p_model_layers_0_self_attn_o_proj_weight = 'model.layers.0.self_attn.o_proj.weight' p_model_layers_0_self_attn_q_proj_weight = 'model.layers.0.self_attn.q_proj.weight' p_model_layers_0_self_attn_v_proj_weight = 'model.layers.0.self_attn.v_proj.weight' p_model_norm_weight = 'model.norm.weight' --TORCH-USERS-- attention_mask -> {to} b_model_rotary_emb_inv_freq -> {unsqueeze_1} c_model_lifted_tensor_0 -> {lift_fresh_copy} function_const_func_spec0 -> {flat_apply} input_ids -> {embedding} p_model_embed_tokens_weight -> {embedding} p_model_layers_0_input_layernorm_weight -> {mul_3} p_model_layers_0_mlp_down_proj_weight -> {linear_6} p_model_layers_0_mlp_gate_proj_weight -> {linear_4} p_model_layers_0_mlp_up_proj_weight -> {linear_5} p_model_layers_0_post_attention_layernorm_weight -> {mul_10} p_model_layers_0_self_attn_k_proj_weight -> {linear_1} p_model_layers_0_self_attn_o_proj_weight -> {linear_3} p_model_layers_0_self_attn_q_proj_weight -> {linear} p_model_layers_0_self_attn_v_proj_weight -> {linear_2} p_model_norm_weight -> {mul_13} --TORCH-SHAPES-- p_model_embed_tokens_weight: ('run_node', ('', ('val', torch.float32, torch.Size([1024, 16])))) --- 1:2:(1024, 16): p_model_layers_0_self_attn_q_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16): p_model_layers_0_self_attn_k_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16): p_model_layers_0_self_attn_v_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16): p_model_layers_0_self_attn_o_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16): p_model_layers_0_mlp_gate_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16): p_model_layers_0_mlp_up_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16): p_model_layers_0_mlp_down_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16): p_model_layers_0_input_layernorm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16])))) --- 1:1:(16,): p_model_layers_0_post_attention_layernorm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16])))) --- 1:1:(16,): p_model_norm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16])))) --- 1:1:(16,): b_model_rotary_emb_inv_freq: ('run_node', ('', ('val', torch.float32, torch.Size([4])))) --- 1:1:(4,): c_model_lifted_tensor_0: ('run_node', ('', ('val', torch.float32, torch.Size([])))) --- 1:0:(): input_ids: ('run_node', ('', ('val', torch.int64, torch.Size([2, 8])))) --- 7:2:(2, 8): attention_mask: ('run_node', ('', ('val', torch.float32, torch.Size([2, 8])))) --- 1:2:(2, 8): function_const_func_spec0: ('run_node', ('', '')) --- ::: --ONNX-- -- EXEPATH -- export-export_options=ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.scaled_dot_product_attention.default', 'aten.setitem', <built-in function setitem>)) -- process.graph_module -- ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_model_embed_tokens_weight: "f32[1024, 16]", p_model_layers_0_self_attn_q_proj_weight: "f32[16, 16]", p_model_layers_0_self_attn_k_proj_weight: "f32[16, 16]", p_model_layers_0_self_attn_v_proj_weight: "f32[16, 16]", p_model_layers_0_self_attn_o_proj_weight: "f32[16, 16]", p_model_layers_0_mlp_gate_proj_weight: "f32[16, 16]", p_model_layers_0_mlp_up_proj_weight: "f32[16, 16]", p_model_layers_0_mlp_down_proj_weight: "f32[16, 16]", p_model_layers_0_input_layernorm_weight: "f32[16]", p_model_layers_0_post_attention_layernorm_weight: "f32[16]", p_model_norm_weight: "f32[16]", b_model_rotary_emb_inv_freq: "f32[4]", c_model_lifted_tensor_0: "f32[]", input_ids: "i64[2, 8]", attention_mask: "f32[2, 8]"): # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask( function_const_func_spec0 = self.function_const_func_spec0 torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = self.torch__dynamo__trace_wrapped_higher_order_op_ModIndex0 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding( embedding: "f32[2, 8, 16]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:375 in forward, code: cache_position: torch.Tensor = torch.arange( arange: "i64[8]" = torch.ops.aten.arange.start(0, 8, device = device(type='cpu'), pin_memory = False) # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:380 in forward, code: position_ids = cache_position.unsqueeze(0) unsqueeze: "i64[1, 8]" = torch.ops.aten.unsqueeze.default(arange, 0) # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask( to: "b8[2, 8]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool); attention_mask = None arange_1: "i64[8]" = torch.ops.aten.arange.default(8, device = device(type='cpu'), pin_memory = False) add_: "i64[8]" = torch.ops.aten.add_.Tensor(arange_1, 0); arange_1 = None arange_2: "i64[2]" = torch.ops.aten.arange.default(2, device = device(type='cpu'), pin_memory = False) _add_batch_dim: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange_2, 0, 1); arange_2 = None _add_batch_dim_2: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange, 0, 3); arange = None _add_batch_dim_3: "i64[]" = torch._functorch.predispatch._add_batch_dim(add_, 0, 4); add_ = None new_ones: "b8[]" = torch.ops.aten.new_ones.default(_add_batch_dim_2, [], dtype = torch.bool, pin_memory = False) le: "b8[]" = torch.ops.aten.le.Tensor(_add_batch_dim_3, _add_batch_dim_2); _add_batch_dim_2 = None to_1: "b8[]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); le = None and_1: "b8[]" = torch.ops.aten.__and__.Tensor(new_ones, to_1); new_ones = to_1 = None flat_apply: "b8[]" = torch.ops.higher_order.flat_apply(function_const_func_spec0, torch__dynamo__trace_wrapped_higher_order_op_mod_index0, 'torch._dynamo._trace_wrapped_higher_order_op.ModIndex', to, _add_batch_dim, _add_batch_dim_3); function_const_func_spec0 = torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = to = _add_batch_dim = _add_batch_dim_3 = None to_2: "b8[]" = torch.ops.aten.to.dtype_layout(flat_apply, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); flat_apply = None and_2: "b8[]" = torch.ops.aten.__and__.Tensor(and_1, to_2); and_1 = to_2 = None _remove_batch_dim: "b8[8]" = torch._functorch.predispatch._remove_batch_dim(and_2, 4, 8, 0); and_2 = None _remove_batch_dim_1: "b8[8, 8]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim, 3, 8, 0); _remove_batch_dim = None _remove_batch_dim_2: "b8[1, 8, 8]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_1, 2, 1, 0); _remove_batch_dim_1 = None _remove_batch_dim_3: "b8[2, 1, 8, 8]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_2, 1, 2, 0); _remove_batch_dim_2 = None lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(c_model_lifted_tensor_0); c_model_lifted_tensor_0 = None detach_: "f32[]" = torch.ops.aten.detach_.default(lift_fresh_copy); lift_fresh_copy = None where: "f32[2, 1, 8, 8]" = torch.ops.aten.where.ScalarOther(_remove_batch_dim_3, detach_, -3.4028234663852886e+38); _remove_batch_dim_3 = detach_ = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:96 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) unsqueeze_1: "f32[1, 4]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None unsqueeze_2: "f32[1, 4, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 2); unsqueeze_1 = None to_3: "f32[1, 4, 1]" = torch.ops.aten.to.dtype(unsqueeze_2, torch.float32); unsqueeze_2 = None expand_1: "f32[1, 4, 1]" = torch.ops.aten.expand.default(to_3, [1, -1, 1]); to_3 = None to_4: "f32[1, 4, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_1 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:97 in forward, code: position_ids_expanded = position_ids[:, None, :].float() unsqueeze_3: "i64[1, 1, 8]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1); unsqueeze = None to_5: "f32[1, 1, 8]" = torch.ops.aten.to.dtype(unsqueeze_3, torch.float32); unsqueeze_3 = None # No stacktrace found for following nodes submod_3 = self.submod_1 wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5); submod_3 = to_4 = to_5 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:103 in forward, code: cos = emb.cos() * self.attention_scaling mul: "f32[1, 8, 8]" = wrap_with_autocast[0] # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:104 in forward, code: sin = emb.sin() * self.attention_scaling mul_1: "f32[1, 8, 8]" = wrap_with_autocast[1]; wrap_with_autocast = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:106 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) to_8: "f32[1, 8, 8]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None to_9: "f32[1, 8, 8]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32) to_10: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True) pow_1: "f32[2, 8, 16]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2) mean: "f32[2, 8, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) add: "f32[2, 8, 1]" = torch.ops.aten.add.Tensor(mean, 1e-06); mean = None rsqrt: "f32[2, 8, 1]" = torch.ops.aten.rsqrt.default(add); add = None mul_2: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(to_10, rsqrt); rsqrt = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype) to_11: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None mul_3: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11); p_model_layers_0_input_layernorm_weight = to_11 = None # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) view: "f32[2, 8, 2, 8]" = torch.ops.aten.view.default(linear, [2, 8, -1, 8]); linear = None transpose_1: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(view, 1, 2); view = None # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear_1: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) view_1: "f32[2, 8, 2, 8]" = torch.ops.aten.view.default(linear_1, [2, 8, -1, 8]); linear_1 = None transpose_2: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear_2: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) view_2: "f32[2, 8, 2, 8]" = torch.ops.aten.view.default(linear_2, [2, 8, -1, 8]); linear_2 = None transpose_3: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) unsqueeze_4: "f32[1, 1, 8, 8]" = torch.ops.aten.unsqueeze.default(to_8, 1); to_8 = None unsqueeze_5: "f32[1, 1, 8, 8]" = torch.ops.aten.unsqueeze.default(to_9, 1); to_9 = None mul_4: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_4) slice_1: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 4) slice_2: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 4, 9223372036854775807); transpose_1 = None neg: "f32[2, 2, 8, 4]" = torch.ops.aten.neg.default(slice_2); slice_2 = None cat_1: "f32[2, 2, 8, 8]" = torch.ops.aten.cat.default([neg, slice_1], -1); neg = slice_1 = None mul_5: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_5); cat_1 = None add_1: "f32[2, 2, 8, 8]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None mul_6: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_4); unsqueeze_4 = None slice_3: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 4) slice_4: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 4, 9223372036854775807); transpose_2 = None neg_1: "f32[2, 2, 8, 4]" = torch.ops.aten.neg.default(slice_4); slice_4 = None cat_2: "f32[2, 2, 8, 8]" = torch.ops.aten.cat.default([neg_1, slice_3], -1); neg_1 = slice_3 = None mul_7: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_5); cat_2 = unsqueeze_5 = None add_2: "f32[2, 2, 8, 8]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface( transpose_4: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(add_2, 2, 3); add_2 = None matmul_1: "f32[2, 2, 8, 8]" = torch.ops.aten.matmul.default(add_1, transpose_4); add_1 = transpose_4 = None mul_8: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(matmul_1, 0.3535533905932738); matmul_1 = None alias: "f32[2, 1, 8, 8]" = torch.ops.aten.alias.default(where); where = None add_3: "f32[2, 2, 8, 8]" = torch.ops.aten.add.Tensor(mul_8, alias); mul_8 = alias = None softmax: "f32[2, 2, 8, 8]" = torch.ops.aten.softmax.int(add_3, -1, torch.float32); add_3 = None to_12: "f32[2, 2, 8, 8]" = torch.ops.aten.to.dtype(softmax, torch.float32); softmax = None dropout: "f32[2, 2, 8, 8]" = torch.ops.aten.dropout.default(to_12, 0.0, True); to_12 = None matmul_2: "f32[2, 2, 8, 8]" = torch.ops.aten.matmul.default(dropout, transpose_3); dropout = transpose_3 = None transpose_5: "f32[2, 8, 2, 8]" = torch.ops.aten.transpose.int(matmul_2, 1, 2); matmul_2 = None contiguous: "f32[2, 8, 2, 8]" = torch.ops.aten.contiguous.default(transpose_5); transpose_5 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous() reshape: "f32[2, 8, 16]" = torch.ops.aten.reshape.default(contiguous, [2, 8, -1]); contiguous = None # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear_3: "f32[2, 8, 16]" = torch.ops.aten.linear.default(reshape, p_model_layers_0_self_attn_o_proj_weight); reshape = p_model_layers_0_self_attn_o_proj_weight = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states add_4: "f32[2, 8, 16]" = torch.ops.aten.add.Tensor(to_10, linear_3); to_10 = linear_3 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32) to_13: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(add_4, torch.float32); add_4 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True) pow_2: "f32[2, 8, 16]" = torch.ops.aten.pow.Tensor_Scalar(to_13, 2) mean_1: "f32[2, 8, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) add_5: "f32[2, 8, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-06); mean_1 = None rsqrt_1: "f32[2, 8, 1]" = torch.ops.aten.rsqrt.default(add_5); add_5 = None mul_9: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(to_13, rsqrt_1); rsqrt_1 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype) to_14: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(mul_9, torch.float32); mul_9 = None mul_10: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_14); p_model_layers_0_post_attention_layernorm_weight = to_14 = None # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear_4: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace) silu: "f32[2, 8, 16]" = torch.ops.aten.silu.default(linear_4); linear_4 = None # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear_5: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_up_proj_weight); mul_10 = p_model_layers_0_mlp_up_proj_weight = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) mul_11: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear_6: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_11, p_model_layers_0_mlp_down_proj_weight); mul_11 = p_model_layers_0_mlp_down_proj_weight = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states add_6: "f32[2, 8, 16]" = torch.ops.aten.add.Tensor(to_13, linear_6); to_13 = linear_6 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32) to_15: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(add_6, torch.float32); add_6 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True) pow_3: "f32[2, 8, 16]" = torch.ops.aten.pow.Tensor_Scalar(to_15, 2) mean_2: "f32[2, 8, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) add_7: "f32[2, 8, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-06); mean_2 = None rsqrt_2: "f32[2, 8, 1]" = torch.ops.aten.rsqrt.default(add_7); add_7 = None mul_12: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(to_15, rsqrt_2); to_15 = rsqrt_2 = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype) to_16: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(mul_12, torch.float32); mul_12 = None mul_13: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_16); p_model_norm_weight = to_16 = None return (mul_13,) class submod_1(torch.nn.Module): def forward(self, to_4: "f32[1, 4, 1]", to_5: "f32[1, 1, 8]"): # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:101 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None to_6: "f32[1, 4, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32); to_4 = None _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None to_7: "f32[1, 1, 8]" = torch.ops.aten.to.dtype(to_5, torch.float32); to_5 = None matmul: "f32[1, 4, 8]" = torch.ops.aten.matmul.default(to_6, to_7); to_6 = to_7 = None transpose: "f32[1, 8, 4]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:102 in forward, code: emb = torch.cat((freqs, freqs), dim=-1) cat: "f32[1, 8, 8]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:103 in forward, code: cos = emb.cos() * self.attention_scaling cos: "f32[1, 8, 8]" = torch.ops.aten.cos.default(cat) mul: "f32[1, 8, 8]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:104 in forward, code: sin = emb.sin() * self.attention_scaling sin: "f32[1, 8, 8]" = torch.ops.aten.sin.default(cat); cat = None mul_1: "f32[1, 8, 8]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None return (mul, mul_1) Graph signature: # inputs p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight' p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight' p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight' p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight' p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight' p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight' p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight' p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight' p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight' p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight' p_model_norm_weight: PARAMETER target='model.norm.weight' b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False c_model_lifted_tensor_0: CONSTANT_TENSOR target='model.lifted_tensor_0' input_ids: USER_INPUT attention_mask: USER_INPUT # outputs mul_13: USER_OUTPUT Range constraints: {} -- process.graph_module.graph -- graph(): %p_model_embed_tokens_weight : [num_users=1] = placeholder[target=p_model_embed_tokens_weight] %p_model_layers_0_self_attn_q_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_q_proj_weight] %p_model_layers_0_self_attn_k_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_k_proj_weight] %p_model_layers_0_self_attn_v_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_v_proj_weight] %p_model_layers_0_self_attn_o_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_o_proj_weight] %p_model_layers_0_mlp_gate_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_gate_proj_weight] %p_model_layers_0_mlp_up_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_up_proj_weight] %p_model_layers_0_mlp_down_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_down_proj_weight] %p_model_layers_0_input_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_input_layernorm_weight] %p_model_layers_0_post_attention_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_post_attention_layernorm_weight] %p_model_norm_weight : [num_users=1] = placeholder[target=p_model_norm_weight] %b_model_rotary_emb_inv_freq : [num_users=1] = placeholder[target=b_model_rotary_emb_inv_freq] %c_model_lifted_tensor_0 : [num_users=1] = placeholder[target=c_model_lifted_tensor_0] %input_ids : [num_users=1] = placeholder[target=input_ids] %attention_mask : [num_users=1] = placeholder[target=attention_mask] %function_const_func_spec0 : [num_users=1] = get_attr[target=function_const_func_spec0] %torch__dynamo__trace_wrapped_higher_order_op_mod_index0 : [num_users=1] = get_attr[target=torch__dynamo__trace_wrapped_higher_order_op_ModIndex0] %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_model_embed_tokens_weight, %input_ids), kwargs = {}) %arange : [num_users=2] = call_function[target=torch.ops.aten.arange.start](args = (0, 8), kwargs = {device: cpu, pin_memory: False}) %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arange, 0), kwargs = {}) %to : [num_users=1] = call_function[target=torch.ops.aten.to.device](args = (%attention_mask, cpu, torch.bool), kwargs = {}) %arange_1 : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (8,), kwargs = {device: cpu, pin_memory: False}) %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%arange_1, 0), kwargs = {}) %arange_2 : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (2,), kwargs = {device: cpu, pin_memory: False}) %_add_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%arange_2, 0, 1), kwargs = {}) %_add_batch_dim_2 : [num_users=2] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%arange, 0, 3), kwargs = {}) %_add_batch_dim_3 : [num_users=2] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%add_, 0, 4), kwargs = {}) %new_ones : [num_users=1] = call_function[target=torch.ops.aten.new_ones.default](args = (%_add_batch_dim_2, []), kwargs = {dtype: torch.bool, pin_memory: False}) %le : [num_users=1] = call_function[target=torch.ops.aten.le.Tensor](args = (%_add_batch_dim_3, %_add_batch_dim_2), kwargs = {}) %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%le,), kwargs = {dtype: torch.bool, layout: torch.strided, device: cpu}) %and_1 : [num_users=1] = call_function[target=torch.ops.aten.__and__.Tensor](args = (%new_ones, %to_1), kwargs = {}) %flat_apply : [num_users=1] = call_function[target=torch.ops.higher_order.flat_apply](args = (%function_const_func_spec0, %torch__dynamo__trace_wrapped_higher_order_op_mod_index0, torch._dynamo._trace_wrapped_higher_order_op.ModIndex, %to, %_add_batch_dim, %_add_batch_dim_3), kwargs = {}) %to_2 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%flat_apply,), kwargs = {dtype: torch.bool, layout: torch.strided, device: cpu}) %and_2 : [num_users=1] = call_function[target=torch.ops.aten.__and__.Tensor](args = (%and_1, %to_2), kwargs = {}) %_remove_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%and_2, 4, 8, 0), kwargs = {}) %_remove_batch_dim_1 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim, 3, 8, 0), kwargs = {}) %_remove_batch_dim_2 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim_1, 2, 1, 0), kwargs = {}) %_remove_batch_dim_3 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim_2, 1, 2, 0), kwargs = {}) %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_model_lifted_tensor_0,), kwargs = {}) %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {}) %where : [num_users=1] = call_function[target=torch.ops.aten.where.ScalarOther](args = (%_remove_batch_dim_3, %detach_, -3.4028234663852886e+38), kwargs = {}) %unsqueeze_1 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%b_model_rotary_emb_inv_freq, 0), kwargs = {}) %unsqueeze_2 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_1, 2), kwargs = {}) %to_3 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%unsqueeze_2, torch.float32), kwargs = {}) %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%to_3, [1, -1, 1]), kwargs = {}) %to_4 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%expand_1,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu}) %unsqueeze_3 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze, 1), kwargs = {}) %to_5 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%unsqueeze_3, torch.float32), kwargs = {}) %submod_3 : [num_users=1] = get_attr[target=submod_1] %wrap_with_autocast : [num_users=2] = call_function[target=torch.ops.higher_order.wrap_with_autocast](args = (cpu, torch.bfloat16, False, False, %submod_3, %to_4, %to_5), kwargs = {}) %mul : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_autocast, 0), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_autocast, 1), kwargs = {}) %to_8 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul, torch.float32), kwargs = {}) %to_9 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_1, torch.float32), kwargs = {}) %to_10 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%embedding, torch.float32), kwargs = {}) %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_10, 2), kwargs = {}) %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-06), kwargs = {}) %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {}) %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_10, %rsqrt), kwargs = {}) %to_11 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_2, torch.float32), kwargs = {}) %mul_3 : [num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_input_layernorm_weight, %to_11), kwargs = {}) %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_q_proj_weight), kwargs = {}) %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear, [2, 8, -1, 8]), kwargs = {}) %transpose_1 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view, 1, 2), kwargs = {}) %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_k_proj_weight), kwargs = {}) %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_1, [2, 8, -1, 8]), kwargs = {}) %transpose_2 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view_1, 1, 2), kwargs = {}) %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_v_proj_weight), kwargs = {}) %view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_2, [2, 8, -1, 8]), kwargs = {}) %transpose_3 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%view_2, 1, 2), kwargs = {}) %unsqueeze_4 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_8, 1), kwargs = {}) %unsqueeze_5 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_9, 1), kwargs = {}) %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_1, %unsqueeze_4), kwargs = {}) %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 0, 4), kwargs = {}) %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 4, 9223372036854775807), kwargs = {}) %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_5), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {}) %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_2, %unsqueeze_4), kwargs = {}) %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 0, 4), kwargs = {}) %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 4, 9223372036854775807), kwargs = {}) %neg_1 : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) %cat_2 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %unsqueeze_5), kwargs = {}) %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {}) %transpose_4 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%add_2, 2, 3), kwargs = {}) %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%add_1, %transpose_4), kwargs = {}) %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_1, 0.3535533905932738), kwargs = {}) %alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%where,), kwargs = {}) %add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_8, %alias), kwargs = {}) %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%add_3, -1, torch.float32), kwargs = {}) %to_12 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%softmax, torch.float32), kwargs = {}) %dropout : [num_users=1] = call_function[target=torch.ops.aten.dropout.default](args = (%to_12, 0.0, True), kwargs = {}) %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%dropout, %transpose_3), kwargs = {}) %transpose_5 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%matmul_2, 1, 2), kwargs = {}) %contiguous : [num_users=1] = call_function[target=torch.ops.aten.contiguous.default](args = (%transpose_5,), kwargs = {}) %reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%contiguous, [2, 8, -1]), kwargs = {}) %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%reshape, %p_model_layers_0_self_attn_o_proj_weight), kwargs = {}) %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_10, %linear_3), kwargs = {}) %to_13 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%add_4, torch.float32), kwargs = {}) %pow_2 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_13, 2), kwargs = {}) %mean_1 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {}) %add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-06), kwargs = {}) %rsqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_5,), kwargs = {}) %mul_9 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_13, %rsqrt_1), kwargs = {}) %to_14 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_9, torch.float32), kwargs = {}) %mul_10 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_post_attention_layernorm_weight, %to_14), kwargs = {}) %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_10, %p_model_layers_0_mlp_gate_proj_weight), kwargs = {}) %silu : [num_users=1] = call_function[target=torch.ops.aten.silu.default](args = (%linear_4,), kwargs = {}) %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_10, %p_model_layers_0_mlp_up_proj_weight), kwargs = {}) %mul_11 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%silu, %linear_5), kwargs = {}) %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_11, %p_model_layers_0_mlp_down_proj_weight), kwargs = {}) %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_13, %linear_6), kwargs = {}) %to_15 : [num_users=2] = call_function[target=torch.ops.aten.to.dtype](args = (%add_6, torch.float32), kwargs = {}) %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_15, 2), kwargs = {}) %mean_2 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {}) %add_7 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-06), kwargs = {}) %rsqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_7,), kwargs = {}) %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_15, %rsqrt_2), kwargs = {}) %to_16 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_12, torch.float32), kwargs = {}) %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_norm_weight, %to_16), kwargs = {}) return (mul_13,) -- process.inputs_to_remove -- set() -- process.progress -- node 15/124 target=function_const_func_spec0 -- 2 INPUTS [GraphBuilder-YHC.make_tensor_input] input0[7:2x8] [GraphBuilder-YHC.make_tensor_input] input1[1:2x8] -- 24 INITIALIZERS [GraphBuilder-YHC.make_initializer] p_model_embed_tokens_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.embed_tokens.weight) [GraphBuilder-YHC.make_initializer] model.embed_tokens.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.embed_tokens.weight) [GraphBuilder-YHC.make_initializer] p_model_layers_0_self_attn_q_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.q_proj.weight) [GraphBuilder-YHC.make_initializer] model.layers.0.self_attn.q_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.q_proj.weight) [GraphBuilder-YHC.make_initializer] p_model_layers_0_self_attn_k_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.k_proj.weight) [GraphBuilder-YHC.make_initializer] model.layers.0.self_attn.k_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.k_proj.weight) [GraphBuilder-YHC.make_initializer] p_model_layers_0_self_attn_v_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.v_proj.weight) [GraphBuilder-YHC.make_initializer] model.layers.0.self_attn.v_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.v_proj.weight) [GraphBuilder-YHC.make_initializer] p_model_layers_0_self_attn_o_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.o_proj.weight) [GraphBuilder-YHC.make_initializer] model.layers.0.self_attn.o_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.o_proj.weight) [GraphBuilder-YHC.make_initializer] p_model_layers_0_mlp_gate_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.gate_proj.weight) [GraphBuilder-YHC.make_initializer] model.layers.0.mlp.gate_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.gate_proj.weight) [GraphBuilder-YHC.make_initializer] p_model_layers_0_mlp_up_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.up_proj.weight) [GraphBuilder-YHC.make_initializer] model.layers.0.mlp.up_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.up_proj.weight) [GraphBuilder-YHC.make_initializer] p_model_layers_0_mlp_down_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.down_proj.weight) [GraphBuilder-YHC.make_initializer] model.layers.0.mlp.down_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.down_proj.weight) [GraphBuilder-YHC.make_initializer] p_model_layers_0_input_layernorm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.input_layernorm.weight) [GraphBuilder-YHC.make_initializer] model.layers.0.input_layernorm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.input_layernorm.weight) [GraphBuilder-YHC.make_initializer] p_model_layers_0_post_attention_layernorm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.post_attention_layernorm.weight) [GraphBuilder-YHC.make_initializer] model.layers.0.post_attention_layernorm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.post_attention_layernorm.weight) [GraphBuilder-YHC.make_initializer] p_model_norm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.norm.weight) [GraphBuilder-YHC.make_initializer] model.norm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.norm.weight) [GraphBuilder-YHC.make_initializer] b_model_rotary_emb_inv_freq[torch.float32:torch.float32:[1.0, 0.10000000149011612, 0.009999999776482582, 0.0010000000474974513]] - SOURCE: DynamoInterpret.placeholder.0 [GraphBuilder-YHC.make_initializer] c_model_lifted_tensor_0[torch.float32:torch.float32:[0.0]] - SOURCE: DynamoInterpret.placeholder.0 [GraphBuilder-YHC.make_node] make_tensor_input_id [#:# ] Identity:['input0']->['input_ids'] [GraphBuilder-YHC.make_node] make_tensor_input_id2 [#:# ] Identity:['input1']->['attention_mask'] -- 0 OUTPUTS [GraphBuilder-YHC] Message completed, there are 24 initializers, 2 nodes, 2 inputs, 2 outputs. nanobind: leaked 1 instances! nanobind: leaked 10 types! - leaked type "onnx.onnx_cpp2py_export.defs.Attribute" - leaked type "onnx.onnx_cpp2py_export.shape_inference.InferenceContext" - leaked type "onnx.onnx_cpp2py_export.defs.OpSchema" - leaked type "onnx.onnx_cpp2py_export.shape_inference.GraphInferencer" - leaked type "onnx.onnx_cpp2py_export.defs.FormalParameter" - leaked type "FormalParameterOption" - leaked type "SupportType" - leaked type "DifferentiationCategory" - leaked type "onnx.onnx_cpp2py_export.defs.TypeConstraintParam" - leaked type "AttrType" - ... skipped remainder nanobind: leaked 73 functions! - leaked function "" - leaked function "" - leaked function "has_schema" - leaked function "" - leaked function "set_output_type" - leaked function "get_symbolic_input" - leaked function "" - leaked function "get_function_with_opset_version" - leaked function "" - leaked function "" - leaked function "has_output" - leaked function "" - leaked function "" - leaked function "" - leaked function "" - leaked function "get_input_sparse_data" - leaked function "" - leaked function "" - leaked function "__init__" - leaked function "" - leaked function "" - leaked function "get_num_outputs" - leaked function "get_input_data" - leaked function "" - leaked function "" - leaked function "" - leaked function "_infer_node_outputs" - leaked function "get_type_and_shape_inference_function" - leaked function "__init__" - leaked function "" - leaked function "" - leaked function "do_inferencing" - leaked function "" - leaked function "" - leaked function "" - leaked function "" - leaked function "__init__" - leaked function "get_schema" - leaked function "get_all_schemas" - leaked function "" - leaked function "" - leaked function "deregister_schema" - leaked function "" - leaked function "" - leaked function "get_output_type" - leaked function "is_infinite" - leaked function "" - leaked function "" - leaked function "" - leaked function "" - leaked function "" - leaked function "get_input_type" - leaked function "" - leaked function "get_graph_attribute_inferencer" - leaked function "set_type_and_shape_inference_function" - leaked function "" - leaked function "" - leaked function "get_attribute" - leaked function "" - leaked function "get_all_schemas_with_history" - leaked function "" - leaked function "has_input" - leaked function "get_context_dependent_function" - leaked function "get_display_name" - leaked function "" - leaked function "__init__" - leaked function "get_num_inputs" - leaked function "" - leaked function "" - leaked function "" - leaked function "" - leaked function "" - leaked function "get_context_dependent_function_with_opset_version" nanobind: this is likely caused by a reference counting issue in the binding code. See https://nanobind.readthedocs.io/en/latest/refleaks.html
- fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None [source]¶
The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.
- Parameters:
name – object or str
fct – function found so far
args – known arguments coming from the graph module
kwargs – known named arguments coming from the graph module
builder – GraphBuilder
- Returns:
callable
- class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDispatcher(verbose: int = 0)[source]¶
If
DynamoInterpreter
cannot find any converting function for a specific function, it tries to find an existing one in onnxscript. The converting function from onnxscript is run in trace only mode. The variable and functions op, Rank, IsScalar are replaced by op = OwsOpset(), op.Rank, op.Scalar. onnxscript may have multiple overloaded functions. Right now, it takes the first one.- Parameters:
verbose – verbose
- fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None [source]¶
The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.
- Parameters:
name – object or str
fct – function found so far
args – known arguments coming from the graph module
kwargs – known named arguments coming from the graph module
builder – GraphBuilder
- Returns:
callable