interpreter

DynamoInterpreter

class experimental_experiment.torch_interpreter.interpreter.DynamoInterpreter(graph_builder: GraphBuilder, retriever: Callable, dispatcher: Dispatcher | None = None, use_dynamo: bool = False)[source]

Interprets a torch graph into an ONNX graph. Dispatches every node to the appropriate converting function.

Parameters:
call_function(node: torch.fx.Node)[source]

Called for a function.

call_method(node: torch.fx.Node)[source]

Called for a method.

call_module(node: torch.fx.Node)[source]

Called for a module.

get_attr(node: torch.fx.Node)[source]

Retrieves an attribute.

getitem(node: torch.fx.Node)[source]

Called when the brackets something[...] appears. The index may be another variable, an integer, a slice, a tuple, a list.

output(node)[source]

Adds an output to the graph.

placeholder(node: torch.fx.Node)[source]

placeholder for an input. The interpreter adds an Identity node between the input names he wants and the name it has in the graph module.

run_node(node: torch.fx.Node)[source]

Runs a node: call the approrpiate method based on the node type.

Dispatcher

class experimental_experiment.torch_interpreter.Dispatcher(registered_functions: Dict[str, Callable], verbose: int = 0)[source]

Used to changes the way class DynamoInterpreter selects the function translating aten function or module.

Parameters:
  • registered_functions – registered functions

  • verbose – verbose

fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

find_function(name: Any) Callable | None[source]

Finds the most suitable function to translate a function.

Parameters:

name – function name or definition

Returns:

the function or None if not found

The signature of the returned function is similar to a function such as aten_elu.

find_method(name: Any) Callable | None[source]

Finds the most suitable function to translate a method.

Parameters:

name – method name or definition

Returns:

the function or None if not found

The signature of the returned function is similar to a function such as aten_elu.

ForceDispatcher

class experimental_experiment.torch_interpreter.ForceDispatcher(signatures: Dict[str, Callable] | None = None, verbose: int = 0, domain: str = 'aten.lib', version: int = 1, strict: bool = False, only_registered: bool = False)[source]

Implements a dispatcher which as an onnx as it is when no converting function is found.

Parameters:
  • signatures – function used only for their signature mapping a name to a function in order to have parameter names

  • verbose – verbose

  • domain – domain of the added node

  • version – version of the domain

  • strict – when an input is not a tensor, it becomes a named parameter if strict is False

  • only_registered – fails if a function is not found in signatures

fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

OxsDispatcher

class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDispatcher(verbose: int = 0)[source]

If class DynamoInterpreter cannot find any converting function for a specific function, it tries to find an existing one in onnxscript. The converting function from onnxscript is run in trace only mode. The variable and functions op, Rank, IsScalar are replaced by op = OwsOpset(), op.Rank, op.Scalar. onnxscript may have multiple overloaded functions. Right now, it takes the first one.

Parameters:

verbose – verbose

fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

property submodules: Dict[str, Callable]

Returns the submodules implementing torch functions.

OxsDebugDispatcher

class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDebugDispatcher(verbose: int = 0, raise_exc: bool = True)[source]

Tries the fallback even if is not necessary to check it is working.

Parameters:
  • verbose – verbosity

  • raise_exc – fail or raise an exception

The class can be used the following way.

<<<

import torch
from experimental_experiment.torch_models.llama_helper import get_llama_model
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.torch_interpreter.oxs_dispatcher import (
    OxsDebugDispatcher,
)

with torch.no_grad():
    model, input_tensors = get_llama_model()
    input_tensors = input_tensors[0]

    to_onnx(
        model,
        input_tensors,
        input_names=[f"input{i}" for i in range(len(input_tensors))],
        options=OptimizationOptions(patterns=None),
        verbose=0,
        dispatcher=OxsDebugDispatcher(verbose=2, raise_exc=False),
    )

