experimental_experiment.torch_interpreter

to_onnx

experimental_experiment.torch_interpreter.to_onnx(mod: torch.nn.Module | torch.fx.GraphModule, args: Sequence[torch.Tensor] | None = None, kwargs: Dict[str, torch.Tensor] | None = None, input_names: Sequence[str] | None = None, target_opset: int | Dict[str, int] | None = None, as_function: bool = False, options: OptimizationOptions | None = None, verbose: int = 0, return_builder: bool = False, raise_list: Set[str] | None = None, dynamic_shapes: Dict[str, Any] | Tuple[Any] | None = None, optimize: bool = True, dispatcher: Dispatcher | None = None, large_model: bool = False, external_threshold: int = 1024, export_options: str | ExportOptions | None = None, return_optimize_report: bool = False, filename: str | None = None, inline: bool = False, export_modules_as_functions: bool | Set[type[torch.nn.Module]] = False, function_options: FunctionOptions | None = None) ModelProto | ModelContainer | Tuple[ModelProto | ModelContainer, GraphBuilder][source]

Exports a torch model into ONNX using dynamo export.

Parameters:
  • mod – torch module

  • args – input arguments

  • kwargs – keyword attributes

  • input_names – input names

  • target_opset – targeted opset or targeted opsets as a dictionary

  • as_function – export as a ModelProto or a FunctionProto

  • options – optimization options

  • verbose – verbosity level

  • return_builder – returns the builder as well

  • raise_list – the builder stops any time a name falls into that list, this is a debbuging tool

  • dynamic_shapes – see torch.export.export

  • optimize – optimize the model before exporting into onnx

  • dispatcher – see experimental_experiment.torch_interpreter.Dispatcher

  • large_model – if True returns a onnx.model_container.ModelContainer, it lets the user to decide later if the weights should be part of the model or saved as external weights

  • external_threshold – if large_model is True, every tensor above this limit is stored as external

  • return_optimize_report – returns statistics on the optimization as well

  • filename – if specified, stores the model into that file

  • inline – inline the model before converting to onnx, this is done before any optimization takes place

  • export_options – to apply differents options before to get the exported program

  • export_modules_as_functions – export submodules as local functions, this parameter can be filled with a set of class to preserve, all this other will be exported as usual

  • function_options – to specify what to do with the initializers in local functions, add them as constants or inputs

Returns:

onnx model

If environment variable PRINT_GRAPH_MODULE is set to one, information about the graph module is printed out.

Environment variable ONNXVERBOSE=1 can be used to increase verbosity in this function. Environment variable ONNX_BUILDER_PROGRESS=1 can be used to show a progress bar on big models.

Dispatcher

class experimental_experiment.torch_interpreter.Dispatcher(registered_functions: Dict[str, Callable], verbose: int = 0)[source]

Used to changes the way class DynamoInterpreter selects the function translating aten function or module.

Parameters:
  • registered_functions – registered functions

  • verbose – verbose

fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

find_function(name: Any) Callable | None[source]

Finds the most suitable function to translate a function.

Parameters:

name – function name or definition

Returns:

the function or None if not found

The signature of the returned function is similar to a function such as aten_elu.

find_method(name: Any) Callable | None[source]

Finds the most suitable function to translate a method.

Parameters:

name – method name or definition

Returns:

the function or None if not found

The signature of the returned function is similar to a function such as aten_elu.

ExportOptions

class experimental_experiment.torch_interpreter.ExportOptions(strict: bool = True, fallback: bool = False, tracing: bool = False, jit: bool = False, decomposition_table: str | Dict[TorchOpOverload, Callable[[...], Any]] | None = None, strategy: str | None = None, dynamo: bool = False, aten_as_function: bool = False, remove_inplace: bool = True)[source]

Gathers altogether all the options defining the way to export a model into a graph (not onnx).

