experimental_experiment.torch_interpreter.interpreter¶
- class experimental_experiment.torch_interpreter.interpreter.DynamoInterpreter(graph_builder: GraphBuilder, retriever: Callable, dispatcher: Dispatcher | None = None, example_inputs: Tuple[torch.Tensor, ...] | None = None, export_options: ExportOptions | None = None, optimize_submodules: bool = False, function_options: FunctionOptions | None = None, submodule_naming: Callable | None = None, parameter_naming: Callable | None = None, module_name: str | None = None, default_values: Dict[str, Any] | None = None)[source]¶
Interprets a torch graph into an ONNX graph. Dispatches every node to the appropriate converting function.
- Parameters:
graph_builder – a graph builder
retriever – callable to help retrieve the weights in a module, see function _retrieve <experimental_experiment.torch_interpreter.onnx_export._retrieve>.
dispatcher – see
experimental_experiment.torch_interpreter.Dispatcher
export_options – see
ExportOptions
optimize_submodules – optimizes submodules after they are built
submodule_naming – a function which returns a submodule name in the onnx graph
parameter_naming – a function which returns a parameter name in the onnx graph
module_name – module name (makes it easier to retrieve the parameter names)
- add_aten_as_function(name_fct: str, fct: Callable, can_set: Dict[str, Any] | None, output_names: List[str], args: List[Any], kwargs: Dict[str, Any], domain: str = 'aten') str | Tuple[str] [source]¶
Converts a function into a local function and adds this local function to the graph.
- call_module(node: torch.fx.Node)[source]¶
Called for a module.
- flatten_inputs(x: Any) List[torch.Tensor] [source]¶
Flatten inputs.
- get_attr(node: torch.fx.Node)[source]¶
Retrieves an attribute.
- get_submodule_name(module_name: str, module: torch.nn.Module) str [source]¶
Gets a submodule name, simple but unique.
- getitem(node: torch.fx.Node)[source]¶
Called when the brackets
something[...]
appears. The index may be another variable, an integer, a slice, a tuple, a list.
- placeholder(node: torch.fx.Node)[source]¶
placeholder for an input. The interpreter adds an Identity node between the input names he wants and the name it has in the graph module.
- register_named_modules(parent_interpreter: DynamoInterpreter | None, preserved_modules: Set[type[torch.nn.Module]] | None, named_modules: Dict[str, torch.nn.Module])[source]¶
Registers a list of modules to preserve as local function in the onnx model. If empty, the graph is almost inlined. The module to convert to onnx should the output of method
torch.export.unflatten.unflatten()
.
- run_node(node: torch.fx.Node)[source]¶
Runs a node: call the approrpiate method based on the node type.