Mistral

Mistral

<<<

import numpy as np
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
import torch
from transformers import MistralConfig
from transformers.models.mistral.modeling_mistral import MistralModel
from experimental_experiment.torch_interpreter import to_onnx


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = MistralConfig(
    hidden_size=32,
    num_hidden_layers=2,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=512,
    num_attention_heads=2,
    num_key_value_heads=2,
)
config._attn_implementation = "eager"

with torch.no_grad():

    model = MistralModel(config)

    batch, seq, vocab_size = 2, 1024, 1024

    input_ids = ids_tensor([batch, seq], vocab_size)
    input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

    model(input_ids, input_mask)

    try:
        onx = to_onnx(model, (input_ids, input_mask))
        print(onnx_simple_text_plot(onx))
    except Exception as e:
        print(f"conversion is broken due to {e}")

>>>

    conversion is broken due to cannot mutate tensors with frozen storage
    
    While executing %masked_fill__1 : [num_users=0] = call_method[target=masked_fill_](args = (%mask_1, %context_mask, -3.4028234663852886e+38), kwargs = {})
    Original traceback:
      File "/home/xadupre/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1012, in forward
        attention_mask = _prepare_4d_causal_attention_mask(
      File "/home/xadupre/.local/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 267, in _prepare_4d_causal_attention_mask
        attention_mask = attn_mask_converter.to_4d(
      File "/home/xadupre/.local/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 121, in to_4d
        causal_4d_mask = self._make_causal_mask(
      File "/home/xadupre/.local/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 170, in _make_causal_mask
        mask.masked_fill_(context_mask, torch.finfo(dtype).min)