Mistral¶
<<<
import numpy as np
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
import torch
from transformers import MistralConfig
from transformers.models.mistral.modeling_mistral import MistralModel
from experimental_experiment.torch_interpreter import to_onnx
def ids_tensor(shape, vocab_size):
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(np.random.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
config = MistralConfig(
hidden_size=32,
num_hidden_layers=2,
vocab_size=1024,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
num_key_value_heads=2,
)
config._attn_implementation = "eager"
with torch.no_grad():
model = MistralModel(config)
batch, seq, vocab_size = 2, 1024, 1024
input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
model(input_ids, input_mask)
try:
onx = to_onnx(model, (input_ids, input_mask))
print(onnx_simple_text_plot(onx))
except Exception as e:
print(f"conversion is broken due to {e}")
>>>
conversion is broken due to cannot mutate tensors with frozen storage
While executing %masked_fill__1 : [num_users=0] = call_method[target=masked_fill_](args = (%mask_1, %context_mask, -3.4028234663852886e+38), kwargs = {})
Original traceback:
File "/home/xadupre/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1012, in forward
attention_mask = _prepare_4d_causal_attention_mask(
File "/home/xadupre/.local/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 267, in _prepare_4d_causal_attention_mask
attention_mask = attn_mask_converter.to_4d(
File "/home/xadupre/.local/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 121, in to_4d
causal_4d_mask = self._make_causal_mask(
File "/home/xadupre/.local/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 170, in _make_causal_mask
mask.masked_fill_(context_mask, torch.finfo(dtype).min)