experimental_experiment.torch_interpreter.interpreter

class experimental_experiment.torch_interpreter.interpreter.DynamoInterpreter(graph_builder: GraphBuilder, retriever: Callable, dispatcher: Dispatcher | None = None, example_inputs: Tuple[torch.Tensor, ...] | None = None, export_options: ExportOptions | None = None, optimize_submodules: bool = False, function_options: FunctionOptions | None = None, submodule_naming: Callable | None = None, parameter_naming: Callable | None = None, module_name: str | None = None, default_values: Dict[str, Any] | None = None)[source]

Interprets a torch graph into an ONNX graph. Dispatches every node to the appropriate converting function.

Parameters:
  • graph_builder – a graph builder

  • retriever – callable to help retrieve the weights in a module, see function _retrieve <experimental_experiment.torch_interpreter.onnx_export._retrieve>.

  • dispatcher – see experimental_experiment.torch_interpreter.Dispatcher

  • export_options – see ExportOptions

  • optimize_submodules – optimizes submodules after they are built

  • submodule_naming – a function which returns a submodule name in the onnx graph

  • parameter_naming – a function which returns a parameter name in the onnx graph

  • module_name – module name (makes it easier to retrieve the parameter names)

add_aten_as_function(name_fct: str, fct: Callable, can_set: Dict[str, Any] | None, output_names: List[str], args: List[Any], kwargs: Dict[str, Any], domain: str = 'aten') str | Tuple[str][source]

Converts a function into a local function and adds this local function to the graph.

call_function(node: torch.fx.Node) str | Tuple[str][source]

Called for a function.

call_method(node: torch.fx.Node) str | Tuple[str][source]

Called for a method.

call_module(node: torch.fx.Node)[source]

Called for a module.

flatten_inputs(x: Any) List[torch.Tensor][source]

Flatten inputs.

get_attr(node: torch.fx.Node)[source]

Retrieves an attribute.

get_submodule_name(module_name: str, module: torch.nn.Module) str[source]

Gets a submodule name, simple but unique.

getitem(node: torch.fx.Node)[source]

Called when the brackets something[...] appears. The index may be another variable, an integer, a slice, a tuple, a list.

output(node)[source]

Adds an output to the graph.

placeholder(node: torch.fx.Node)[source]

placeholder for an input. The interpreter adds an Identity node between the input names he wants and the name it has in the graph module.

register_named_modules(parent_interpreter: DynamoInterpreter | None, preserved_modules: Set[type[torch.nn.Module]] | None, named_modules: Dict[str, torch.nn.Module])[source]

Registers a list of modules to preserve as local function in the onnx model. If empty, the graph is almost inlined. The module to convert to onnx should the output of method torch.export.unflatten.unflatten().

run_node(node: torch.fx.Node)[source]

Runs a node: call the approrpiate method based on the node type.