Source code for experimental_experiment.torch_dynamo.fast_backend

import os
import pickle
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import ModelProto, TensorProto, load
from onnx.numpy_helper import to_array
import torch
from torch._C import _from_dlpack
from onnxruntime.capi import _pybind_state as ORTC
from ..convert.ort_helper import append_custom_libraries
from ..helpers import tensor_dtype_to_np_dtype, from_array_extended, onnx_dtype_to_torch_dtype
from ..xbuilder import OptimizationOptions
from ..xoptim import get_pattern_list
from ..torch_interpreter import to_onnx, ExportOptions
from ..torch_interpreter._torch_helper import create_input_names
from .backend_helper import get_dimensions

def _get_session(
    onx: ModelProto,
    impl: str = "ort",
    providers: Optional[List[str]] = None,
    ort_optimization_level: Optional[str] = None,
    exc: bool = True,
) -> Tuple[Union["ReferenceEvaluator", "InferenceSession"], "RunOptions"]:  # noqa: F821
    assert impl == "ort", f"Unexpected impl={impl!r}"
    assert exc, f"Silent mode is not allowed but exc={exc!r}"
    import onnxruntime

    run_options = onnxruntime.RunOptions()
    run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
    opts = onnxruntime.SessionOptions()
    if ort_optimization_level is not None:
        assert hasattr(onnxruntime.GraphOptimizationLevel, ort_optimization_level), (
            f"Unexpected value {ort_optimization_level!r} for GraphOptimizationLevel, "
            f"expecting one of the values in {dir(onnxruntime.GraphOptimizationLevel)}"
        opts.graph_optimization_level = getattr(
            onnxruntime.GraphOptimizationLevel, ort_optimization_level
        if ort_optimization_level == "ORT_DISABLE_ALL":
            opts.enable_mem_pattern = False
            opts.enable_mem_reuse = False
            opts.enable_cpu_mem_arena = False
            # opts.add_session_config_entry("set_denormal_as_zero", "1")
            opts.add_session_config_entry("disable_prepacking", "1")

    opts.add_session_config_entry("session.disable_aot_function_inlining", "1")
    append_custom_libraries(onx, opts)

    return (
        onnxruntime.InferenceSession(onx.SerializeToString(), opts, providers=providers),

def _post_process(output: Any, name: Optional[str], dim: bool) -> Any:
    if name is None:
        # None value required by torch
        return None
    if dim:
        # a dimension to replace
        if output.shape == (1,):
            yi = int(output[0])
            yi = int(output)
        return yi
    return output

def _serialize(args: Any) -> Any:
    if isinstance(args, torch.Tensor):
        return args
    if isinstance(args, tuple):
        return tuple(_serialize(a) for a in args)
    if isinstance(args, list):
        return [_serialize(a) for a in args]
    if isinstance(args, (int, torch.SymInt, float, torch.SymFloat)):
        return args
    raise RuntimeError(f"Unable to serialize type {type(args)}.")

[docs] class OrtBackend: """ Wraps method ``run_with_ortvaluevector`` from ``InferenceSession`` to implement a backend for ``torch.dynamo``. """ ORT_STR_TYPE_TO_TENSOR_TYPE = { "tensor(int64)": TensorProto.INT64, "tensor(int32)": TensorProto.INT32, "tensor(int16)": TensorProto.INT16, "tensor(uint64)": TensorProto.UINT64, "tensor(uint32)": TensorProto.UINT32, "tensor(uint16)": TensorProto.UINT16, "tensor(float)": TensorProto.FLOAT, "tensor(float16)": TensorProto.FLOAT16, "tensor(double)": TensorProto.DOUBLE, "tensor(bool)": TensorProto.BOOL, } TORCH_DTYPE_TO_NUMPY_DTYPE = { torch.float16: np.float16, torch.float32: np.float32, torch.float64: np.float64, torch.uint8: np.uint8, torch.int8: np.int8, torch.int16: np.int16, torch.int32: np.int32, torch.int64: np.int64, torch.bool: np.bool_, } NUMPY_DTYPE_TO_TORCH_DTYPE = { np.float16: torch.float16, np.float32: torch.float32, np.int64: torch.int64, np.int32: torch.int32, np.dtype("float16"): torch.float16, np.dtype("float32"): torch.float32, np.dtype("int64"): torch.int64, np.dtype("int32"): torch.int32, } def __init__( self, sess: "onnxruntime.InferenceSession", # noqa: F821 run_options: Optional["onnxruntime.RunOptions"] = None, # noqa: F821 devices: Optional[Dict[int, Any]] = None, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, is_dimension_in: Optional[List[Tuple[bool, int, str, int]]] = None, is_dimension_out: Optional[List[Tuple[bool, int, Optional[str], int]]] = None, dump_first_inputs: Optional[str] = None, stor: Optional[Dict[str, Any]] = None, onnx_model: Optional[ModelProto] = None, ): self.sess = sess self.input_names = input_names self.output_names = output_names self.is_dimension_in = is_dimension_in self.is_dimension_out = is_dimension_out self.dump_first_inputs = dump_first_inputs self.stor = stor self.run_options = run_options self.devices = devices self.OrtValueVector = ORTC.OrtValueVector self.from_dlpack = _from_dlpack self.onnx_model = onnx_model if self.devices is None: DEVICES = { -1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0) } if torch.cuda.device_count() > 0: for i in range(torch.cuda.device_count()): DEVICES[i] = ORTC.OrtDevice( ORTC.OrtDevice.cuda(), ORTC.OrtDevice.default_memory(), i ) self.devices = DEVICES if self.run_options is None: import onnxruntime self.run_options = onnxruntime.RunOptions() self.run_options.add_run_config_entry( "disable_synchronize_execution_providers", "1" ) if self.input_names is None: self.input_names = [ for i in sess.get_inputs()] if self.output_names is None: self.output_names = [ for i in sess.get_outputs()] if self.is_dimension_in is None: self.is_dimension_in = [] for o in sess.get_inputs(): b = "_dim_" in rk = len(o.shape) dt = self.ORT_STR_TYPE_TO_TENSOR_TYPE[o.type] self.is_dimension_in.append((b, rk,, dt)) if self.is_dimension_out is None: self.is_dimension_out = [] for o in sess.get_outputs(): b = "_dim_" in rk = len(o.shape) dt = self.ORT_STR_TYPE_TO_TENSOR_TYPE[o.type] self.is_dimension_out.append( (b, rk, None if "_NONE_" in else, dt) ) def __call__(self, *inputs): if self.dump_first_inputs: name = self.dump_first_inputs self.dump_first_inputs = None with open(name + ".pkl", "wb") as f: pickle.dump([self.input_names, _serialize(inputs), self.output_names], f) res, dimensions = self._run_onnx_session_with_ortvaluevector(inputs) for x, name in zip(res, self.output_names): if isinstance(x, (torch.SymInt, int, float, torch.SymFloat)): if x == 0: self.dump_for_debug("debug_data", *inputs) assert ( x != 0 ), f"Dimension is null for name={name!r}, input dimensions={dimensions}" if self.stor: self.stor["inputs"].append(inputs) self.stor["outputs"].append(res) return res def _get_ortvalues_from_torch_tensors( self, tensors: Tuple["torch.Tensor", ...], # noqa: F821 ) -> Tuple[Tuple["torch.Tensor", ...], Tuple["OrtDevice", ...], Any]: # noqa: F821 ortvalues = self.OrtValueVector() ortvalues.reserve(len(tensors)) dtypes = [] shapes = [] data_ptrs = [] devices = [] dimensions = [] max_device = -1 assert isinstance(max_device, int), f"unexpected type for device={max_device!r}" assert tensors is not None, "tensors cannot be None" new_tensors = [] for tensor, (dim, rk, name, dt) in zip(tensors, self.is_dimension_in): if dim: dim_types = (int, torch.SymInt, float, torch.SymFloat) assert isinstance( tensor, dim_types ), f"Unexpected type {type(tensor)} for name={name!r}." np_dtype = tensor_dtype_to_np_dtype(dt) dtypes.append(np_dtype) ti = ( int(tensor) if dt in { TensorProto.INT64, TensorProto.INT32, TensorProto.UINT64, TensorProto.UINT32, } else float(tensor) ) assert ti != 0, ( f"Null value for a dimension ti={ti}, " f"tensor={tensor}, rk={rk}, name={name!r}, " f"type(tensor)={type(tensor)}, " f"dimension={[t for t in tensors if isinstance(t, dim_types)]}" ) t = torch.tensor([ti] if rk == 1 else ti, dtype=onnx_dtype_to_torch_dtype(dt)) devices.append(self.devices[-1]) new_tensors.append(t) dimensions.append(t) shapes.append(t.size()) data_ptrs.append(t.data_ptr()) else: assert isinstance(tensor, torch.Tensor), ( f"Unexpected type {type(tensor)}, " f"dim={dim}, rk={rk}, name={name!r}, dt={dt}, " f"len(tensors)={len(tensors)}, " f"len(is_dimension_in)={len(self.is_dimension_in)}" ) dtypes.append(self.TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype]) shapes.append(tensor.size()) data_ptrs.append(tensor.data_ptr()) d = tensor.get_device() devices.append(self.devices[d]) new_tensors.append(tensor) max_device = max(max_device, tensor.get_device()) ortvalues.push_back_batch(new_tensors, data_ptrs, dtypes, shapes, devices) output_devices = [] for dim, _rk, _name, _dt in self.is_dimension_out: dev = self.devices[-1] if dim else self.devices[max_device] output_devices.append(dev) return (ortvalues, output_devices, dimensions) def _ortvalues_to_torch_tensor( self, ortvalues: "onnxruntime.OrtValueVector", # noqa: F821 ) -> Tuple["torch.Tensor", ...]: # noqa: F821 if len(ortvalues) == 0: return tuple() res = ortvalues.to_dlpacks(self.from_dlpack) return tuple(_post_process(r, d[2], d[0]) for r, d in zip(res, self.is_dimension_out)) def _run_onnx_session_with_ortvaluevector( self, inputs: Tuple["torch.Tensor", ...], # noqa: F821 ) -> Tuple["torch.Tensor"]: # noqa: F821 # _nvtx_range_push("contiguous") contiguous_inputs = tuple( (a.contiguous() if isinstance(a, torch.Tensor) else a) for a in inputs ) # _nvtx_range_pop() # _nvtx_range_push("push_back_batch") ort_inputs, output_devices, dimensions = self._get_ortvalues_from_torch_tensors( contiguous_inputs ) # _nvtx_range_pop() # _nvtx_range_push("run_with_ortvaluevector") ort_outputs = self.OrtValueVector() self.sess.run_with_ortvaluevector( self.run_options, self.input_names, ort_inputs, self.output_names, ort_outputs, output_devices, ) # _nvtx_range_pop() # _nvtx_range_push("after run_with_ortvaluevector") # Map ORTValue to torch.Tensor. pth_outputs = self._ortvalues_to_torch_tensor(ort_outputs) # _nvtx_range_pop() return pth_outputs, dimensions def to_tensor_proto(self, value: Any) -> TensorProto: if isinstance(value, np.ndarray): proto = from_array_extended(value) elif isinstance(value, int): proto = from_array_extended(np.array([value], dtype=np.int64)) elif isinstance(value, torch.Tensor): return self.to_tensor_proto(value.detach().cpu().numpy()) else: raise RuntimeError( f"Unexpected type {type(value)}, unable to convert to TensorProto" ) return proto
[docs] def dump_for_debug(self, folder: str, *inputs, test_case: int = 0): """ Dumps everything in a folder. """ assert self.onnx_model is not None, "Cannot dump if the onnx model is not here" if not os.path.exists(folder): os.makedirs(folder) with open(os.path.join(folder, "model.onnx"), "wb") as f: f.write(self.onnx_model.SerializeToString()) case = os.path.join(folder, f"test_case_{test_case}") if case and not os.path.exists(case): os.makedirs(case) assert len(inputs) > 0, f"Empty sequence of inputs, cannot save into {folder!r}." for i, inp in enumerate(inputs): name = os.path.join(case, f"input_{i}.pb") pb = self.to_tensor_proto(inp) with open(name, "wb") as f: f.write(pb.SerializeToString())
[docs] @classmethod def replay_dumped_data( cls, folder: str, test_case: int = 0, providers: Optional[List[str]] = None, impl: str = "ort", ort_optimization_level: Optional[str] = None, ) -> Tuple["OrtBackend", List[Any]]: """ Loads the data save by :meth:`dump_for_debug`. """ onx = load(os.path.join(folder, "model.onnx")) test = os.path.join(folder, f"test_case_{test_case}") inputs = [] i = 0 name = os.path.join(test, f"input_{i}.pb") while os.path.exists(name): with open(name, "rb") as f: b = t = TensorProto() t.ParseFromString(b) a = to_array(t) inputs.append(a) i += 1 name = os.path.join(test, f"input_{i}.pb") if providers is None: providers = ( [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})] if torch.cuda.device_count() > 0 else ["CPUExecutionProvider"] ) device = 0 if torch.cuda.device_count() > 0 else -1 else: device = 0 if "CUDAExecutionProvider" in providers else -1 sess, run_options = _get_session( onx, impl, providers, exc=True, ort_optimization_level=ort_optimization_level, ) bck = OrtBackend(sess, run_options=run_options, onnx_model=onx) new_inputs = [] for value, dim in zip(inputs, bck.is_dimension_in): if dim[0]: v = int(value[0]) if value.shape == (1,) else int(value) else: v = torch.Tensor(value.copy()).to(cls.NUMPY_DTYPE_TO_TORCH_DTYPE[value.dtype]) if device >= 0: v = new_inputs.append(v) return bck, new_inputs
def _default_export( graph_module, args, verbose, target_opset, dispatcher, optimize, enable_pattern, disable_pattern, rename_inputs, processor, order_algorithm=None, dump_patterns=None, options=None, export_options: Optional[Union[str, ExportOptions]] = None, ): input_names = input_names = ( create_input_names(graph_module, args) if rename_inputs else None ) verbose_onnx, verbose_backend = ( verbose if isinstance(verbose, tuple) else (verbose, verbose) ) if options is None: patterns = get_pattern_list(enable_pattern, disable_pattern, verbose=verbose_onnx) if order_algorithm is not None: from ..xoptim import OrderAlgorithm order_algorithm = getattr(OrderAlgorithm, order_algorithm.upper()) options = OptimizationOptions( remove_unused=True, constant_folding=False, patterns=patterns, verbose=verbose_onnx, processor=processor, order=order_algorithm, dump_applied_patterns=dump_patterns, ) onx, builder = to_onnx( graph_module, tuple(args), input_names=input_names, options=options, verbose=verbose_onnx, target_opset=target_opset, return_builder=True, dispatcher=dispatcher, optimize=optimize, export_options=export_options, ) return onx, builder def _print_memory(max_device: int): if max_device >= 0: print( f"[onnx_custom_backend] CUDA memory " f"allocated={torch.cuda.memory_allocated(max_device)}, " f"reserved={torch.cuda.memory_reserved(max_device)}, " f"max_device={max_device}" )
[docs] def onnx_custom_backend( graph_module: "torch.fx.GraphModule", # noqa: F821 args: List["torch.Tensor"], # noqa: F821 target_opset: Optional[int] = None, backend: str = "ort", verbose: Union[int, Tuple[int, int]] = 0, dump_prefix: Optional[None] = None, dump_patterns: Optional[str] = None, providers: Optional[Tuple[str]] = None, raise_exc: bool = True, storage: Optional[Dict[str, Any]] = None, enable_pattern: Optional[Union[str, List[Union[str, type]]]] = "default", disable_pattern: Optional[Union[str, List[Union[str, type]]]] = None, pre_ort_model_transforms: Optional[ Union[Callable[ModelProto, ModelProto], List[Callable[ModelProto, ModelProto]]] ] = None, ort_optimization_level: Optional[str] = None, dispatcher: Optional["Dispatcher"] = None, # noqa: F821 rename_inputs: bool = True, optimize: bool = True, exporter: Optional[str] = None, processor: str = "CPU", order_algorithm: Optional[str] = None, options: Optional[OptimizationOptions] = None, export_options: Optional[Union[str, ExportOptions]] = None, ) -> Callable: """ Custom backend to export torch models into onnx (see :epkg:`torch.compiler`). This backend relies on :epkg:`onnxruntime` and tries to be as efficient as possible. :param graph_module: graph to export :param args: arguments :param target_opset: opset to use for the conversion :param backend: only `'ort'` is allowed :param verbose: adjust verbosity, if tuple, if gives different verbosity level to the exporter and the runtime :param dump_prefix: to dump the models and the inputs :param dump_patterns: dump the patterns as well :param providers: where to run the model, by default :param raise_exc: raise an exception whenever something goes wrong :param storage: to store any interesting objects during the process :param enable_pattern: optimization patterns to enable :param disable_pattern: optimization patterns to disable :param pre_ort_model_transforms: list of transformations applied on the final ModelProto :param ort_optimization_level: graph optimization level for onnxruntime, the default value is the same as what :epkg:`onnxruntime` defines :param dispatcher: see :class:`experimental_experiment.torch_interpreter.Dispatcher` :param rename_inputs: rename the inputs :param optimize: enable or disable the optimization :param exporter: use a different exporter :param processor: optimization should be made for this processor or this list of processors (comma separated value) :param order_algorithm: algorithm optimizing the order the onnx node, none by default :param options: to define custom Optimization options, in that case, any other optimization parameter is ignored :param export_options: see :class:`ExportOptions <experimental_experiment.torch_interpreter.ExportOptions>` :return: Callable See :ref:`l-plot-onnxrt-diff` or :ref:`l-plot-custom-backend` for examples. If not empty, `storage` keeps the memory of the data generated, onnx models, graph module as well the inputs and outputs when the model is run. The following example shows how to use the custom backend (based on :epkg:`onnxruntime`). .. runpython:: :showcode: import torch from experimental_experiment.torch_dynamo import onnx_custom_backend class MLP(torch.nn.Module): def __init__(self): super().__init__() self.layers = torch.nn.Sequential( torch.nn.Linear(10, 32), torch.nn.Sigmoid(), torch.nn.Linear(32, 1), ) def forward(self, x): return self.layers(x) x = torch.randn(3, 10, dtype=torch.float32) mlp = MLP() expected = mlp(x) compiled_model = torch.compile( mlp, backend=lambda *args, **kwargs: onnx_custom_backend(*args, verbose=1, **kwargs), dynamic=False, fullgraph=True, ) try: got = compiled_model(x) diff = (expected - got).max() print(f"discrepancies: {diff}") except (ImportError, AttributeError) as e: print("onnxruntime-training is not installed", e) """ assert dump_patterns is None or isinstance( dump_patterns, str ), f"Unexpected type {type(dump_patterns)} for dump_patterns." assert storage is None or isinstance( storage, dict ), f"Unexpected type {type(storage)} for storage" # determines the devices DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)} providers = ["CPUExecutionProvider"] if torch.cuda.device_count() > 0: for i in range(torch.cuda.device_count()): DEVICES[i] = ORTC.OrtDevice( ORTC.OrtDevice.cuda(), ORTC.OrtDevice.default_memory(), i ) max_device = max(i.get_device() for i in args if hasattr(i, "get_device")) if max_device >= 0: providers = [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})] else: max_device = -1 # Conversion to onnx begin = time.perf_counter() if verbose: _print_memory(max_device) print("[onnx_custom_backend] starts conversion to onnx.") if exporter is None: onx, builder = _default_export( graph_module, args, verbose, target_opset, dispatcher, optimize, enable_pattern, disable_pattern, rename_inputs, processor, order_algorithm=order_algorithm, dump_patterns=dump_patterns, options=options, export_options=export_options, ) elif exporter == "dynamo": from ._dynamo_exporter import _dynamo_export onx, builder = _dynamo_export( graph_module, args, verbose, target_opset, dispatcher, optimize, enable_pattern, disable_pattern, rename_inputs, processor, order_algorithm=order_algorithm, dump_patterns=dump_patterns, ) else: raise NotImplementedError(f"Unknown exporter {exporter!r}") if verbose: print( f"[onnx_custom_backend] to_onnx done in {time.perf_counter() - begin} with " f"{len(onx.graph.node)} nodes and {len(onx.functions)} local functions." ) _print_memory(max_device) # Applies other transformation. if pre_ort_model_transforms is not None: if not isinstance(pre_ort_model_transforms, list): pre_ort_model_transforms = [pre_ort_model_transforms] for tr in pre_ort_model_transforms: begin = time.perf_counter() if verbose: _print_memory(max_device) print(f"[onnx_custom_backend] starts pre_ort_model_transforms {tr}") onx = tr(onx) if verbose: print( f"[onnx_custom_backend] pre_ort_model_transforms " f"done in {time.perf_counter() - begin} with " f"{len(onx.graph.node)} nodes and {len(onx.functions)} local functions." ) _print_memory(max_device) # Checks for variable ONNXRT_DUMP_PATH value = os.environ.get("ONNXRT_DUMP_PATH", None) if value: dump_prefix = value dump_first_inputs = None if dump_prefix: counter = 0 name = f"{dump_prefix}_{counter}.onnx" while os.path.exists(name): counter += 1 name = f"{dump_prefix}_{counter}.onnx" with open(name, "wb") as f: f.write(onx.SerializeToString()) name = f"{dump_prefix}_{counter}.txt" with open(name, "w") as f: f.write(str(graph_module.graph)) f.write("\n") dump_first_inputs = name # InferenceSession begin = time.perf_counter() if verbose: _print_memory(max_device) print("[onnx_custom_backend] starts creating InferenceSession") sess, run_options = _get_session( onx, backend, providers, exc=raise_exc, ort_optimization_level=ort_optimization_level, ) if verbose: print(f"[onnx_custom_backend] InferenceSession done in {time.perf_counter() - begin}") _print_memory(max_device) input_names = [ for i in onx.graph.input] output_names = [ for i in onx.graph.output] is_dimension_in, is_dimension_out = get_dimensions(onx) # Storage if storage is not None: stor = {} if "instance" in storage: storage["instance"].append(stor) else: storage["instance"] = [stor] stor["graph_module"] = graph_module stor["onnx"] = onx stor["is_dimension_in"] = is_dimension_in stor["is_dimension_out"] = is_dimension_out stor["builder"] = builder stor["sess"] = sess stor["inputs"] = [] stor["outputs"] = [] stor["providers"] = providers else: stor = None # Creates the backend. run = OrtBackend( sess=sess, run_options=run_options, stor=stor, input_names=input_names, output_names=output_names, dump_first_inputs=dump_first_inputs, is_dimension_in=is_dimension_in, is_dimension_out=is_dimension_out, devices=DEVICES, onnx_model=onx, ) return run