Source code for onnx_diagnostic.helpers.rt_helper

from typing import Any, Dict, List, Union
import numpy as np
import onnx
import torch
from .helper import string_type, flatten_object
from .cache_helper import is_cache_dynamic_registered


[docs] def make_feeds( proto: Union[onnx.ModelProto, List[str]], inputs: Any, use_numpy: bool = False, copy: bool = False, check_flatten: bool = True, ) -> Dict[str, Union[torch.Tensor, np.ndarray]]: """ Serializes the inputs to produce feeds expected by :class:`onnxruntime.InferenceSession`. :param proto: onnx model or list of names :param inputs: any kind of inputs :param use_numpy: if True, converts torch tensors into numpy arrays :param copy: a copy is made, this should be the case if the inputs is ingested by ``OrtValue`` :param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten`` returns the same number of outputs :return: feeds dictionary """ flat = flatten_object(inputs, drop_keys=True) assert ( not check_flatten or not all(isinstance(obj, torch.Tensor) for obj in flat) or not is_cache_dynamic_registered(fast=True) or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0]) ), ( f"Unexpected number of flattened objects, " f"{string_type(flat, with_shape=True)} != " f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}" ) if use_numpy: flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat] names = ( [i.name for i in proto.graph.input] if isinstance(proto, onnx.ModelProto) else proto ) if copy: flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat] return dict(zip(names, flat))