MistralΒΆ

Mistral

<<<

import numpy as np
import torch
from transformers import MistralConfig
from transformers.models.mistral.modeling_mistral import MistralModel
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = MistralConfig(
    hidden_size=32,
    num_hidden_layers=2,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=512,
    num_attention_heads=2,
    num_key_value_heads=2,
)
config._attn_implementation = "eager"

with torch.no_grad():

    model = MistralModel(config)

    batch, seq, vocab_size = 2, 1024, 1024

    input_ids = ids_tensor([batch, seq], vocab_size)
    input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

    model(input_ids, input_mask)

    try:
        onx = to_onnx(
            model,
            (input_ids, input_mask),
            export_options=ExportOptions(decomposition_table="default"),
        )
        print(pretty_onnx(onx))
    except Exception as e:
        print(f"conversion is broken due to {e}")

>>>

    opset: domain='' version=18
    doc_string: large_model=False, inline=False, external_threshold=102...
    input: name='input_ids' type=dtype('int64') shape=[2, 1024]
    input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
    init: name='b_layers_0_self_attn_rotary_emb_inv_freq' type=float32 shape=(8,)-- DynamoInterpret.placeholder.0
    init: name='b_layers_1_self_attn_rotary_emb_inv_freq' type=float32 shape=(8,)-- DynamoInterpret.placeholder.0
    init: name='init7_s_0' type=int64 shape=() -- array([0])              -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s_1024' type=int64 shape=() -- array([1024])        -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s_1' type=int64 shape=() -- array([1])              -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s2_1024_1024' type=int64 shape=(2,) -- array([1024, 1024])-- Opset.make_node.1/Shape
    init: name='init7_s2_-1_1' type=int64 shape=(2,) -- array([-1,  1])   -- Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s_4096' type=int64 shape=() -- array([4096])        -- shape_type_compute._cast_inputs.0
    init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s4_2_1_1024_1024' type=int64 shape=(4,) -- array([   2,    1, 1024, 1024])-- GraphBuilder.make_shape_from_results.shape
    init: name='init1_s_' type=float32 shape=() -- array([0.], dtype=float32)-- shape_type_compute._cast_inputs.0
    init: name='init1_s1_' type=float32 shape=(1,) -- array([-3.403e+38], dtype=float32)-- Opset.make_node.1/Small
    init: name='init1_s1_2' type=float32 shape=(1,) -- array([2.], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small##Opset.make_node.1/Small##Opset.make_node.1/Small##Opset.make_node.1/Small
    init: name='init7_s1_-1' type=int64 shape=(1,) -- array([-1])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s_2' type=float32 shape=() -- array([1.e-06], dtype=float32)-- shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0
    init: name='init7_s4_2_1024_-1_16' type=int64 shape=(4,) -- array([   2, 1024,   -1,   16])-- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s_3' type=float32 shape=() -- array([4.], dtype=float32)-- shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0
    init: name='init7_s3_2_1024_-1' type=int64 shape=(3,) -- array([   2, 1024,   -1])-- Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s2_0_1' type=int64 shape=(2,) -- array([0, 1])      -- UnsqueezeUnsqueezePattern.apply.new_axis
    init: name='init7_s2_1_2' type=int64 shape=(2,) -- array([1, 2])      -- UnsqueezeUnsqueezePattern.apply.new_axis
    init: name='init7_s2_0_2' type=int64 shape=(2,) -- array([0, 2])      -- UnsqueezeUnsqueezePattern.apply.new_axis##UnsqueezeUnsqueezePattern.apply.new_axis
    init: name='init7_s2_0_12' type=int64 shape=(2,) -- array([0, 1])     -- UnsqueezeUnsqueezePattern.apply.new_axis
    init: name='init7_s2_-1_32' type=int64 shape=(2,) -- array([-1, 32])  -- MatMulAddPattern.new_shape.1##MatMulAddPattern.new_shape.3##MatMulAddPattern.new_shape.3##MatMulAddPattern.new_shape.1##MatMulAddPattern.new_shape.3##MatMulAddPattern.new_shape.3
    init: name='init7_s3_2_1024_-12' type=int64 shape=(3,) -- array([   2, 1024,   -1])-- MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2
    init: name='init7_s2_-1_16' type=int64 shape=(2,) -- array([-1, 16])  -- MatMulAddPattern.new_shape.1##MatMulAddPattern.new_shape.1
    init: name='init7_s2_8_8' type=int64 shape=(2,) -- array([8, 8])      -- SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits
    init: name='layers.0.input_layernorm.weight' type=float32 shape=(32,) -- DynamoInterpret.placeholder.1/P(layers.0.input_layernorm.weight)
    init: name='layers.0.post_attention_layernorm.weight' type=float32 shape=(32,)-- DynamoInterpret.placeholder.1/P(layers.0.post_attention_layernorm.weight)
    init: name='layers.0.self_attn.q_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.q_proj.weight)
    init: name='layers.0.self_attn.k_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.k_proj.weight)
    init: name='layers.0.self_attn.v_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.v_proj.weight)
    init: name='layers.0.self_attn.o_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.o_proj.weight)
    init: name='layers.0.mlp.gate_proj.weight' type=float32 shape=(16, 32)-- DynamoInterpret.placeholder.1/P(layers.0.mlp.gate_proj.weight)
    init: name='layers.0.mlp.up_proj.weight' type=float32 shape=(16, 32)  -- DynamoInterpret.placeholder.1/P(layers.0.mlp.up_proj.weight)
    init: name='layers.0.mlp.down_proj.weight' type=float32 shape=(32, 16)-- DynamoInterpret.placeholder.1/P(layers.0.mlp.down_proj.weight)
    init: name='layers.1.input_layernorm.weight' type=float32 shape=(32,) -- DynamoInterpret.placeholder.1/P(layers.1.input_layernorm.weight)
    init: name='layers.1.post_attention_layernorm.weight' type=float32 shape=(32,)-- DynamoInterpret.placeholder.1/P(layers.1.post_attention_layernorm.weight)
    init: name='layers.1.self_attn.q_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.1.self_attn.q_proj.weight)
    init: name='layers.1.self_attn.k_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.1.self_attn.k_proj.weight)
    init: name='layers.1.self_attn.v_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.1.self_attn.v_proj.weight)
    init: name='layers.1.self_attn.o_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.1.self_attn.o_proj.weight)
    init: name='layers.1.mlp.gate_proj.weight' type=float32 shape=(16, 32)-- DynamoInterpret.placeholder.1/P(layers.1.mlp.gate_proj.weight)
    init: name='layers.1.mlp.up_proj.weight' type=float32 shape=(16, 32)  -- DynamoInterpret.placeholder.1/P(layers.1.mlp.up_proj.weight)
    init: name='layers.1.mlp.down_proj.weight' type=float32 shape=(32, 16)-- DynamoInterpret.placeholder.1/P(layers.1.mlp.down_proj.weight)
    init: name='norm.weight' type=float32 shape=(32,)                     -- DynamoInterpret.placeholder.1/P(norm.weight)
    init: name='embed_tokens.weight' type=float32 shape=(1024, 32)        -- DynamoInterpret.placeholder.1/P(embed_tokens.weight)
    ConstantOfShape(init7_s2_1024_1024, value=[-3.402823...) -> full
    Gather(embed_tokens.weight, input_ids) -> embedding
      Pow(embedding, init1_s1_2) -> pow_1
        ReduceMean(pow_1, init7_s1_-1, keepdims=1) -> mean
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange
      Reshape(arange, init7_s2_-1_1) -> view
    Unsqueeze(arange, init7_s2_0_12) -> unsqueeze_9
      Cast(unsqueeze_9, to=1) -> _to_copy_3
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_1
      Greater(arange_1, view) -> gt
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_2
    Reshape(init7_s_4096, init7_s1_1) -> _onx_reshape0
      Sub(view, _onx_reshape0) -> sub
      LessOrEqual(arange_2, sub) -> le
        Or(gt, le) -> bitwise_or
          Cast(bitwise_or, to=1) -> _onx_cast0
      Mul(full, _onx_cast0) -> _onx_mul0
        Unsqueeze(_onx_mul0, init7_s2_0_1) -> unsqueeze_4
          Expand(unsqueeze_4, init7_s4_2_1_1024_1024) -> expand_1
    Unsqueeze(attention_mask, init7_s2_1_2) -> unsqueeze_6
      Add(expand_1, unsqueeze_6) -> add
    Reshape(init1_s_, init7_s1_1) -> _onx_reshape02
      Equal(add, _onx_reshape02) -> eq
        Where(eq, init1_s1_, expand_1) -> masked_fill
    Reshape(init1_s_2, init7_s1_1) -> _onx_reshape03
      Add(mean, _onx_reshape03) -> add_1
        Sqrt(add_1) -> _onx_sqrt0
          Reciprocal(_onx_sqrt0) -> rsqrt
      Mul(embedding, rsqrt) -> mul_1
        Mul(layers.0.input_layernorm.weight, mul_1) -> mul_2
    Transpose(layers.0.self_attn.q_proj.weight, perm=[1,0]) -> _onx_transpose0
      MatMul(mul_2, _onx_transpose0) -> linear
        Reshape(linear, init7_s4_2_1024_-1_16) -> view_2
          Transpose(view_2, perm=[0,2,1,3]) -> transpose
            Split(transpose, init7_s2_8_8, axis=3) -> slice_24, slice_25
              Neg(slice_25) -> neg
              Concat(neg, slice_24, axis=-1) -> cat_1
    Transpose(layers.0.self_attn.k_proj.weight, perm=[1,0]) -> _onx_transpose02
      MatMul(mul_2, _onx_transpose02) -> linear_1
        Reshape(linear_1, init7_s4_2_1024_-1_16) -> view_3
          Transpose(view_3, perm=[0,2,1,3]) -> transpose_1
            Split(transpose_1, init7_s2_8_8, axis=3) -> slice_26, slice_27
              Neg(slice_27) -> neg_1
              Concat(neg_1, slice_26, axis=-1) -> cat_2
    Transpose(layers.0.self_attn.v_proj.weight, perm=[1,0]) -> _onx_transpose03
      MatMul(mul_2, _onx_transpose03) -> linear_2
        Reshape(linear_2, init7_s4_2_1024_-1_16) -> view_4
          Transpose(view_4, perm=[0,2,1,3]) -> output_2
    Unsqueeze(b_layers_0_self_attn_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_8
      MatMul(unsqueeze_8, _to_copy_3) -> matmul
        Transpose(matmul, perm=[0,2,1]) -> transpose_3
          Concat(transpose_3, transpose_3, axis=-1) -> cat
            Cos(cat) -> cos
              Unsqueeze(cos, init7_s1_1) -> unsqueeze_10
            Mul(transpose, unsqueeze_10) -> mul_3
    Sin(cat) -> sin
      Unsqueeze(sin, init7_s1_1) -> unsqueeze_11
        Mul(cat_1, unsqueeze_11) -> mul_4
          Add(mul_3, mul_4) -> add_2
    Mul(transpose_1, unsqueeze_10) -> mul_5
    Mul(cat_2, unsqueeze_11) -> mul_6
      Add(mul_5, mul_6) -> output_1
        Transpose(output_1, perm=[0,1,3,2]) -> transpose_4
          MatMul(add_2, transpose_4) -> matmul_1
    Reshape(init1_s_3, init7_s1_1) -> _onx_reshape04
      Div(matmul_1, _onx_reshape04) -> div
        Add(div, masked_fill) -> add_4
          Softmax(add_4, axis=-1) -> softmax
            MatMul(softmax, output_2) -> matmul_2
              Transpose(matmul_2, perm=[0,2,1,3]) -> transpose_5
                Reshape(transpose_5, init7_s3_2_1024_-1) -> view_5
                  Reshape(view_5, init7_s2_-1_32) -> MatMulAddPattern--view_5
      Reshape(embedding, init7_s2_-1_32) -> MatMulAddPattern--view_52
        Gemm(MatMulAddPattern--view_5, layers.0.self_attn.o_proj.weight, MatMulAddPattern--view_52, transB=1) -> MatMulAddPattern--view_53
          Reshape(MatMulAddPattern--view_53, init7_s3_2_1024_-12) -> add_5
            Pow(add_5, init1_s1_2) -> pow_2
              ReduceMean(pow_2, init7_s1_-1, keepdims=1) -> mean_1
    Reshape(init1_s_2, init7_s1_1) -> _onx_reshape05
      Add(mean_1, _onx_reshape05) -> add_6
        Sqrt(add_6) -> _onx_sqrt02
          Reciprocal(_onx_sqrt02) -> rsqrt_1
            Mul(add_5, rsqrt_1) -> mul_7
              Mul(layers.0.post_attention_layernorm.weight, mul_7) -> mul_8
    Transpose(layers.0.mlp.gate_proj.weight, perm=[1,0]) -> _onx_transpose05
      MatMul(mul_8, _onx_transpose05) -> linear_4
        Sigmoid(linear_4) -> _onx_sigmoid0
        Mul(linear_4, _onx_sigmoid0) -> silu
    Transpose(layers.0.mlp.up_proj.weight, perm=[1,0]) -> _onx_transpose06
      MatMul(mul_8, _onx_transpose06) -> linear_5
        Mul(silu, linear_5) -> mul_9
          Reshape(mul_9, init7_s2_-1_16) -> MatMulAddPattern--mul_9
    Reshape(add_5, init7_s2_-1_32) -> MatMulAddPattern--mul_92
      Gemm(MatMulAddPattern--mul_9, layers.0.mlp.down_proj.weight, MatMulAddPattern--mul_92, transB=1) -> MatMulAddPattern--mul_93
        Reshape(MatMulAddPattern--mul_93, init7_s3_2_1024_-12) -> add_7
          Pow(add_7, init1_s1_2) -> pow_3
            ReduceMean(pow_3, init7_s1_-1, keepdims=1) -> mean_2
    Reshape(init1_s_2, init7_s1_1) -> _onx_reshape06
      Add(mean_2, _onx_reshape06) -> add_8
        Sqrt(add_8) -> _onx_sqrt03
          Reciprocal(_onx_sqrt03) -> rsqrt_2
          Mul(add_7, rsqrt_2) -> mul_10
            Mul(layers.1.input_layernorm.weight, mul_10) -> mul_11
    Transpose(layers.1.self_attn.q_proj.weight, perm=[1,0]) -> _onx_transpose08
      MatMul(mul_11, _onx_transpose08) -> linear_7
        Reshape(linear_7, init7_s4_2_1024_-1_16) -> view_6
          Transpose(view_6, perm=[0,2,1,3]) -> transpose_6
            Split(transpose_6, init7_s2_8_8, axis=3) -> slice_37, slice_38
              Neg(slice_38) -> neg_2
              Concat(neg_2, slice_37, axis=-1) -> cat_4
    Transpose(layers.1.self_attn.k_proj.weight, perm=[1,0]) -> _onx_transpose09
      MatMul(mul_11, _onx_transpose09) -> linear_8
        Reshape(linear_8, init7_s4_2_1024_-1_16) -> view_7
          Transpose(view_7, perm=[0,2,1,3]) -> transpose_7
            Split(transpose_7, init7_s2_8_8, axis=3) -> slice_39, slice_40
              Neg(slice_40) -> neg_3
              Concat(neg_3, slice_39, axis=-1) -> cat_5
    Transpose(layers.1.self_attn.v_proj.weight, perm=[1,0]) -> _onx_transpose010
      MatMul(mul_11, _onx_transpose010) -> linear_9
        Reshape(linear_9, init7_s4_2_1024_-1_16) -> view_8
          Transpose(view_8, perm=[0,2,1,3]) -> output_4
    Unsqueeze(b_layers_1_self_attn_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_13
      MatMul(unsqueeze_13, _to_copy_3) -> matmul_3
        Transpose(matmul_3, perm=[0,2,1]) -> transpose_9
          Concat(transpose_9, transpose_9, axis=-1) -> cat_3
            Cos(cat_3) -> cos_1
              Unsqueeze(cos_1, init7_s1_1) -> unsqueeze_15
            Mul(transpose_6, unsqueeze_15) -> mul_12
    Sin(cat_3) -> sin_1
      Unsqueeze(sin_1, init7_s1_1) -> unsqueeze_16
        Mul(cat_4, unsqueeze_16) -> mul_13
          Add(mul_12, mul_13) -> add_9
    Mul(transpose_7, unsqueeze_15) -> mul_14
    Mul(cat_5, unsqueeze_16) -> mul_15
      Add(mul_14, mul_15) -> output_3
        Transpose(output_3, perm=[0,1,3,2]) -> transpose_10
          MatMul(add_9, transpose_10) -> matmul_4
    Reshape(init1_s_3, init7_s1_1) -> _onx_reshape07
      Div(matmul_4, _onx_reshape07) -> div_1
        Add(div_1, masked_fill) -> add_11
          Softmax(add_11, axis=-1) -> softmax_1
            MatMul(softmax_1, output_4) -> matmul_5
              Transpose(matmul_5, perm=[0,2,1,3]) -> transpose_11
                Reshape(transpose_11, init7_s3_2_1024_-1) -> view_9
                  Reshape(view_9, init7_s2_-1_32) -> MatMulAddPattern--view_9
          Reshape(add_7, init7_s2_-1_32) -> MatMulAddPattern--view_92
            Gemm(MatMulAddPattern--view_9, layers.1.self_attn.o_proj.weight, MatMulAddPattern--view_92, transB=1) -> MatMulAddPattern--view_93
              Reshape(MatMulAddPattern--view_93, init7_s3_2_1024_-12) -> add_12
                Pow(add_12, init1_s1_2) -> pow_4
                  ReduceMean(pow_4, init7_s1_-1, keepdims=1) -> mean_3
    Reshape(init1_s_2, init7_s1_1) -> _onx_reshape08
      Add(mean_3, _onx_reshape08) -> add_13
        Sqrt(add_13) -> _onx_sqrt04
          Reciprocal(_onx_sqrt04) -> rsqrt_3
            Mul(add_12, rsqrt_3) -> mul_16
              Mul(layers.1.post_attention_layernorm.weight, mul_16) -> mul_17
    Transpose(layers.1.mlp.gate_proj.weight, perm=[1,0]) -> _onx_transpose012
      MatMul(mul_17, _onx_transpose012) -> linear_11
        Sigmoid(linear_11) -> _onx_sigmoid02
        Mul(linear_11, _onx_sigmoid02) -> silu_1
    Transpose(layers.1.mlp.up_proj.weight, perm=[1,0]) -> _onx_transpose013
      MatMul(mul_17, _onx_transpose013) -> linear_12
        Mul(silu_1, linear_12) -> mul_18
          Reshape(mul_18, init7_s2_-1_16) -> MatMulAddPattern--mul_18
    Reshape(add_12, init7_s2_-1_32) -> MatMulAddPattern--mul_182
      Gemm(MatMulAddPattern--mul_18, layers.1.mlp.down_proj.weight, MatMulAddPattern--mul_182, transB=1) -> MatMulAddPattern--mul_183
        Reshape(MatMulAddPattern--mul_183, init7_s3_2_1024_-12) -> add_14
          Pow(add_14, init1_s1_2) -> pow_5
            ReduceMean(pow_5, init7_s1_-1, keepdims=1) -> mean_4
    Reshape(init1_s_2, init7_s1_1) -> _onx_reshape09
      Add(mean_4, _onx_reshape09) -> add_15
        Sqrt(add_15) -> _onx_sqrt05
          Reciprocal(_onx_sqrt05) -> rsqrt_4
          Mul(add_14, rsqrt_4) -> mul_19
            Mul(norm.weight, mul_19) -> output_0
    output: name='output_0' type=dtype('float32') shape=[2, 1024, 32]
    output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_3' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_4' type=dtype('float32') shape=[2, 2, 1024, 16]