Source code for onnx_diagnostic.helpers.rt_helper

from typing import Any, Dict, List, Union
import numpy as np
import onnx
import torch
from .helper import string_type, flatten_object
from .cache_helper import is_cache_dynamic_registered


def name_type_to_onnx_dtype(name: str) -> int:
    if name == "tensor(int64)":
        return onnx.TensorProto.INT64
    if name == "tensor(float)":
        return onnx.TensorProto.FLOAT
    if name == "tensor(float16)":
        return onnx.TensorProto.FLOAT16
    raise AssertionError(f"Unexpected value {name!r}")


[docs] def make_feeds( proto: Union[onnx.ModelProto, List[str]], inputs: Any, use_numpy: bool = False, copy: bool = False, check_flatten: bool = True, is_modelbuilder: bool = False, ) -> Dict[str, Union[torch.Tensor, np.ndarray]]: """ Serializes the inputs to produce feeds expected by :class:`onnxruntime.InferenceSession`. :param proto: onnx model or list of names :param inputs: any kind of inputs :param use_numpy: if True, converts torch tensors into numpy arrays :param copy: a copy is made, this should be the case if the inputs is ingested by ``OrtValue`` :param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten`` returns the same number of outputs :param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder the past_key_values inputs to match the expected order, and get rid of position_ids. :return: feeds dictionary """ # NOTE: position_ids is a special case because ModelBuilder does not usually use it, # because it's fued into rotary embedding in GQA. if is_modelbuilder and isinstance(inputs, dict): inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing. flat = flatten_object(inputs, drop_keys=True) assert ( not check_flatten or not all(isinstance(obj, torch.Tensor) for obj in flat) or not is_cache_dynamic_registered(fast=True) or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0]) ), ( f"Unexpected number of flattened objects, " f"{string_type(flat, with_shape=True)} != " f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}" ) if use_numpy: flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat] names = ( [i.name for i in proto.graph.input] if isinstance(proto, onnx.ModelProto) else ( [i.name for i in proto.get_inputs()] if hasattr(proto, "get_inputs") else (proto.input_names if hasattr(proto, "input_names") else proto) ) ) assert ( isinstance(names, list) and len(names) <= len(flat) and ( len(names) == len(flat) or isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs") ) ), ( f"Not the same number of given inputs {len(flat)} " f"and the number of model inputs {len(names)}, " f"type(names)={type(names)}, type(proto)={type(proto)}" f"\n-- inputs={string_type(inputs, with_shape=True)}" f"\n-- names={names}" ) if copy: flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat] # bool, int, float, onnxruntime does not support float, bool, int new_flat = [] for i in flat: if isinstance(i, bool): i = np.array(i, dtype=np.bool_) elif isinstance(i, int): i = np.array(i, dtype=np.int64) elif isinstance(i, float): i = np.array(i, dtype=np.float32) new_flat.append(i) # NOTE: model builder has a different order for past_key_values # we need to reorder them to match the expected order if is_modelbuilder: # We assume that if "past_key_values" is in the names when it's # modelbuilder non_past_kv_input_names = [n for n in names if "past_key_values" not in n] past_kv_names = [n for n in names if "past_key_values" in n] reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names) names = non_past_kv_input_names + reorder_past_kv_names return dict(zip(names, new_flat))
[docs] def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]: """ Reorders the past_kvs for ModelBuilder to match the expected order by PyTorch exported models. .. note:: This function can take either the names or the actual tensors as long as they are in a list. Conceptually, From:: [past_key_values.0.key, past_key_values.0.value, past_key_values.1.key, past_key_values.1.value, ...] To:: [past_key_values.0.key, past_key_values.1.key, ..., past_key_values.0.value, past_key_values.1.value, ...] :param past_kv: list of flattened inputs :return: reordered list of flattened inputs """ total_len = len(past_kv) if total_len % 2 != 0: raise ValueError("The length of past_key_values should be even.") keys = [] values = [] for i in range(0, total_len, 2): keys.append(past_kv[i]) values.append(past_kv[i + 1]) return keys + values