yobx.helpers.rt_helper#

yobx.helpers.rt_helper.make_feeds(proto: ModelProto | Sequence[str], inputs: Any, use_numpy: bool = False, copy: bool = False, is_modelbuilder: bool = False) Dict[str, torch.Tensor | ndarray][source]#

Serializes the inputs to produce feeds expected by onnxruntime.InferenceSession.

Parameters:
  • proto – onnx model or list of names

  • inputs – any kind of inputs

  • use_numpy – if True, converts torch tensors into numpy arrays

  • copy – a copy is made, this should be the case if the inputs is ingested by OrtValue

  • is_modelbuilder – if True, the exporter is ModelBuilder, and we need to reorder the past_key_values inputs to match the expected order, and get rid of position_ids.

Returns:

feeds dictionary

yobx.helpers.rt_helper.onnx_generate(model_or_path: str | ModelProto, input_ids: ndarray | torch.Tensor, attention_mask: ndarray | torch.Tensor | None = None, eos_token_id: int | None = None, max_new_tokens: int = 20, do_sample: bool = False, return_session: bool = False, verbose: int = 0) ndarray | torch.Tensor[source]#

Performs auto-regressive token generation using an exported ONNX model.

The function mimics the generate method of HuggingFace transformers models. It calls the ONNX forward pass in a loop, appending the most likely next token at each step (greedy decoding by default), and feeds the updated past key/value tensors back into the model on each subsequent call.

Models that do not expose past-key-value inputs/outputs are also supported: in that case the full input_ids sequence is fed on every step (simpler but less efficient).

Parameters:
  • model_or_path – path to an .onnx file or a onnx.ModelProto loaded into memory.

  • input_ids – initial prompt token IDs, integer array/tensor of shape [batch, seq_len].

  • attention_mask – optional attention mask of shape [batch, seq_len]. When None, an all-ones mask matching input_ids is created automatically.

  • eos_token_id – when set, generation stops as soon as all batch items have produced this token.

  • max_new_tokens – upper bound on the number of tokens to generate (not counting the original input_ids).

  • do_sample – when True sample the next token from the softmax distribution; when False (default) use greedy argmax.

  • verbose – verbosity level (0 = silent).

  • return_session – when True return a 3-tuple (tokens, session, last_feeds) instead of just the tokens.

Returns:

integer array/tensor of shape [batch, seq_len + generated_tokens] containing the original prompt followed by the generated tokens. The type matches input_ids: numpy.ndarray when the caller passed NumPy arrays, torch.Tensor otherwise.

Example with a tiny synthetic ONNX decoder (no KV cache):

import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from yobx.helpers.rt_helper import onnx_generate

TINT64 = onnx.TensorProto.INT64
TFLOAT = onnx.TensorProto.FLOAT
VOCAB  = 8

# A minimal "LM head": always returns the same logits so that the
# argmax always picks token 3.
fixed_logits = np.zeros((1, 1, VOCAB), dtype=np.float32)
fixed_logits[0, 0, 3] = 10.0   # token 3 always wins

model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node(
                "Constant",
                [],
                ["logits"],
                value=onh.from_array(fixed_logits),
            ),
        ],
        "tiny_lm",
        [oh.make_tensor_value_info("input_ids", TINT64, [1, None])],
        [oh.make_tensor_value_info("logits", TFLOAT, [1, 1, VOCAB])],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=9,
)

prompt = np.array([[1, 2]], dtype=np.int64)
tokens = onnx_generate(model, prompt, max_new_tokens=3, eos_token_id=3)
# tokens == [[1, 2, 3]]  (stops after the first EOS token)

Note

When the ONNX model exposes past key/value inputs, the function automatically creates zero-filled tensors for the initial call and feeds back the corresponding outputs on every subsequent step. The KV-cache heuristic treats any input whose name is not in {input_ids, attention_mask, position_ids, token_type_ids, cache_position} as a KV-cache slot. Present-key/value outputs are mapped back to past-key/value inputs by position (i.e. outputs[1]cache_inputs[0], etc.).