yobx.torch.in_transformers.classes.llama_attention#

Direct ONNX converter for transformers.models.llama.modeling_llama.LlamaAttention.

This converter appends ONNX nodes to an existing GraphBuilder from a fitted transformers.models.llama.modeling_llama.LlamaAttention module, without going through torch.export.export().

Three computation backends are supported, selected automatically based on the opsets registered in the graph builder:

  • com.microsoft ("com.microsoft" domain registered in the builder): Uses the com.microsoft.MultiHeadAttention contrib op from onnxruntime. GQA key/value heads are expanded (via repeat-interleave) to match the query head count before being passed to the op. The model runs efficiently on CPU and CUDA with OnnxRuntime.

  • opset ≥ 24 (main opset ≥ 24): Uses the standard ONNX Attention operator introduced in opset 23 (revision 24 fixes a correctness bug; this converter therefore requires 24).

  • opset ≤ 22 (default fallback): Uses basic ONNX ops (MatMul, Softmax, Transpose, …).

Supported dtypes:

float32, float16, and bfloat16 are all supported. The output dtype is inferred from the registered type of hidden_states in the graph builder.

Expected graph inputs (must be declared in the builder before calling the converter):

  • hidden_states: (batch, seq, hidden_size)

  • cos: (batch, seq, head_dim)

  • sin: (batch, seq, head_dim)

  • attention_mask (optional): (batch, 1, seq_q, total_seq)

Output returned by the converter:

  • tensor name (batch, seq, hidden_size) — the caller must register it as a graph output via g.make_tensor_output

Example:

import torch
from onnx import TensorProto
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from yobx.xbuilder import GraphBuilder
from yobx.torch.in_transformers.models import llama_attention_to_onnx

config = LlamaConfig(
    hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16
)
attn = LlamaAttention(config, layer_idx=0).eval()

# opset 22 — plain ONNX ops (MatMul, Softmax, …)
g = GraphBuilder({"": 22}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()

# opset 24 — ONNX Attention op
g = GraphBuilder({"": 24}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()

# OnnxRuntime contrib ops
g = GraphBuilder({"": 22, "com.microsoft": 1}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()
yobx.torch.in_transformers.classes.llama_attention.llama_attention_to_onnx(g: GraphBuilderExtendedProtocol, attn: LlamaAttention, hidden_states: str, cos: str, sin: str, attention_mask: str | None = None, name: str = 'llama_attention') str[source]#

Appends ONNX nodes implementing transformers.models.llama.modeling_llama.LlamaAttention to g.

The output dtype (float32, float16, or bfloat16) is inferred from the registered type of hidden_states in g. Model weights are cast to match.

The backend is chosen from the opsets registered in g:

  • com.microsoftMultiHeadAttention contrib op (OnnxRuntime). The "com.microsoft" domain must be registered in g. GQA KV heads are expanded to match the query head count before the op.

  • opset ≥ 24 — standard ONNX Attention op.

  • opset ≤ 22 — plain ONNX ops (MatMul, Softmax, Transpose, …). This is the default path.

Parameters:
  • g – an existing graph builder — inputs must already be declared with their types; the function appends nodes without creating new graph inputs or outputs

  • attn – a fitted LlamaAttention module (weights must be initialised; the module may be in any of float32, float16, or bfloat16)

  • hidden_states – name of the (batch, seq, hidden_size) input tensor already declared in g

  • cos – name of the (batch, seq, head_dim) cosine embedding tensor already declared in g

  • sin – name of the (batch, seq, head_dim) sine embedding tensor already declared in g

  • attention_mask – optional name of the (batch, 1, seq_q, total_seq) attention mask tensor already declared in g; pass None (default) when no mask is needed

  • name – prefix used for all node names added to g

Returns:

name of the output tensor (batch, seq, hidden_size); the caller is responsible for registering it as a graph output via g.make_tensor_output

Example:

from onnx import TensorProto
from yobx.xbuilder import GraphBuilder
from yobx.torch.in_transformers.models import llama_attention_to_onnx

g = GraphBuilder({"": 22}, verbose=0)
g.make_tensor_input("hidden_states", TensorProto.FLOAT, ("batch", "seq", 64))
g.make_tensor_input("cos", TensorProto.FLOAT, ("batch", "seq", 16))
g.make_tensor_input("sin", TensorProto.FLOAT, ("batch", "seq", 16))
out = llama_attention_to_onnx(g, attn, "hidden_states", "cos", "sin")
g.make_tensor_output(out, TensorProto.FLOAT, ("batch", "seq", 64))
model = g.to_onnx()