Parameters:
  • strict – strict export or not

  • fallback – fallback to jit

  • decomposition_table – decomposition_table, a string as well such as default to use the default decomposition table returned by get_decomposition_table, it can 'all', 'default' or a decomposition list

  • dynamo – to use torch._dynamo.export instead of torch.export.export()

  • tracing – use symbolic tracing

  • jit – use jit to get a graph then converts it into a fx graph

  • strategy – to overwrite all the previous parameters with just a value

  • remove_inplace – remove inplace nodes

  • aten_as_function – keeps aten function as local function to keep a faithful translation of the fx graph.

The fallback strategy tries the following in order:

<<<

import pprint
from experimental_experiment.torch_interpreter import ExportOptions

print("-- default fallback")
pprint.pprint(ExportOptions().get_fallback_options())
print("-- default fallback with decomposition")
pprint.pprint(ExportOptions(decomposition_table="default").get_fallback_options())

>>>

    -- default fallback
    [ExportOptions(),
     ExportOptions(strict=False),
     ExportOptions(decomposition_table='default'),
     ExportOptions(strict=False, decomposition_table='default'),
     ExportOptions(dynamo=True),
     ExportOptions(decomposition_table='default', dynamo=True),
     ExportOptions(jit=True)]
    -- default fallback with decomposition
    [ExportOptions(decomposition_table='default'),
     ExportOptions(strict=False, decomposition_table='default'),
     ExportOptions(),
     ExportOptions(strict=False),
     ExportOptions(decomposition_table='default', dynamo=True),
     ExportOptions(dynamo=True),
     ExportOptions(jit=True, decomposition_table='default')]

Most of the models works with strict=True or False and no decompositions. But if it contains control flows (test or loop), inplace modifications, it may be useful to try different values for strict and to apply decompositions decomposition_table='default'. The decompositions removes unused results coming from inplace modifications.

A graph is considered as invalid if decompositions were not run and there is one node with no user. This usually indicates one inplace operation is still part of the graph.

export(mod: Any, args: Tuple[Any, ...] | None, kwargs: Dict[str, Any] | None, tracing_mode: bool, dynamic_shapes: Dict, same_signature: bool, input_names: List[str] | None = None, exc: bool = True, verbose: int = 0) torch.export.ExportedProgram | torch.fx.GraphModule[source]

Exports the model into an exported program.

get_decomposition_table() Dict[TorchOpOverload, Callable[[...], Any]][source]

Returns the decompisitions table.

get_fallback_options(kind: str | None = None) List[ExportOptions][source]

Returns the fallback scenario.

match_input_parameters

experimental_experiment.torch_interpreter.match_input_parameters(model: Any, names: List[str], args: Tuple[Any, ...] | None = None) Dict[str, Any][source]

Maps the given names with the parameter names in the model.

Parameters:
  • model – model

  • names – names to retrieve

  • args – available inputs

Returns:

dictionary with values

Example:

<<<

import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx, match_input_parameters


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.relu(self.linear(x))


fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter

fake_x = converter.from_real_tensor(fake_mode, torch.rand(2, 5))
with fake_mode:
    model = Neuron(5, 3)
    onx = to_onnx(model, (fake_x,))

# expected values with a different model
not_fake_model = Neuron(5, 3)
x = torch.rand(2, 5)
expected = not_fake_model(x)
print(expected)

# converts the model, fill inputs with the weights
names = [i.name for i in onx.graph.input]
pfeeds = match_input_parameters(not_fake_model, names, (x,))
nfeeds = {k: v.detach().numpy() for k, v in pfeeds.items()}
ref = ExtendedReferenceEvaluator(onx)
got = ref.run(None, nfeeds)
print(got)

>>>

    tensor([[0.0000, 0.0000, 0.2516],
            [0.0000, 0.0000, 0.1396]], grad_fn=<ReluBackward0>)
    [array([[0.   , 0.   , 0.252],
           [0.   , 0.   , 0.14 ]], dtype=float32)]

Other functions

class experimental_experiment.torch_interpreter.TorchOpOverload[source]

The class is unused only to bypass a documentation warning. The alias TorchOpOverload refers to torch._ops.Overload.