fromtypingimportAny,Dict,List,Unionimportnumpyasnpimportonnximporttorchfrom.helperimportstring_type,flatten_objectfrom.onnx_helperimportdtype_to_tensor_dtypefrom.cache_helperimportis_cache_dynamic_registereddefname_type_to_onnx_dtype(name:str)->int:ifname=="tensor(int64)":returnonnx.TensorProto.INT64ifname=="tensor(float)":returnonnx.TensorProto.FLOATifname=="tensor(float16)":returnonnx.TensorProto.FLOAT16raiseAssertionError(f"Unexpected value {name!r}")
[docs]defmake_feeds(proto:Union[onnx.ModelProto,List[str]],inputs:Any,use_numpy:bool=False,copy:bool=False,check_flatten:bool=True,)->Dict[str,Union[torch.Tensor,np.ndarray]]:""" Serializes the inputs to produce feeds expected by :class:`onnxruntime.InferenceSession`. :param proto: onnx model or list of names :param inputs: any kind of inputs :param use_numpy: if True, converts torch tensors into numpy arrays :param copy: a copy is made, this should be the case if the inputs is ingested by ``OrtValue`` :param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten`` returns the same number of outputs :return: feeds dictionary """# position_ids is a special case because ModelBuilder does not usually use it.# We use types to detect the best inputs.flat=flatten_object(inputs,drop_keys=True)assert(notcheck_flattenornotall(isinstance(obj,torch.Tensor)forobjinflat)ornotis_cache_dynamic_registered(fast=True)orlen(flat)==len(torch.utils._pytree.tree_flatten(inputs)[0])),(f"Unexpected number of flattened objects, "f"{string_type(flat,with_shape=True)} != "f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0],with_shape=True)}")ifuse_numpy:flat=[t.detach().cpu().numpy()ifisinstance(t,torch.Tensor)elsetfortinflat]names=([i.nameforiinproto.graph.input]ifisinstance(proto,onnx.ModelProto)else([i.nameforiinproto.get_inputs()]ifhasattr(proto,"get_inputs")elseproto))assert(isinstance(names,list)andlen(names)<=len(flat)and(len(names)==len(flat)orisinstance(proto,onnx.ModelProto)orhasattr(proto,"get_inputs"))),(f"Not the same number of given inputs {len(flat)} "f"and the number of model inputs {len(names)}, "f"type(names)={type(names)}, type(proto)={type(proto)}"f"\n-- inputs={string_type(inputs,with_shape=True)}"f"\n-- names={names}")iflen(names)<len(flat)and(isinstance(proto,onnx.ModelProto)orhasattr(proto,"get_inputs")):typed_names=([(i.name,i.type.tensor_type.elem_type)foriinproto.graph.input]ifisinstance(proto,onnx.ModelProto)else[(i.name,name_type_to_onnx_dtype(i.type))foriinproto.get_inputs()])new_flat=[]pos=0for_name,dtypeintyped_names:assertisinstance(dtype,int),f"Unexpected value for dtype={dtype!r}, type(proto)={type(proto)}"itype=dtype_to_tensor_dtype(flat[pos].dtype)whiledtype!=itype:pos+=1ifpos>=len(flat):breakitype=dtype_to_tensor_dtype(flat[pos].dtype)ifpos>=len(flat):breaknew_flat.append(flat[pos])pos+=1assertlen(new_flat)==len(names),(f"Unable to align expected input {names} with the given input, "f"type(proto)={type(proto)}"f"\n-- inputs: {string_type(inputs,with_shape=True)}"f"\n-- typed_names: {typed_names}")flat=new_flatifcopy:flat=[t.copy()ifhasattr(t,"copy")elset.clone()fortinflat]returndict(zip(names,flat))