>>>

    [2024-05-08 14:06:48,405] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.embedding', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d7910>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.arange', overload='start')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d7910>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.full', overload='default')> with e=aten_full() got an unexpected keyword argument 'device'
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.triu', overload='default')> with e=Unable to access attribute 'Trilu', you can still use this operator with method 'make_node'.
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.arange', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([-1, 1],)
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.gt', overload='Tensor')> with e='NoneType' object is not callable
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.mul', overload='Tensor')> with e=Input type mismatch: float32 != bool (operator ''Mul'', shapes (8, 8), (8, 8)).
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d6e60>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 1, -1, -1],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d7f40>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.alias', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d7f40>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.eq', overload='Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d7f40>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d16d5900>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.eq', overload='Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.alias', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.masked_fill', overload='Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.copy', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.pow', overload='Tensor_Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.mean', overload='dim')> with e='immutable_list' object has no attribute 'name'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.rsqrt', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],)
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 2, 8],)
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 2, 8],)
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 2, 8],)
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d18d28c0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, -1, 1],)
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, 4, 1],)
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, 4, 1],)
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, 1, 8],)
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, 1, 8],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.bmm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e710>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([1, 4, 8],)
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.cat', overload='default')> with e=Unable to access attribute 'ConcatFromSequence', you can still use this operator with method 'make_node'.
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.cos', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d1936290>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.sin', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d16d4430>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.unsqueeze', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e560>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.neg', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.cat', overload='default')> with e=Unable to access attribute 'ConcatFromSequence', you can still use this operator with method 'make_node'.
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.neg', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.cat', overload='default')> with e=Unable to access attribute 'ConcatFromSequence', you can still use this operator with method 'make_node'.
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten._unsafe_view', overload='default')> with e='NoneType' object is not callable
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten._unsafe_view', overload='default')> with e='NoneType' object is not callable
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.bmm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.div', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.slice', overload='Tensor')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._softmax', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],)
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([4, 8, 8],)
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.expand', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten._unsafe_view', overload='default')> with e='NoneType' object is not callable
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.bmm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e950>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 2, 8, 8],)
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten.transpose', overload='int')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.clone', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d0b2e5f0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.pow', overload='Tensor_Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.mean', overload='dim')> with e='immutable_list' object has no attribute 'name'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.rsqrt', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],)
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.silu', overload='default')> with e='NoneType' object is not callable
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.t', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([16, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mm', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d15d5e10>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.view', overload='default')> with e=Wrong inputs for operator 'Cast': ([2, 8, 16],)
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.pow', overload='Tensor_Scalar')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback unverified for <OpOverload(op='aten.mean', overload='dim')> with e='immutable_list' object has no attribute 'name'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.add', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.rsqrt', overload='default')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
    [OxsDebugDispatcher.fallback] fallback failed for <OpOverload(op='aten._to_copy', overload='default')> with e='TracedOnnxFunction' object has no attribute 'function'
    [OxsDebugDispatcher.fallback] fallback verified for <OpOverload(op='aten.mul', overload='Tensor')> with <function OxsDispatcher.fallback.<locals>.wrapper at 0x7fe1d14304c0>
fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

OxsOpset

class experimental_experiment.torch_interpreter.oxs_opset.OxsOpset(builder: GraphBuilder)[source]

Bridge with onnxscript.

Parameters:

builder – builder

make_node(op_type: str, *inputs: List[str] | str | None, outputs: int | List[str] | str | None = None, domain: str = '', name: str | None = None, **kwargs)[source]

Creates a node.

Parameters:
  • op_type – type

  • inputs – inputs

  • outputs – outputs

  • domain – domain

  • name – name

  • kwargs – additional arguments

Returns:

output name

Retriever

_retrieve

experimental_experiment.torch_interpreter.onnx_export._retrieve(name: str, value: Any, weights: Dict[str, torch.Tensor], buffers: Dict[str, torch.Tensor], mapping: Dict[str, Tuple[str, bool]], graph_builder: GraphBuilder) torch.Tensor[source]

Sent to the DynamoInterpreter. It retrieves the weights.