experimental_experiment.torch_interpreter.interpreter¶
- class experimental_experiment.torch_interpreter.interpreter.DynamoInterpreter(graph_builder: GraphBuilder, retriever: Callable, dispatcher: Dispatcher | None = None, example_inputs: Tuple[torch.Tensor, ...] | None = None, export_options: ExportOptions | None = None, optimize_submodules: bool = False, function_options: FunctionOptions | None = None, submodule_naming: Callable | None = None, parameter_naming: Callable | None = None, module_name: str | None = None)[source]¶
Interprets a torch graph into an ONNX graph. Dispatches every node to the appropriate converting function.
- Parameters:
graph_builder – a graph builder
retriever – callable to help retrieve the weights in a module, see function _retrieve <experimental_experiment.torch_interpreter.onnx_export._retrieve>.
dispatcher – see
experimental_experiment.torch_interpreter.Dispatcher
export_options – see
ExportOptions
optimize_submodules – optimizes submodules after they are built
submodule_naming – a function which returns a submodule name in the onnx graph
parameter_naming – a function which returns a parameter name in the onnx graph
module_name – module name (makes it easier to retrieve the parameter names)
- call_function(node: torch.fx.Node)[source]¶
Called for a function.
- call_method(node: torch.fx.Node)[source]¶
Called for a method.
- call_module(node: torch.fx.Node)[source]¶
Called for a module.
- flatten_inputs(x: Any) List[torch.Tensor] [source]¶
Flatten inputs.
- get_attr(node: torch.fx.Node)[source]¶
Retrieves an attribute.
- get_submodule_name(module_name: str, module: torch.nn.Module) str [source]¶
Gets a submodule name, simple but unique.
- getitem(node: torch.fx.Node)[source]¶
Called when the brackets
something[...]
appears. The index may be another variable, an integer, a slice, a tuple, a list.
- placeholder(node: torch.fx.Node)[source]¶
placeholder for an input. The interpreter adds an Identity node between the input names he wants and the name it has in the graph module.
- register_named_modules(parent_interpreter: DynamoInterpreter | None, preserved_modules: Set[type[torch.nn.Module]] | None, named_modules: Dict[str, torch.nn.Module])[source]¶
Registers a list of modules to preserve as local function in the onnx model. If empty, the graph is almost inlined. The module to convert to onnx should the output of method
torch.export.unflatten.unflatten()
.
- run_node(node: torch.fx.Node)[source]¶
Runs a node: call the approrpiate method based on the node type.