interpreter¶
DynamoInterpreter¶
- class experimental_experiment.torch_interpreter.interpreter.DynamoInterpreter(graph_builder: GraphBuilder, retriever: Callable, dispatcher: Dispatcher | None = None, use_dynamo: bool = False)[source]¶
Interprets a torch graph into an ONNX graph. Dispatches every node to the appropriate converting function.
- Parameters:
graph_builder – a graph builder
retriever – callable to help retrieve the weights in a module, see function _retrieve <experimental_experiment.torch_interpreter.onnx_export._retrieve>.
dispatcher – see
experimental_experiment.torch_interpreter.Dispatcher
use_dynamo – see
to_onnx
- call_function(node: torch.fx.Node)[source]¶
Called for a function.
- call_method(node: torch.fx.Node)[source]¶
Called for a method.
- call_module(node: torch.fx.Node)[source]¶
Called for a module.
- get_attr(node: torch.fx.Node)[source]¶
Retrieves an attribute.
- getitem(node: torch.fx.Node)[source]¶
Called when the brackets
something[...]
appears. The index may be another variable, an integer, a slice, a tuple, a list.
- placeholder(node: torch.fx.Node)[source]¶
placeholder for an input. The interpreter adds an Identity node between the input names he wants and the name it has in the graph module.
- run_node(node: torch.fx.Node)[source]¶
Runs a node: call the approrpiate method based on the node type.
Dispatcher¶
- class experimental_experiment.torch_interpreter.Dispatcher(registered_functions: Dict[str, Callable], verbose: int = 0)[source]¶
Used to changes the way class
DynamoInterpreter
selects the function translating aten function or module.- Parameters:
registered_functions – registered functions
verbose – verbose
- fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None [source]¶
The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.
- Parameters:
name – object or str
fct – function found so far
args – known arguments coming from the graph module
kwargs – known named arguments coming from the graph module
builder – GraphBuilder
- Returns:
callable
ForceDispatcher¶
- class experimental_experiment.torch_interpreter.ForceDispatcher(signatures: Dict[str, Callable] | None = None, verbose: int = 0, domain: str = 'aten.lib', version: int = 1, strict: bool = False, only_registered: bool = False)[source]¶
Implements a dispatcher which as an onnx as it is when no converting function is found.
- Parameters:
signatures – function used only for their signature mapping a name to a function in order to have parameter names
verbose – verbose
domain – domain of the added node
version – version of the domain
strict – when an input is not a tensor, it becomes a named parameter if strict is False
only_registered – fails if a function is not found in signatures
- fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None [source]¶
The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.
- Parameters:
name – object or str
fct – function found so far
args – known arguments coming from the graph module
kwargs – known named arguments coming from the graph module
builder – GraphBuilder
- Returns:
callable
OxsDispatcher¶
- class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDispatcher(verbose: int = 0)[source]¶
If class
DynamoInterpreter
cannot find any converting function for a specific function, it tries to find an existing one in onnxscript. The converting function from onnxscript is run in trace only mode. The variable and functions op, Rank, IsScalar are replaced by op = OwsOpset(), op.Rank, op.Scalar. onnxscript may have multiple overloaded functions. Right now, it takes the first one.- Parameters:
verbose – verbose
- fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None [source]¶
The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.
- Parameters:
name – object or str
fct – function found so far
args – known arguments coming from the graph module
kwargs – known named arguments coming from the graph module
builder – GraphBuilder
- Returns:
callable
OxsDebugDispatcher¶
- class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDebugDispatcher(verbose: int = 0, raise_exc: bool = True)[source]¶
Tries the fallback even if is not necessary to check it is working.
- Parameters:
verbose – verbosity
raise_exc – fail or raise an exception
The class can be used the following way.
<<<
import torch from experimental_experiment.torch_models.llama_helper import get_llama_model from experimental_experiment.xbuilder import OptimizationOptions from experimental_experiment.torch_interpreter import to_onnx from experimental_experiment.torch_interpreter.oxs_dispatcher import ( OxsDebugDispatcher, ) with torch.no_grad(): model, input_tensors = get_llama_model() input_tensors = input_tensors[0] to_onnx( model, input_tensors, input_names=[f"input{i}" for i in range(len(input_tensors))], options=OptimizationOptions(patterns=None), verbose=0, dispatcher=OxsDebugDispatcher(verbose=2, raise_exc=False), )
>>>
[2024-05-08 14:06:48,405] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.embedding', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d7910> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.arange', overload='start')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d7910> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.full', overload='default')> with e=aten_full() got an unexpected keyword argument 'device' [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.triu', overload='default')> with e=Unable to access attribute 'Trilu', you can still use this operator with method 'make_node'. [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.arange', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([-1, 1],) [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.gt', overload='Tensor')> with e='NoneType' object is not callable [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.mul', overload='Tensor')> with e=Input type mismatch: float32 != bool (operator ''Mul'', shapes (8, 8), (8, 8)). [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d6e60> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 1, -1, -1],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d7f40> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.alias', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d7f40> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.eq', overload='Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d7f40> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d16d5900> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.eq', overload='Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.alias', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.masked_fill', overload='Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.copy', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.pow', overload='Tensor_Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.mean', overload='dim')> with e='immutable_list' object has no attribute 'name' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.rsqrt', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],) [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 2, 8],) [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 2, 8],) [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 2, 8],) [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d18d28c0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, -1, 1],) [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, 4, 1],) [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, 4, 1],) [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, 1, 8],) [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, 1, 8],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.bmm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e710> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, 4, 8],) [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.cat', overload='default')> with e=Unable to access attribute 'ConcatFromSequence', you can still use this operator with method 'make_node'. [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.cos', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.sin', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d16d4430> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e560> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.neg', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.cat', overload='default')> with e=Unable to access attribute 'ConcatFromSequence', you can still use this operator with method 'make_node'. [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.neg', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.cat', overload='default')> with e=Unable to access attribute 'ConcatFromSequence', you can still use this operator with method 'make_node'. [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten._unsafe_view', overload='default')> with e='NoneType' object is not callable [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten._unsafe_view', overload='default')> with e='NoneType' object is not callable [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.bmm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.div', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._softmax', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],) [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([4, 8, 8],) [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten._unsafe_view', overload='default')> with e='NoneType' object is not callable [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.bmm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],) [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.pow', overload='Tensor_Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.mean', overload='dim')> with e='immutable_list' object has no attribute 'name' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.rsqrt', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],) [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.silu', overload='default')> with e='NoneType' object is not callable [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],) [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.pow', overload='Tensor_Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.mean', overload='dim')> with e='immutable_list' object has no attribute 'name' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.rsqrt', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0> [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function' [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
- fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None [source]¶
The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.
- Parameters:
name – object or str
fct – function found so far
args – known arguments coming from the graph module
kwargs – known named arguments coming from the graph module
builder – GraphBuilder
- Returns:
callable
OxsOpset¶
- class experimental_experiment.torch_interpreter.oxs_opset.OxsOpset(builder: GraphBuilder)[source]¶
Bridge with onnxscript.
- Parameters:
builder – builder
- make_node(op_type: str, *inputs: List[str] | str | None, outputs: int | List[str] | str | None = None, domain: str = '', name: str | None = None, **kwargs)[source]¶
Creates a node.
- Parameters:
op_type – type
inputs – inputs
outputs – outputs
domain – domain
name – name
kwargs – additional arguments
- Returns:
output name
Retriever¶
_retrieve¶
- experimental_experiment.torch_interpreter.onnx_export._retrieve(name: str, value: Any, weights: Dict[str, torch.Tensor], buffers: Dict[str, torch.Tensor], mapping: Dict[str, Tuple[str, bool]], graph_builder: GraphBuilder) torch.Tensor [source]¶
Sent to the
DynamoInterpreter
. It retrieves the weights.