yobx.torch.in_transformers.classes.llama_attention#

Direct ONNX converter for transformers.models.llama.modeling_llama.LlamaAttention.

This converter appends ONNX nodes to an existing GraphBuilder from a fitted transformers.models.llama.modeling_llama.LlamaAttention module, without going through torch.export.export().

Three computation backends are supported, selected automatically based on the opsets registered in the graph builder:

  • com.microsoft ("com.microsoft" domain registered in the builder): Uses the com.microsoft.RotaryEmbedding contrib op for rotary embeddings and the com.microsoft.MultiHeadAttention contrib op for attention from onnxruntime. GQA key/value heads are expanded (via repeat-interleave) to match the query head count before the attention op. The model runs efficiently on CPU and CUDA with OnnxRuntime.

  • opset ≥ 24 (main opset ≥ 24): Uses the standard ONNX RotaryEmbedding operator (opset ≥ 23) for rotary embeddings and the standard ONNX Attention operator introduced in opset 23 (revision 24 fixes a correctness bug; this converter therefore requires 24) for attention.

  • opset 23 (main opset == 23): Uses the standard ONNX RotaryEmbedding operator for rotary embeddings and basic ONNX ops (MatMul, Softmax, Transpose, …) for attention.

  • opset ≤ 22 (default fallback): Uses basic ONNX ops (MatMul, Softmax, Transpose, …) for both rotary embeddings and attention.

Supported dtypes:

float32, float16, and bfloat16 are all supported. The output dtype is inferred from the registered type of hidden_states in the graph builder.

Expected graph inputs (must be declared in the builder before calling the converter):

  • hidden_states: (batch, seq, hidden_size)

  • cos: (batch, seq, head_dim)

  • sin: (batch, seq, head_dim)

  • attention_mask (optional): (batch, 1, seq_q, total_seq)

Output returned by the converter:

  • tensor name (batch, seq, hidden_size) — the caller must register it as a graph output via g.make_tensor_output

Example:

import torch
from onnx import TensorProto
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from yobx.xbuilder import GraphBuilder
from yobx.torch.in_transformers.models import llama_attention_to_onnx

config = LlamaConfig(
    hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16
)
attn = LlamaAttention(config, layer_idx=0).eval()

# opset 22 — plain ONNX ops (MatMul, Softmax, …)
g = GraphBuilder({"": 22}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()

# opset 24 — ONNX Attention op
g = GraphBuilder({"": 24}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()

# OnnxRuntime contrib ops
g = GraphBuilder({"": 22, "com.microsoft": 1}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()
yobx.torch.in_transformers.classes.llama_attention.llama_attention_to_onnx(g: GraphBuilderExtendedProtocol, attn: LlamaAttention, hidden_states: str, cos: str, sin: str, attention_mask: str | None = None, name: str = 'llama_attention') str[source]#

Appends ONNX nodes implementing transformers.models.llama.modeling_llama.LlamaAttention to g.

The output dtype (float32, float16, or bfloat16) is inferred from the registered type of hidden_states in g. Model weights are cast to match.

The backend is chosen from the opsets registered in g:

  • com.microsoftcom.microsoft.RotaryEmbedding for rotary embeddings and com.microsoft.MultiHeadAttention for attention (OnnxRuntime). The "com.microsoft" domain must be registered in g. GQA KV heads are expanded to match the query head count before the attention op.

  • opset ≥ 24 — standard ONNX RotaryEmbedding op (opset ≥ 23) for rotary embeddings and the standard ONNX Attention op for attention.

  • opset 23 — standard ONNX RotaryEmbedding op for rotary embeddings and plain ONNX ops (MatMul, Softmax, Transpose, …) for attention.

  • opset ≤ 22 — plain ONNX ops for both rotary embeddings and attention. This is the default fallback path.

Note

The cos and sin inputs are expected to carry symmetric values, i.e. cos[..., :head_dim//2] == cos[..., head_dim//2:] (and likewise for sin). This matches what transformers.models.llama.modeling_llama.LlamaRotaryEmbedding produces. The dedicated ONNX/ORT RotaryEmbedding ops use only the first half of the last dimension; the plain-op fallback uses the full tensor unchanged.

Parameters:
  • g – an existing graph builder — inputs must already be declared with their types; the function appends nodes without creating new graph inputs or outputs

  • attn – a fitted LlamaAttention module (weights must be initialised; the module may be in any of float32, float16, or bfloat16)

  • hidden_states – name of the (batch, seq, hidden_size) input tensor already declared in g

  • cos – name of the (batch, seq, head_dim) cosine embedding tensor already declared in g

  • sin – name of the (batch, seq, head_dim) sine embedding tensor already declared in g

  • attention_mask – optional name of the (batch, 1, seq_q, total_seq) attention mask tensor already declared in g; pass None (default) when no mask is needed

  • name – prefix used for all node names added to g

Returns:

name of the output tensor (batch, seq, hidden_size); the caller is responsible for registering it as a graph output via g.make_tensor_output

Example:

from onnx import TensorProto
from yobx.xbuilder import GraphBuilder
from yobx.torch.in_transformers.models import llama_attention_to_onnx

g = GraphBuilder({"": 22}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()