Note
Go to the end to download the full example code.
Export LLM with dynamic shapes¶
We focus on the model Tiny-LLM. To avoid downloading any weigths, we write a function creating a random model based on the same architecture.
Guess the cache dimension¶
The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.
import copy
import torch
import transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.llms import get_tiny_llm
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
We rewrite the forward method to print the cache dimension.
def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not torch.compiler.is_exporting():
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
res = _f(*args, **kwargs)
if not torch.compiler.is_exporting():
print("->", string_type((args, kwargs), with_shape=True, with_min_max=True))
return res
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
*args, _f=_f, **kwargs
)
Let’s run the model.
prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[2866,2866:A2866.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.1302779765439347]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.007744434695858352]]),input_ids:T7s1x1[2866,2866:A2866.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.1302779765439347]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.007744434695858352]]),input_ids:T7s1x1[14150,14150:A14150.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1353976684111937]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.008736979494627425]]),input_ids:T7s1x1[14150,14150:A14150.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1353976684111937]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.008736979494627425]]),input_ids:T7s1x1[21439,21439:A21439.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.918347358703613,6.226877689361572:A-0.13480870858852126]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.49568021297454834:A0.009022974111827132]]),input_ids:T7s1x1[21439,21439:A21439.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
Continue: it rains... Continue Reading
Let’s restore the forward as it was.
model.forward = keep_model_forward
The model creation¶
Let’s create an untrained model.
Let’s get the model, inputs and dynamic shapes.
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.
print("input type", string_type(inputs, with_shape=True))
expected_output = untrained_model(**inputs)
print("input after the execution", string_type(inputs, with_shape=True))
print("result type", string_type(expected_output, with_shape=True))
ep = torch.export.export(
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes
)
input type dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input after the execution dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
result type dict(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
It works.
ExportedProgram¶
try:
ep = torch.export.export(
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s0, s1]", attention_mask: "i64[s0, s1 + s5]", past_key_values_key_cache_0: "f32[s0, 1, s5, 96]", past_key_values_value_cache_0: "f32[s0, 1, s5, 96]"):
#
sym_size_int_20: "Sym(s0)" = torch.ops.aten.sym_size.int(input_ids, 0)
sym_size_int_21: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_22: "Sym(s5)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[s0, s1, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:565 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s1 + s5)" = sym_size_int_22 + sym_size_int_21
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: cache_position = torch.arange(
arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False); sym_size_int_22 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:569 in forward, code: position_ids = cache_position.unsqueeze(0)
unsqueeze: "i64[1, s1]" = torch.ops.aten.unsqueeze.default(arange, 0)
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: causal_mask = self._update_causal_mask(
full: "f32[s1, s1 + s5]" = torch.ops.aten.full.default([sym_size_int_21, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu: "f32[s1, s1 + s5]" = torch.ops.aten.triu.default(full, 1); full = None
arange_1: "i64[s1 + s5]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt: "b8[s1, s1 + s5]" = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_: "f32[s1, s1 + s5]" = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze_1: "f32[1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1); unsqueeze_1 = None
slice_1: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
slice_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
expand: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_20, 1, -1, -1]); slice_2 = None
clone: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.clone.default(expand); expand = None
slice_3: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_4: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807); slice_3 = None
slice_5: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807); slice_4 = None
slice_6: "i64[s0, s1 + s5]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807); attention_mask = None
unsqueeze_3: "i64[s0, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(slice_6, 1); slice_6 = None
unsqueeze_4: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2); unsqueeze_3 = None
slice_7: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807); unsqueeze_4 = None
to: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_7 = None
add_2: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.add.Tensor(slice_5, to); slice_5 = to = None
eq_7: "b8[s0, 1, s1, s1 + s5]" = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_8: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_9: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807); slice_8 = None
slice_10: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807); slice_9 = None
masked_fill: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38); slice_10 = eq_7 = None
slice_11: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_12: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807); slice_11 = None
slice_13: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807); slice_12 = None
copy_: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.copy_.default(slice_13, masked_fill); slice_13 = masked_fill = copy_ = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, unsqueeze); submod_3 = b_model_rotary_emb_inv_freq = unsqueeze = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[0]
to_7: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_8: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_9: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s0, s1, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_21, -1, 96]); linear = None
transpose_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_21, -1, 96]); linear_1 = None
transpose_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_21, -1, 96]); linear_2 = None
transpose_3: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_8: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_9: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_8)
slice_17: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_18: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s0, 2, s1, 48]" = torch.ops.aten.neg.default(slice_18); slice_18 = None
cat_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.cat.default([neg, slice_17], -1); neg = slice_17 = None
mul_5: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_9); cat_1 = None
add_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_8); unsqueeze_8 = None
slice_19: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_20: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s0, 1, s1, 48]" = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.cat.default([neg_1, slice_19], -1); neg_1 = slice_19 = None
mul_7: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_9); cat_2 = unsqueeze_9 = None
add_5: "f32[s0, 1, s1, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
slice_21: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
slice_22: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807); slice_21 = None
unsqueeze_10: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_22, 2); slice_22 = None
slice_23: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
slice_24: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807); slice_23 = None
expand_2: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_24, [sym_size_int_20, 1, 2, add, 96]); slice_24 = None
reshape_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_20, 2, add, 96]); expand_2 = None
slice_25: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
slice_26: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807); slice_25 = None
unsqueeze_11: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_26, 2); slice_26 = None
slice_27: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_11, 3, 0, 9223372036854775807); unsqueeze_11 = None
slice_28: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807); slice_27 = None
expand_3: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_28, [sym_size_int_20, 1, 2, add, 96]); slice_28 = None
reshape_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_20, 2, add, 96]); expand_3 = add = None
slice_29: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807); clone = None
slice_30: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807); slice_29 = None
slice_31: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807); slice_30 = None
contiguous: "f32[s0, 2, s1, 96]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_1); reshape_1 = None
contiguous_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_2); reshape_2 = None
scaled_dot_product_attention: "f32[s0, 2, s1, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_31 = None
transpose_4: "f32[s0, s1, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3: "f32[s0, s1, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_3: "f32[s0, s1, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_21, -1]); contiguous_3 = sym_size_int_20 = sym_size_int_21 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight); reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_10: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_11: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s0, s1, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight); mul_9 = p_model_layers_0_mlp_up_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_10: "f32[s0, s1, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight); mul_10 = p_model_layers_0_mlp_down_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_12: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_13: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13); p_model_norm_weight = to_13 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:870 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_32: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807); mul_12 = None
slice_33: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807); slice_32 = None
slice_34: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807); slice_33 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s0, s1, 32000]" = torch.ops.aten.linear.default(slice_34, p_lm_head_weight); slice_34 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", unsqueeze: "i64[1, s1]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
unsqueeze_5: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
slice_14: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 1, 0, 9223372036854775807); unsqueeze_5 = None
unsqueeze_6: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2); slice_14 = None
to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_6, torch.float32); unsqueeze_6 = None
expand_1: "f32[1, 48, 1]" = torch.ops.aten.expand.default(to_1, [1, -1, 1]); to_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_15: "i64[1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze, 0, 0, 9223372036854775807); unsqueeze = None
unsqueeze_7: "i64[1, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1); slice_15 = None
slice_16: "i64[1, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_7, 2, 0, 9223372036854775807); unsqueeze_7 = None
to_2: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32); slice_16 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2); submod_3 = expand_1 = to_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[1, s1, 96]" = wrap_with_autocast[0]
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[1, s1, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
mul: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
mul_1: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
to_7: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, expand_1: "f32[1, 48, 1]", to_2: "f32[1, 1, s1]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32); expand_1 = None
to_4: "f32[1, 48, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); to_3 = None
to_5: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
matmul: "f32[1, 48, s1]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[1, s1, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[1, s1, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[1, s1, 96]" = torch.ops.aten.cos.default(cat)
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[1, s1, 96]" = torch.ops.aten.sin.default(cat); cat = None
return (cos, sin)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_embed_tokens_weight'), target='model.embed_tokens.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_q_proj_weight'), target='model.layers.0.self_attn.q_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_k_proj_weight'), target='model.layers.0.self_attn.k_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_v_proj_weight'), target='model.layers.0.self_attn.v_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_o_proj_weight'), target='model.layers.0.self_attn.o_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_gate_proj_weight'), target='model.layers.0.mlp.gate_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_up_proj_weight'), target='model.layers.0.mlp.up_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_down_proj_weight'), target='model.layers.0.mlp.down_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_input_layernorm_weight'), target='model.layers.0.input_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_post_attention_layernorm_weight'), target='model.layers.0.post_attention_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_norm_weight'), target='model.norm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lm_head_weight'), target='lm_head.weight', persistent=None), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_model_rotary_emb_inv_freq'), target='model.rotary_emb.inv_freq', persistent=False), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_ids'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='attention_mask'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_key_cache_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_value_cache_0'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_4'), target=None)])
Range constraints: {s0: VR[1, 1024], s1: VR[1, 4096], s1 + s5: VR[4, 8192], s5: VR[1, 4096]}
Back to the original model¶
Let’s use the same dummy inputs but we use the downloaded model.
try:
ep = torch.export.export(model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s0, s1]", attention_mask: "i64[s0, s1 + s5]", past_key_values_key_cache_0: "f32[s0, 1, s5, 96]", past_key_values_value_cache_0: "f32[s0, 1, s5, 96]"):
#
sym_size_int_20: "Sym(s0)" = torch.ops.aten.sym_size.int(input_ids, 0)
sym_size_int_21: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_22: "Sym(s5)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[s0, s1, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:565 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s1 + s5)" = sym_size_int_22 + sym_size_int_21
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: cache_position = torch.arange(
arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False); sym_size_int_22 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:569 in forward, code: position_ids = cache_position.unsqueeze(0)
unsqueeze: "i64[1, s1]" = torch.ops.aten.unsqueeze.default(arange, 0)
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: causal_mask = self._update_causal_mask(
full: "f32[s1, s1 + s5]" = torch.ops.aten.full.default([sym_size_int_21, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu: "f32[s1, s1 + s5]" = torch.ops.aten.triu.default(full, 1); full = None
arange_1: "i64[s1 + s5]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt: "b8[s1, s1 + s5]" = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_: "f32[s1, s1 + s5]" = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze_1: "f32[1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1); unsqueeze_1 = None
slice_1: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
slice_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
expand: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_20, 1, -1, -1]); slice_2 = None
clone: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.clone.default(expand); expand = None
slice_3: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_4: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807); slice_3 = None
slice_5: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807); slice_4 = None
slice_6: "i64[s0, s1 + s5]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807); attention_mask = None
unsqueeze_3: "i64[s0, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(slice_6, 1); slice_6 = None
unsqueeze_4: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2); unsqueeze_3 = None
slice_7: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807); unsqueeze_4 = None
to: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_7 = None
add_2: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.add.Tensor(slice_5, to); slice_5 = to = None
eq_7: "b8[s0, 1, s1, s1 + s5]" = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_8: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_9: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807); slice_8 = None
slice_10: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807); slice_9 = None
masked_fill: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38); slice_10 = eq_7 = None
slice_11: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_12: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807); slice_11 = None
slice_13: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807); slice_12 = None
copy_: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.copy_.default(slice_13, masked_fill); slice_13 = masked_fill = copy_ = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, unsqueeze); submod_3 = b_model_rotary_emb_inv_freq = unsqueeze = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[0]
to_7: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_8: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_9: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s0, s1, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_21, -1, 96]); linear = None
transpose_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_21, -1, 96]); linear_1 = None
transpose_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_21, -1, 96]); linear_2 = None
transpose_3: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_8: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_9: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_8)
slice_17: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_18: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s0, 2, s1, 48]" = torch.ops.aten.neg.default(slice_18); slice_18 = None
cat_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.cat.default([neg, slice_17], -1); neg = slice_17 = None
mul_5: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_9); cat_1 = None
add_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_8); unsqueeze_8 = None
slice_19: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_20: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s0, 1, s1, 48]" = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.cat.default([neg_1, slice_19], -1); neg_1 = slice_19 = None
mul_7: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_9); cat_2 = unsqueeze_9 = None
add_5: "f32[s0, 1, s1, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
slice_21: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
slice_22: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807); slice_21 = None
unsqueeze_10: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_22, 2); slice_22 = None
slice_23: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
slice_24: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807); slice_23 = None
expand_2: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_24, [sym_size_int_20, 1, 2, add, 96]); slice_24 = None
reshape_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_20, 2, add, 96]); expand_2 = None
slice_25: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
slice_26: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807); slice_25 = None
unsqueeze_11: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_26, 2); slice_26 = None
slice_27: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_11, 3, 0, 9223372036854775807); unsqueeze_11 = None
slice_28: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807); slice_27 = None
expand_3: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_28, [sym_size_int_20, 1, 2, add, 96]); slice_28 = None
reshape_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_20, 2, add, 96]); expand_3 = add = None
slice_29: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807); clone = None
slice_30: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807); slice_29 = None
slice_31: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807); slice_30 = None
contiguous: "f32[s0, 2, s1, 96]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_1); reshape_1 = None
contiguous_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_2); reshape_2 = None
scaled_dot_product_attention: "f32[s0, 2, s1, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_31 = None
transpose_4: "f32[s0, s1, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3: "f32[s0, s1, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_3: "f32[s0, s1, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_21, -1]); contiguous_3 = sym_size_int_20 = sym_size_int_21 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight); reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_10: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_11: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s0, s1, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight); mul_9 = p_model_layers_0_mlp_up_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_10: "f32[s0, s1, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight); mul_10 = p_model_layers_0_mlp_down_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_12: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_13: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13); p_model_norm_weight = to_13 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:870 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_32: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807); mul_12 = None
slice_33: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807); slice_32 = None
slice_34: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807); slice_33 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s0, s1, 32000]" = torch.ops.aten.linear.default(slice_34, p_lm_head_weight); slice_34 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", unsqueeze: "i64[1, s1]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
unsqueeze_5: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
slice_14: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 1, 0, 9223372036854775807); unsqueeze_5 = None
unsqueeze_6: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2); slice_14 = None
to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_6, torch.float32); unsqueeze_6 = None
expand_1: "f32[1, 48, 1]" = torch.ops.aten.expand.default(to_1, [1, -1, 1]); to_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_15: "i64[1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze, 0, 0, 9223372036854775807); unsqueeze = None
unsqueeze_7: "i64[1, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1); slice_15 = None
slice_16: "i64[1, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_7, 2, 0, 9223372036854775807); unsqueeze_7 = None
to_2: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32); slice_16 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2); submod_3 = expand_1 = to_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[1, s1, 96]" = wrap_with_autocast[0]
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[1, s1, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
mul: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
mul_1: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
to_7: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, expand_1: "f32[1, 48, 1]", to_2: "f32[1, 1, s1]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32); expand_1 = None
to_4: "f32[1, 48, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); to_3 = None
to_5: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
matmul: "f32[1, 48, s1]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[1, s1, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[1, s1, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[1, s1, 96]" = torch.ops.aten.cos.default(cat)
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[1, s1, 96]" = torch.ops.aten.sin.default(cat); cat = None
return (cos, sin)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_embed_tokens_weight'), target='model.embed_tokens.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_q_proj_weight'), target='model.layers.0.self_attn.q_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_k_proj_weight'), target='model.layers.0.self_attn.k_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_v_proj_weight'), target='model.layers.0.self_attn.v_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_o_proj_weight'), target='model.layers.0.self_attn.o_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_gate_proj_weight'), target='model.layers.0.mlp.gate_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_up_proj_weight'), target='model.layers.0.mlp.up_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_down_proj_weight'), target='model.layers.0.mlp.down_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_input_layernorm_weight'), target='model.layers.0.input_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_post_attention_layernorm_weight'), target='model.layers.0.post_attention_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_norm_weight'), target='model.norm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lm_head_weight'), target='lm_head.weight', persistent=None), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_model_rotary_emb_inv_freq'), target='model.rotary_emb.inv_freq', persistent=False), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_ids'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='attention_mask'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_key_cache_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_value_cache_0'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_4'), target=None)])
Range constraints: {s0: VR[1, 1024], s1: VR[1, 4096], s1 + s5: VR[4, 8192], s5: VR[1, 4096]}
Total running time of the script: (0 minutes 12.233 seconds)
Related examples

Use DYNAMIC or AUTO when dynamic shapes has constraints