yobx.torch.in_transformers.classes.llama_attention#
Direct ONNX converter for transformers.models.llama.modeling_llama.LlamaAttention.
This converter appends ONNX nodes to an existing GraphBuilder
from a fitted transformers.models.llama.modeling_llama.LlamaAttention module,
without going through torch.export.export().
Three computation backends are supported, selected automatically based on the opsets registered in the graph builder:
com.microsoft (
"com.microsoft"domain registered in the builder): Uses thecom.microsoft.MultiHeadAttentioncontrib op from onnxruntime. GQA key/value heads are expanded (via repeat-interleave) to match the query head count before being passed to the op. The model runs efficiently on CPU and CUDA with OnnxRuntime.opset ≥ 24 (main opset ≥ 24): Uses the standard ONNX
Attentionoperator introduced in opset 23 (revision 24 fixes a correctness bug; this converter therefore requires 24).opset ≤ 22 (default fallback): Uses basic ONNX ops (
MatMul,Softmax,Transpose, …).
- Supported dtypes:
float32,float16, andbfloat16are all supported. The output dtype is inferred from the registered type of hidden_states in the graph builder.
Expected graph inputs (must be declared in the builder before calling the converter):
hidden_states:(batch, seq, hidden_size)cos:(batch, seq, head_dim)sin:(batch, seq, head_dim)attention_mask(optional):(batch, 1, seq_q, total_seq)
Output returned by the converter:
tensor name
(batch, seq, hidden_size)— the caller must register it as a graph output viag.make_tensor_output
Example:
import torch
from onnx import TensorProto
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from yobx.xbuilder import GraphBuilder
from yobx.torch.in_transformers.models import llama_attention_to_onnx
config = LlamaConfig(
hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16
)
attn = LlamaAttention(config, layer_idx=0).eval()
# opset 22 — plain ONNX ops (MatMul, Softmax, …)
g = GraphBuilder({"": 22}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()
# opset 24 — ONNX Attention op
g = GraphBuilder({"": 24}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()
# OnnxRuntime contrib ops
g = GraphBuilder({"": 22, "com.microsoft": 1}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()
- yobx.torch.in_transformers.classes.llama_attention.llama_attention_to_onnx(g: GraphBuilderExtendedProtocol, attn: LlamaAttention, hidden_states: str, cos: str, sin: str, attention_mask: str | None = None, name: str = 'llama_attention') str[source]#
Appends ONNX nodes implementing
transformers.models.llama.modeling_llama.LlamaAttentionto g.The output dtype (
float32,float16, orbfloat16) is inferred from the registered type of hidden_states in g. Model weights are cast to match.The backend is chosen from the opsets registered in g:
com.microsoft —
MultiHeadAttentioncontrib op (OnnxRuntime). The"com.microsoft"domain must be registered in g. GQA KV heads are expanded to match the query head count before the op.opset ≥ 24 — standard ONNX
Attentionop.opset ≤ 22 — plain ONNX ops (
MatMul,Softmax,Transpose, …). This is the default path.
- Parameters:
g – an existing graph builder — inputs must already be declared with their types; the function appends nodes without creating new graph inputs or outputs
attn – a fitted
LlamaAttentionmodule (weights must be initialised; the module may be in any offloat32,float16, orbfloat16)hidden_states – name of the
(batch, seq, hidden_size)input tensor already declared in gcos – name of the
(batch, seq, head_dim)cosine embedding tensor already declared in gsin – name of the
(batch, seq, head_dim)sine embedding tensor already declared in gattention_mask – optional name of the
(batch, 1, seq_q, total_seq)attention mask tensor already declared in g; passNone(default) when no mask is neededname – prefix used for all node names added to g
- Returns:
name of the output tensor
(batch, seq, hidden_size); the caller is responsible for registering it as a graph output viag.make_tensor_output
Example:
from onnx import TensorProto from yobx.xbuilder import GraphBuilder from yobx.torch.in_transformers.models import llama_attention_to_onnx g = GraphBuilder({"": 22}, verbose=0) g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64)) g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16)) g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16)) out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin") g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64)) model = g.to_onnx()