yobx.torch.in_transformers.classes.llama_attention#
Direct ONNX converter for transformers.models.llama.modeling_llama.LlamaAttention.
This converter appends ONNX nodes to an existing GraphBuilder
from a fitted transformers.models.llama.modeling_llama.LlamaAttention module,
without going through torch.export.export().
Three computation backends are supported, selected automatically based on the opsets registered in the graph builder:
com.microsoft (
"com.microsoft"domain registered in the builder): Uses thecom.microsoft.RotaryEmbeddingcontrib op for rotary embeddings and thecom.microsoft.MultiHeadAttentioncontrib op for attention from onnxruntime. GQA key/value heads are expanded (via repeat-interleave) to match the query head count before the attention op. The model runs efficiently on CPU and CUDA with OnnxRuntime.opset ≥ 24 (main opset ≥ 24): Uses the standard ONNX
RotaryEmbeddingoperator (opset ≥ 23) for rotary embeddings and the standard ONNXAttentionoperator introduced in opset 23 (revision 24 fixes a correctness bug; this converter therefore requires 24) for attention.opset 23 (main opset == 23): Uses the standard ONNX
RotaryEmbeddingoperator for rotary embeddings and basic ONNX ops (MatMul,Softmax,Transpose, …) for attention.opset ≤ 22 (default fallback): Uses basic ONNX ops (
MatMul,Softmax,Transpose, …) for both rotary embeddings and attention.
- Supported dtypes:
float32,float16, andbfloat16are all supported. The output dtype is inferred from the registered type of hidden_states in the graph builder.
Expected graph inputs (must be declared in the builder before calling the converter):
hidden_states:(batch, seq, hidden_size)cos:(batch, seq, head_dim)sin:(batch, seq, head_dim)attention_mask(optional):(batch, 1, seq_q, total_seq)
Output returned by the converter:
tensor name
(batch, seq, hidden_size)— the caller must register it as a graph output viag.make_tensor_output
Example:
import torch
from onnx import TensorProto
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from yobx.xbuilder import GraphBuilder
from yobx.torch.in_transformers.models import llama_attention_to_onnx
config = LlamaConfig(
hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16
)
attn = LlamaAttention(config, layer_idx=0).eval()
# opset 22 — plain ONNX ops (MatMul, Softmax, …)
g = GraphBuilder({"": 22}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()
# opset 24 — ONNX Attention op
g = GraphBuilder({"": 24}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()
# OnnxRuntime contrib ops
g = GraphBuilder({"": 22, "com.microsoft": 1}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()
- yobx.torch.in_transformers.classes.llama_attention.llama_attention_to_onnx(g: GraphBuilderExtendedProtocol, attn: LlamaAttention, hidden_states: str, cos: str, sin: str, attention_mask: str | None = None, name: str = 'llama_attention') str[source]#
Appends ONNX nodes implementing
transformers.models.llama.modeling_llama.LlamaAttentionto g.The output dtype (
float32,float16, orbfloat16) is inferred from the registered type of hidden_states in g. Model weights are cast to match.The backend is chosen from the opsets registered in g:
com.microsoft —
com.microsoft.RotaryEmbeddingfor rotary embeddings andcom.microsoft.MultiHeadAttentionfor attention (OnnxRuntime). The"com.microsoft"domain must be registered in g. GQA KV heads are expanded to match the query head count before the attention op.opset ≥ 24 — standard ONNX
RotaryEmbeddingop (opset ≥ 23) for rotary embeddings and the standard ONNXAttentionop for attention.opset 23 — standard ONNX
RotaryEmbeddingop for rotary embeddings and plain ONNX ops (MatMul,Softmax,Transpose, …) for attention.opset ≤ 22 — plain ONNX ops for both rotary embeddings and attention. This is the default fallback path.
Note
The
cosandsininputs are expected to carry symmetric values, i.e.cos[..., :head_dim//2] == cos[..., head_dim//2:](and likewise forsin). This matches whattransformers.models.llama.modeling_llama.LlamaRotaryEmbeddingproduces. The dedicated ONNX/ORTRotaryEmbeddingops use only the first half of the last dimension; the plain-op fallback uses the full tensor unchanged.- Parameters:
g – an existing graph builder — inputs must already be declared with their types; the function appends nodes without creating new graph inputs or outputs
attn – a fitted
LlamaAttentionmodule (weights must be initialised; the module may be in any offloat32,float16, orbfloat16)hidden_states – name of the
(batch, seq, hidden_size)input tensor already declared in gcos – name of the
(batch, seq, head_dim)cosine embedding tensor already declared in gsin – name of the
(batch, seq, head_dim)sine embedding tensor already declared in gattention_mask – optional name of the
(batch, 1, seq_q, total_seq)attention mask tensor already declared in g; passNone(default) when no mask is neededname – prefix used for all node names added to g
- Returns:
name of the output tensor
(batch, seq, hidden_size); the caller is responsible for registering it as a graph output viag.make_tensor_output
Example:
from onnx import TensorProto from yobx.xbuilder import GraphBuilder from yobx.torch.in_transformers.models import llama_attention_to_onnx g = GraphBuilder({"": 22}, verbose=0) g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64)) g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16)) g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16)) out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin") g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64)) model = g.to_onnx()