Note
Go to the end to download the full example code.
Export LLM with dynamic shapes¶
We focus on the model Tiny-LLM. To avoid downloading any weigths, we write a function creating a random model based on the same architecture.
Guess the cache dimension¶
The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.
import copy
import torch
import transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.llms import get_tiny_llm
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
We rewrite the forward method to print the cache dimension.
def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not torch.compiler.is_exporting():
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
res = _f(*args, **kwargs)
if not torch.compiler.is_exporting():
print("->", string_type((args, kwargs), with_shape=True, with_min_max=True))
return res
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
*args, _f=_f, **kwargs
)
Let’s run the model.
prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[2866,2866:A2866.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.1302779765439347]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.007744434695858352]]),input_ids:T7s1x1[2866,2866:A2866.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.1302779765439347]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.007744434695858352]]),input_ids:T7s1x1[14150,14150:A14150.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1353976684111937]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.008736979494627425]]),input_ids:T7s1x1[14150,14150:A14150.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1353976684111937]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.008736979494627425]]),input_ids:T7s1x1[21439,21439:A21439.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.918347358703613,6.226877689361572:A-0.13480870858852126]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.49568021297454834:A0.009022974111827132]]),input_ids:T7s1x1[21439,21439:A21439.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
Continue: it rains... Continue Reading
Let’s restore the forward as it was.
model.forward = keep_model_forward
The model creation¶
Let’s create an untrained model.
Let’s get the model, inputs and dynamic shapes.
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.
print("input type", string_type(inputs, with_shape=True))
expected_output = untrained_model(**inputs)
print("input after the execution", string_type(inputs, with_shape=True))
print("result type", string_type(expected_output, with_shape=True))
ep = torch.export.export(
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes
)
input type dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input after the execution dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
result type dict(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
~/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
It works.
ExportedProgram¶
try:
ep = torch.export.export(
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
~/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s0, s1]", attention_mask: "i64[s0, s1 + s5]", past_key_values_key_cache_0: "f32[s0, 1, s5, 96]", past_key_values_value_cache_0: "f32[s0, 1, s5, 96]"):
#
sym_size_int_20: "Sym(s0)" = torch.ops.aten.sym_size.int(input_ids, 0)
sym_size_int_21: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_22: "Sym(s5)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[s0, s1, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:565 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s1 + s5)" = sym_size_int_22 + sym_size_int_21
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: cache_position = torch.arange(
arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False); sym_size_int_22 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:569 in forward, code: position_ids = cache_position.unsqueeze(0)
unsqueeze: "i64[1, s1]" = torch.ops.aten.unsqueeze.default(arange, 0)
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: causal_mask = self._update_causal_mask(
full: "f32[s1, s1 + s5]" = torch.ops.aten.full.default([sym_size_int_21, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu: "f32[s1, s1 + s5]" = torch.ops.aten.triu.default(full, 1); full = None
arange_1: "i64[s1 + s5]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt: "b8[s1, s1 + s5]" = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_: "f32[s1, s1 + s5]" = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze_1: "f32[1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1); unsqueeze_1 = None
slice_1: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
slice_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
expand: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_20, 1, -1, -1]); slice_2 = None
clone: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.clone.default(expand); expand = None
slice_3: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_4: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807); slice_3 = None
slice_5: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807); slice_4 = None
slice_6: "i64[s0, s1 + s5]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807); attention_mask = None
unsqueeze_3: "i64[s0, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(slice_6, 1); slice_6 = None
unsqueeze_4: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2); unsqueeze_3 = None
slice_7: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807); unsqueeze_4 = None
to: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_7 = None
add_2: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.add.Tensor(slice_5, to); slice_5 = to = None
eq_7: "b8[s0, 1, s1, s1 + s5]" = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_8: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_9: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807); slice_8 = None
slice_10: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807); slice_9 = None
masked_fill: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38); slice_10 = eq_7 = None
slice_11: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_12: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807); slice_11 = None
slice_13: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807); slice_12 = None
copy_: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.copy_.default(slice_13, masked_fill); slice_13 = masked_fill = copy_ = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, unsqueeze); submod_3 = b_model_rotary_emb_inv_freq = unsqueeze = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[0]
to_7: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_8: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_9: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s0, s1, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_21, -1, 96]); linear = None
transpose_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_21, -1, 96]); linear_1 = None
transpose_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_21, -1, 96]); linear_2 = None
transpose_3: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_8: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_9: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_8)
slice_17: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_18: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s0, 2, s1, 48]" = torch.ops.aten.neg.default(slice_18); slice_18 = None
cat_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.cat.default([neg, slice_17], -1); neg = slice_17 = None
mul_5: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_9); cat_1 = None
add_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_8); unsqueeze_8 = None
slice_19: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_20: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s0, 1, s1, 48]" = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.cat.default([neg_1, slice_19], -1); neg_1 = slice_19 = None
mul_7: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_9); cat_2 = unsqueeze_9 = None
add_5: "f32[s0, 1, s1, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
slice_21: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
slice_22: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807); slice_21 = None
unsqueeze_10: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_22, 2); slice_22 = None
slice_23: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
slice_24: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807); slice_23 = None
expand_2: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_24, [sym_size_int_20, 1, 2, add, 96]); slice_24 = None
reshape_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_20, 2, add, 96]); expand_2 = None
slice_25: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
slice_26: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807); slice_25 = None
unsqueeze_11: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_26, 2); slice_26 = None
slice_27: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_11, 3, 0, 9223372036854775807); unsqueeze_11 = None
slice_28: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807); slice_27 = None
expand_3: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_28, [sym_size_int_20, 1, 2, add, 96]); slice_28 = None
reshape_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_20, 2, add, 96]); expand_3 = add = None
slice_29: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807); clone = None
slice_30: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807); slice_29 = None
slice_31: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807); slice_30 = None
contiguous: "f32[s0, 2, s1, 96]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_1); reshape_1 = None
contiguous_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_2); reshape_2 = None
scaled_dot_product_attention: "f32[s0, 2, s1, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_31 = None
transpose_4: "f32[s0, s1, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3: "f32[s0, s1, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_3: "f32[s0, s1, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_21, -1]); contiguous_3 = sym_size_int_20 = sym_size_int_21 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight); reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_10: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_11: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s0, s1, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight); mul_9 = p_model_layers_0_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_10: "f32[s0, s1, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight); mul_10 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_12: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_13: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13); p_model_norm_weight = to_13 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:870 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_32: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807); mul_12 = None
slice_33: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807); slice_32 = None
slice_34: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807); slice_33 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s0, s1, 32000]" = torch.ops.aten.linear.default(slice_34, p_lm_head_weight); slice_34 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", unsqueeze: "i64[1, s1]"):
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
unsqueeze_5: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
slice_14: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 1, 0, 9223372036854775807); unsqueeze_5 = None
unsqueeze_6: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2); slice_14 = None
to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_6, torch.float32); unsqueeze_6 = None
expand_1: "f32[1, 48, 1]" = torch.ops.aten.expand.default(to_1, [1, -1, 1]); to_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_15: "i64[1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze, 0, 0, 9223372036854775807); unsqueeze = None
unsqueeze_7: "i64[1, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1); slice_15 = None
slice_16: "i64[1, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_7, 2, 0, 9223372036854775807); unsqueeze_7 = None
to_2: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32); slice_16 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2); submod_3 = expand_1 = to_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[1, s1, 96]" = wrap_with_autocast[0]
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[1, s1, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
mul: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
mul_1: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
to_7: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, expand_1: "f32[1, 48, 1]", to_2: "f32[1, 1, s1]"):
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32); expand_1 = None
to_4: "f32[1, 48, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); to_3 = None
to_5: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
matmul: "f32[1, 48, s1]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[1, s1, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[1, s1, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[1, s1, 96]" = torch.ops.aten.cos.default(cat)
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[1, s1, 96]" = torch.ops.aten.sin.default(cat); cat = None
return (cos, sin)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_embed_tokens_weight'), target='model.embed_tokens.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_q_proj_weight'), target='model.layers.0.self_attn.q_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_k_proj_weight'), target='model.layers.0.self_attn.k_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_v_proj_weight'), target='model.layers.0.self_attn.v_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_o_proj_weight'), target='model.layers.0.self_attn.o_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_gate_proj_weight'), target='model.layers.0.mlp.gate_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_up_proj_weight'), target='model.layers.0.mlp.up_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_down_proj_weight'), target='model.layers.0.mlp.down_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_input_layernorm_weight'), target='model.layers.0.input_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_post_attention_layernorm_weight'), target='model.layers.0.post_attention_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_norm_weight'), target='model.norm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lm_head_weight'), target='lm_head.weight', persistent=None), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_model_rotary_emb_inv_freq'), target='model.rotary_emb.inv_freq', persistent=False), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_ids'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='attention_mask'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_key_cache_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_value_cache_0'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_4'), target=None)])
Range constraints: {s0: VR[1, 1024], s1: VR[1, 4096], s1 + s5: VR[4, 8192], s5: VR[1, 4096]}
Back to the original model¶
Let’s use the same dummy inputs but we use the downloaded model.
try:
ep = torch.export.export(model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
~/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s0, s1]", attention_mask: "i64[s0, s1 + s5]", past_key_values_key_cache_0: "f32[s0, 1, s5, 96]", past_key_values_value_cache_0: "f32[s0, 1, s5, 96]"):
#
sym_size_int_20: "Sym(s0)" = torch.ops.aten.sym_size.int(input_ids, 0)
sym_size_int_21: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_22: "Sym(s5)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[s0, s1, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:565 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s1 + s5)" = sym_size_int_22 + sym_size_int_21
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: cache_position = torch.arange(
arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False); sym_size_int_22 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:569 in forward, code: position_ids = cache_position.unsqueeze(0)
unsqueeze: "i64[1, s1]" = torch.ops.aten.unsqueeze.default(arange, 0)
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: causal_mask = self._update_causal_mask(
full: "f32[s1, s1 + s5]" = torch.ops.aten.full.default([sym_size_int_21, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu: "f32[s1, s1 + s5]" = torch.ops.aten.triu.default(full, 1); full = None
arange_1: "i64[s1 + s5]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt: "b8[s1, s1 + s5]" = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_: "f32[s1, s1 + s5]" = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze_1: "f32[1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1); unsqueeze_1 = None
slice_1: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
slice_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
expand: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_20, 1, -1, -1]); slice_2 = None
clone: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.clone.default(expand); expand = None
slice_3: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_4: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807); slice_3 = None
slice_5: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807); slice_4 = None
slice_6: "i64[s0, s1 + s5]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807); attention_mask = None
unsqueeze_3: "i64[s0, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(slice_6, 1); slice_6 = None
unsqueeze_4: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2); unsqueeze_3 = None
slice_7: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807); unsqueeze_4 = None
to: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_7 = None
add_2: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.add.Tensor(slice_5, to); slice_5 = to = None
eq_7: "b8[s0, 1, s1, s1 + s5]" = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_8: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_9: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807); slice_8 = None
slice_10: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807); slice_9 = None
masked_fill: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38); slice_10 = eq_7 = None
slice_11: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_12: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807); slice_11 = None
slice_13: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807); slice_12 = None
copy_: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.copy_.default(slice_13, masked_fill); slice_13 = masked_fill = copy_ = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, unsqueeze); submod_3 = b_model_rotary_emb_inv_freq = unsqueeze = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[0]
to_7: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_8: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_9: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s0, s1, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_21, -1, 96]); linear = None
transpose_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_21, -1, 96]); linear_1 = None
transpose_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_21, -1, 96]); linear_2 = None
transpose_3: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_8: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_9: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_8)
slice_17: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_18: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s0, 2, s1, 48]" = torch.ops.aten.neg.default(slice_18); slice_18 = None
cat_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.cat.default([neg, slice_17], -1); neg = slice_17 = None
mul_5: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_9); cat_1 = None
add_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_8); unsqueeze_8 = None
slice_19: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_20: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s0, 1, s1, 48]" = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.cat.default([neg_1, slice_19], -1); neg_1 = slice_19 = None
mul_7: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_9); cat_2 = unsqueeze_9 = None
add_5: "f32[s0, 1, s1, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
slice_21: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
slice_22: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807); slice_21 = None
unsqueeze_10: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_22, 2); slice_22 = None
slice_23: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
slice_24: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807); slice_23 = None
expand_2: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_24, [sym_size_int_20, 1, 2, add, 96]); slice_24 = None
reshape_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_20, 2, add, 96]); expand_2 = None
slice_25: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
slice_26: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807); slice_25 = None
unsqueeze_11: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_26, 2); slice_26 = None
slice_27: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_11, 3, 0, 9223372036854775807); unsqueeze_11 = None
slice_28: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807); slice_27 = None
expand_3: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_28, [sym_size_int_20, 1, 2, add, 96]); slice_28 = None
reshape_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_20, 2, add, 96]); expand_3 = add = None
slice_29: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807); clone = None
slice_30: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807); slice_29 = None
slice_31: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807); slice_30 = None
contiguous: "f32[s0, 2, s1, 96]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_1); reshape_1 = None
contiguous_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_2); reshape_2 = None
scaled_dot_product_attention: "f32[s0, 2, s1, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_31 = None
transpose_4: "f32[s0, s1, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3: "f32[s0, s1, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_3: "f32[s0, s1, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_21, -1]); contiguous_3 = sym_size_int_20 = sym_size_int_21 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight); reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_10: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_11: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s0, s1, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight); mul_9 = p_model_layers_0_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_10: "f32[s0, s1, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight); mul_10 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_12: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_13: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13); p_model_norm_weight = to_13 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:870 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_32: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807); mul_12 = None
slice_33: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807); slice_32 = None
slice_34: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807); slice_33 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s0, s1, 32000]" = torch.ops.aten.linear.default(slice_34, p_lm_head_weight); slice_34 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", unsqueeze: "i64[1, s1]"):
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
unsqueeze_5: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
slice_14: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 1, 0, 9223372036854775807); unsqueeze_5 = None
unsqueeze_6: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2); slice_14 = None
to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_6, torch.float32); unsqueeze_6 = None
expand_1: "f32[1, 48, 1]" = torch.ops.aten.expand.default(to_1, [1, -1, 1]); to_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_15: "i64[1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze, 0, 0, 9223372036854775807); unsqueeze = None
unsqueeze_7: "i64[1, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1); slice_15 = None
slice_16: "i64[1, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_7, 2, 0, 9223372036854775807); unsqueeze_7 = None
to_2: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32); slice_16 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2); submod_3 = expand_1 = to_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[1, s1, 96]" = wrap_with_autocast[0]
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[1, s1, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
mul: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
mul_1: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
to_7: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, expand_1: "f32[1, 48, 1]", to_2: "f32[1, 1, s1]"):
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32); expand_1 = None
to_4: "f32[1, 48, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); to_3 = None
to_5: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
matmul: "f32[1, 48, s1]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[1, s1, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[1, s1, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[1, s1, 96]" = torch.ops.aten.cos.default(cat)
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[1, s1, 96]" = torch.ops.aten.sin.default(cat); cat = None
return (cos, sin)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_embed_tokens_weight'), target='model.embed_tokens.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_q_proj_weight'), target='model.layers.0.self_attn.q_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_k_proj_weight'), target='model.layers.0.self_attn.k_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_v_proj_weight'), target='model.layers.0.self_attn.v_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_o_proj_weight'), target='model.layers.0.self_attn.o_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_gate_proj_weight'), target='model.layers.0.mlp.gate_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_up_proj_weight'), target='model.layers.0.mlp.up_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_down_proj_weight'), target='model.layers.0.mlp.down_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_input_layernorm_weight'), target='model.layers.0.input_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_post_attention_layernorm_weight'), target='model.layers.0.post_attention_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_norm_weight'), target='model.norm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lm_head_weight'), target='lm_head.weight', persistent=None), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_model_rotary_emb_inv_freq'), target='model.rotary_emb.inv_freq', persistent=False), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_ids'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='attention_mask'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_key_cache_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_value_cache_0'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_4'), target=None)])
Range constraints: {s0: VR[1, 1024], s1: VR[1, 4096], s1 + s5: VR[4, 8192], s5: VR[1, 4096]}
Total running time of the script: (0 minutes 12.233 seconds)
Related examples

Use DYNAMIC or AUTO when dynamic shapes has constraints