Source code for experimental_experiment.torch_bench.export_model_helper

import os
import sys
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import onnx
from onnx import (
    AttributeProto,
    FunctionProto,
    GraphProto,
    ModelProto,
    NodeProto,
    TensorProto,
)
from ..convert.convert_helper import optimize_model_proto_oxs
from ..bench_run import measure_discrepancies


def size_type(dtype: Any) -> int:
    if isinstance(dtype, int):
        # It is a TensorProto.DATATYPE
        if dtype in {TensorProto.DOUBLE, TensorProto.INT64, TensorProto.UINT64}:
            return 8
        if dtype in {TensorProto.FLOAT, TensorProto.INT32, TensorProto.UINT32}:
            return 4
        if dtype in {
            TensorProto.FLOAT16,
            TensorProto.BFLOAT16,
            TensorProto.INT16,
            TensorProto.UINT16,
        }:
            return 2
        if dtype in {TensorProto.INT8, TensorProto.UINT8, TensorProto.BOOL}:
            return 1
        raise AssertionError(f"Unable to return the element size for type {dtype}")

    import torch

    if dtype in {torch.float64, torch.int64}:
        return 8
    if dtype in {torch.float32, torch.int32}:
        return 4
    if dtype in {torch.float16, torch.int16, torch.bfloat16}:
        return 2
    if dtype in {torch.int8, torch.uint8, torch.bool}:
        return 1
    if hasattr(torch, "uint64"):
        # it fails on mac
        if dtype in {torch.uint64}:
            return 8
    if hasattr(torch, "uint32"):
        # it fails on mac
        if dtype in {torch.uint32}:
            return 4
    if hasattr(torch, "uint16"):
        # it fails on mac
        if dtype in {torch.uint16}:
            return 2
    raise AssertionError(f"Unexpected dtype={dtype}")


def obj_size(obj: Any) -> int:
    import torch

    if isinstance(obj, torch.Tensor):
        assert not obj.is_sparse, "Sparse tensors not supported yet."
        return int(np.prod(list(obj.shape)) * size_type(obj.dtype))
    if isinstance(obj, (tuple, list)):
        return sum(obj_size(o) for o in obj)
    if isinstance(obj, dict):
        return sum(obj_size(o) for o in obj.values())
    if obj is None:
        return 0
    if obj.__class__.__name__.endswith("KeyedJaggedTensor"):
        # Not implemented yet.
        return 0
    if isinstance(obj, (int, float, str, bytes)):
        return sys.getsizeof(obj)
    if hasattr(obj, "_fields") and isinstance(obj._fields, dict):
        # detectron2.structures.instances.Instances
        return obj_size(obj._fields)
    if hasattr(obj, "tensor") and isinstance(obj.tensor, torch.Tensor):
        # detectron2.structures.instances.Bowes
        return obj_size(obj.tensor)
    if "SquashedNormal" in obj.__class__.__name__:
        return sys.getsizeof(obj)
    if obj.__class__.__name__ == "MambaCache":
        return obj_size(obj.conv_states) + obj_size(obj.ssm_states)
    if obj.__class__.__name__ == "DynamicCache":
        return obj_size(obj.key_cache) + obj_size(obj.value_cache)
    raise AssertionError(f"input_size not implemented for type {type(obj)}")


[docs] def compute_weight_size(model: Any) -> int: """ Returns the model size for a torch model or an onnx model. That includes the weights. """ if isinstance(model, ModelProto): size = compute_weight_size(model.graph) for f in model.functions: size += compute_weight_size(f) assert isinstance(size, int), f"Unexpected type {type(size)}: {size}" return size if isinstance(model, GraphProto): size = 0 for init in model.initializer: size += compute_weight_size(init) for init in model.sparse_initializer: size += compute_weight_size(init) for node in model.node: size += compute_weight_size(node) assert isinstance(size, int), f"Unexpected type {type(size)}: {size}" return size if isinstance(model, FunctionProto): size = 0 for node in model.node: size += compute_weight_size(node) assert isinstance(size, int), f"Unexpected type {type(size)}: {size}" return size if isinstance(model, TensorProto): p = int(np.prod(model.dims)) size = p * size_type(model.data_type) assert isinstance(size, int), f"Unexpected type {type(size)}: {size}" return size if isinstance(model, NodeProto): if model.op_type == "Constant": return len(model.SerializeToString()) size = 0 for att in model.attribute: if att.type == AttributeProto.GRAPH: size += compute_weight_size(att.g) assert isinstance(size, int), f"Unexpected type {type(size)}: {size}" return size if hasattr(model, "parameters"): import torch size = 0 for p in model.parameters(): assert isinstance(p, torch.Tensor), f"Unexpected type {type(p)}" size += obj_size(p) assert isinstance(size, int), f"Unexpected type {type(size)}: {size}" return size if hasattr(model, "buffer"): # executorch return len(model.buffer) return np.nan
[docs] def common_export( model: Any, inputs: List[Any], exporter: str = "custom", target_opset: int = 18, folder: str = "", filename: str = "model.onnx", dynamic_shapes: Optional[Any] = None, verbose: int = 0, optimize_oxs: str = "", ort_optimize: bool = False, large_model: bool = False, order: bool = False, enable_pattern: Optional[Union[str, List[str]]] = None, disable_pattern: Optional[Union[str, List[str]]] = None, stats: Optional[Dict[str, Any]] = None, ): """ Exports a model into a folder. :param model: model :param exporter: torchscript, onnx_dynamo, dynamo_export, custom, ... :param folder: folder to export into :param filename: onnx filename :param inputs: inputs :param dynamic_shapes: dynamic shapes :param target_opset: target opset :param optimize_oxs: run optimization with onnxscript :param enable_pattern: patterns to apply :param disable_pattern: patterns not to apply :param verbose: verbosity :param stats: if not None, populates this dictionary with statistics about time :param optimize_oxs: optimize :param ort_optimize: produces a file showing onnxruntime optimizations :param large_model: save weights as external :param order: optimize order :returns: onnx proto """ import torch.onnx if folder: if folder and not os.path.exists(folder): os.makedirs(folder) filename = os.path.join(folder, filename) if verbose: print( f"[common_export] start exporting with {exporter!r}, " f"{len(inputs)} inputs in {filename!r}" ) begin = time.perf_counter() if exporter == "torch_script": assert isinstance(inputs, tuple), f"{type(inputs)}" assert len(inputs) == 2 torch.onnx.export( model, inputs, filename, do_constant_folding=False, input_names=[f"input{i}" for i in range(len(inputs))], opset_version=target_opset, dynamic_axes=dynamic_shapes, ) elif exporter == "onnx_dynamo": assert ( dynamic_shapes is None ), f"dynamic_shapes={dynamic_shapes} is not implemented yet" torch.onnx.export( model, inputs, filename, do_constant_folding=False, input_names=[f"input{i}" for i in range(len(inputs))], opset_version=target_opset, dynamic_axes=dynamic_shapes, dynamo=True, ) elif exporter == "dynamo_export": with torch.no_grad(): prog = torch.onnx.dynamo_export(model, *inputs) onnx.save(prog.model_proto, filename) elif exporter in ("custom", "custom-fallback"): from ..torch_interpreter import to_onnx from ..xbuilder import OptimizationOptions from ..xoptim import get_pattern_list patterns = get_pattern_list(enable_pattern, disable_pattern, verbose=verbose) onx = to_onnx( model, inputs, input_names=[f"input{i}" for i in range(len(inputs))], options=OptimizationOptions(patterns=patterns, order=order), verbose=verbose, target_opset=target_opset, optimize=bool(enable_pattern), large_model=large_model, ) print([i.name for i in onx.graph.input]) with open(filename, "wb") as f: f.write(onx.SerializeToString()) else: raise ValueError(f"Unknown exporter {exporter!r}") if stats is not None: stats["export_time"] = time.perf_counter() - begin stats["filesize"] = os.stat(filename).st_size if verbose: print(f"[common_export] exporter done in {time.perf_counter() - begin}s") print(f"[common_export] size of the export: {os.stat(filename).st_size / 2**20} Mb") with open(filename, "rb") as f: onx = onnx.load(f) if optimize_oxs: if verbose: print("[common_export] start optimization with onnxscript") begin = time.perf_counter() optimized_model = optimize_model_proto_oxs(onx, verbose=verbose, stats=stats) end = time.perf_counter() - begin if stats is not None: stats["optimization_time"] = end if verbose: print(f"[common_export] optimization done in {end}") print(f"[common_export] saves the model in {filename!r}") begin = time.perf_counter() onnx.save(optimized_model, filename) if verbose: print(f"[common_export] done saving in {time.perf_counter() - begin}") if ort_optimize and filename: output = f"{filename}.opt.onnx" if verbose: print(f"[common_export] onnxruntime optimize in {output!r}") from ..convert.convert_helper import ort_optimize as fopt is_cuda = next(model.parameters()).is_cuda if is_cuda: device_id = next(model.parameters()).get_device() providers = [ ("CUDAExecutionProvider", {"device_id": device_id}), ("CPUExecutionProvider", {}), ] else: providers = ["CPUExecutionProvider"] fopt(onx, output, providers=providers, disable_aot=False) if verbose: print("[common_export] done") return onx
[docs] def run_inference( model: Any, example_inputs: List[Any], warmup: int = 5, repeat: int = 5, verbose: int = 0, ) -> dict[str, Any]: """ Runs multiple times the same inference. Args: model: torch model to run example_inputs: dummy inputs warmup: number of iterations to warmup repeat: number of iterations to repeat verbose: verbosity Returns: statistcs """ if verbose: print(f"[run_inference] start {warmup} warmup iterations") stats: dict[str, Any] = {} iterations: list[float] = [] begin = time.perf_counter() for i in range(warmup): t0 = time.perf_counter() model(*example_inputs[i % len(example_inputs)]) iterations.append(time.perf_counter() - t0) end = time.perf_counter() - begin stats["warmup"] = warmup stats["warmup_time"] = end stats["warmup_iter"] = iterations if verbose: print(f"[run_inference] warmup done in {time.perf_counter() - begin}") print(f"[run_inference] start {repeat} iterations") iterations = [] begin = time.perf_counter() for i in range(warmup): t0 = time.perf_counter() model(*example_inputs[i % len(example_inputs)]) iterations.append(time.perf_counter() - t0) end = time.perf_counter() - begin stats["repeat"] = repeat stats["repeat_time"] = end stats["repeat_iter"] = iterations if verbose: print(f"[run_inference] measure done in {time.perf_counter() - begin}") return stats
[docs] class WrapForTorch: """Wraps a torch model.""" def __init__(self, torch_model: Any): if hasattr(torch_model, "graph_module"): self.model = torch_model.graph_module else: self.model = torch_model def run(self, inputs): if isinstance(inputs, dict): return self.model(**inputs) return self.model(*inputs) @property def input_names(self): res = [] for node in self.model.graph.nodes: if node.op == "placeholder": res.append(node.target) return res
[docs] class WrapExecutorchForTorch(WrapForTorch): """Wraps a executorch model.""" def __init__(self, model: Any, forward_method: Callable): self.model = model self.forward_method = forward_method def run(self, inputs): return self.forward_method.execute(inputs) @property def input_names(self): raise NotImplementedError(f"Not implemented yet {dir(self.model)}.") res = [] for node in self.model.graph.nodes: if node.op == "placeholder": res.append(node.target) return res
[docs] class WrapInferenceSessionForTorch: """ Wraps an `onnxruntime.InferenceSession` to overload method `run` to support :class:`torch.Tensor`. """ def __init__(self, sess: Any, nvtx: bool = False): # onnxruntime is importing when needed as it takes a # couple of seconds if it contains CUDA EP. import onnxruntime import torch from onnxruntime.capi import _pybind_state as ORTC # noqa: N812 self.sess = sess self.input_names = [i.name for i in sess.get_inputs()] self.output_names = [i.name for i in sess.get_outputs()] self.OrtValue = ORTC.OrtValue self.ORTC = ORTC self.torch = torch self.nvtx = nvtx self.run_options = onnxruntime.RunOptions() self.dlpack = False self.TORCH_DTYPE_TO_NUMPY_DTYPE = { torch.float16: np.float16, torch.float32: np.float32, torch.float64: np.float64, torch.uint8: np.uint8, torch.int8: np.int8, torch.int16: np.int16, torch.int32: np.int32, torch.int64: np.int64, torch.bool: np.bool_, } self.TORCH_DTYPE_TO_ONNX_DTYPE = { torch.float16: TensorProto.FLOAT16, torch.bfloat16: TensorProto.BFLOAT16, torch.float32: TensorProto.FLOAT, torch.float64: TensorProto.DOUBLE, torch.uint8: TensorProto.UINT8, torch.int8: TensorProto.INT8, torch.int16: TensorProto.INT16, torch.int32: TensorProto.INT32, torch.int64: TensorProto.INT64, torch.bool: TensorProto.BOOL, } DEVICES = { -1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0) } if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): DEVICES[i] = ORTC.OrtDevice( ORTC.OrtDevice.cuda(), ORTC.OrtDevice.default_memory(), i ) self.DEVICES = DEVICES def _get_ortvalues_from_torch_tensors( self, tensors: tuple[Any, ...], # tuple["torch.Tensor", ...], n_outputs: int, ) -> tuple[Any, Any]: # tuple[tuple["torch.Tensor", ...], tuple["OrtDevice", ...]]: assert tensors is not None, "tensors cannot be None" ortvalues = self.ORTC.OrtValueVector() ortvalues.reserve(len(tensors)) dtypes = [] shapes = [] data_ptrs = [] devices = [] if self.nvtx: self.torch.cuda.nvtx.range_push("_get_ortvalues_from_torch_tensors.1") max_device = -1 new_tensors = [] for tensor in tensors: assert isinstance(tensor, self.torch.Tensor), f"Unexpected type {type(tensor)}" dtypes.append(self.TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype]) shapes.append(tensor.size()) data_ptrs.append(tensor.data_ptr()) d = tensor.get_device() devices.append(self.DEVICES[d]) new_tensors.append(tensor) max_device = max(max_device, d) if self.nvtx: self.torch.cuda.nvtx.range_pop() self.torch.cuda.nvtx.range_push("_get_ortvalues_from_torch_tensors.2") assert isinstance(max_device, int), f"unexpected type for device={max_device!r}" ortvalues.push_back_batch(new_tensors, data_ptrs, dtypes, shapes, devices) output_devices = [] for _ in range(n_outputs): dev = self.DEVICES[max_device] output_devices.append(dev) if self.nvtx: self.torch.cuda.nvtx.range_pop() return ortvalues, output_devices def _ortvalues_to_torch_tensor( self, ortvalues: Any, # "onnxruntime.OrtValueVector", ) -> tuple[Any, ...]: # tuple["torch.Tensor", ...]: if len(ortvalues) == 0: return tuple() from torch._C import _from_dlpack if all(ortvalues[i].has_value() for i in range(len(ortvalues))): if self.nvtx: self.torch.cuda.nvtx.range_push("_ortvalues_to_torch_tensor.1") res = ortvalues.to_dlpacks(_from_dlpack) if self.nvtx: self.torch.cuda.nvtx.range_pop() else: if self.nvtx: self.torch.cuda.nvtx.range_push("_ortvalues_to_torch_tensor.2") res = [] for i in range(len(ortvalues)): res.append( _from_dlpack(ortvalues[i].to_dlpack()) if ortvalues[i].has_value() else None ) if self.nvtx: self.torch.cuda.nvtx.range_pop() return tuple(res) def run(self, output_names, feeds): inputs = [feeds[i] for i in self.input_names] if self.dlpack: return self.run_dlpack(*inputs, output_names=output_names) return self.run_ort_inference(*inputs) def _bind_torch_tensors( self, tensors: tuple[Any, ...], # tuple["torch.Tensor", ...], output_names: List[str], ) -> "SessionIBinding": # noqa: F821 assert tensors is not None, "tensors cannot be None" assert len(tensors) == len(self.input_names), ( f"Mismatch got {len(tensors)}, {len(self.input_names)} are expected, " f"names={self.input_names}" ) bind = self.ORTC.SessionIOBinding(self.sess._sess) max_device = -1 for name, tensor in zip(self.input_names, tensors): assert isinstance(tensor, self.torch.Tensor), f"Unexpected type {type(tensor)}" d = tensor.get_device() max_device = max(d, max_device) bind.bind_input( name, self.DEVICES[d], self.TORCH_DTYPE_TO_NUMPY_DTYPE.get( # it works on CI tensor.dtype, # it does not seem to write for all releases self.TORCH_DTYPE_TO_ONNX_DTYPE[tensor.dtype], ), tensor.shape, tensor.data_ptr(), ) device = self.DEVICES[max_device] for o in output_names: bind.bind_output(o, device) return bind def run_ort_inference(self, *inputs, output_names=None): if output_names is None: output_names = self.output_names if self.nvtx: self.torch.cuda.nvtx.range_push("_bind_torch_tensors") bind = self._bind_torch_tensors(inputs, output_names=output_names) if self.nvtx: self.torch.cuda.nvtx.range_pop() self.torch.cuda.nvtx.range_push("run_with_iobinding") self.sess._sess.run_with_iobinding(bind, self.run_options) ort_vector_outputs = bind.get_outputs() # The function returns OrtValue, the code computing the discrepancies will # have to convert (so the necessary copy is not included here). # DlPack mechanism should be implemented in onnxruntime # (not only in onnxruntime-training). ort_outputs = [ort_vector_outputs[i] for i in range(len(ort_vector_outputs))] if self.nvtx: self.torch.cuda.nvtx.range_pop() return ort_outputs def run_dlpack(self, *inputs, output_names=None): if output_names is None: output_names = self.output_names ortvalues, output_devices = self._get_ortvalues_from_torch_tensors( inputs, len(output_names) ) if self.nvtx: self.torch.cuda.nvtx.range_push("run_with_ortvaluevector") ort_outputs = self.ORTC.OrtValueVector() self.sess.run_with_ortvaluevector( self.run_options, self.input_names, ortvalues, output_names, ort_outputs, output_devices, ) if self.nvtx: self.torch.cuda.nvtx.range_pop() pth_outputs = self._ortvalues_to_torch_tensor(ort_outputs) return pth_outputs
[docs] def run_onnx_inference( model: ModelProto, example_inputs: List[Any], warmup: int = 5, repeat: int = 5, verbose: int = 0, ort_optimize: bool = True, torch_model: Optional[Any] = None, ) -> Optional[Dict[str, Any]]: """ Runs multiple times the same inference with onnxruntime. Args: model: torch model to run example_inputs: dummy inputs warmup: number of iterations to warmup repeat: number of iterations to repeat verbose: verbosity ort_optimize: enable, disable onnxruntime optimizations torch_model: if not empty, measure the discrepancies Returns: statistcs """ stats: dict[str, Any] = {} device = example_inputs[0][0].get_device() providers = ( [("CUDAExecutionProvider", {"device_id": device}), "CPUExecutionProvider"] if device >= 0 else ["CPUExecutionProvider"] ) stats["providers"] = ",".join(providers) if verbose: print(f"[run_inference] create session with providers {providers!r}") begin = time.perf_counter() # onnxruntime is importing when needed as it # takes a couple of seconds if it contains CUDA EP. import onnxruntime so = onnxruntime.SessionOptions() if ort_optimize: so.add_session_config_entry("session.disable_aot_function_inlining", "0") so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL else: so.add_session_config_entry("session.disable_aot_function_inlining", "1") so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL sess = onnxruntime.InferenceSession(model.SerializeToString(), so, providers) wrapped_session = WrapInferenceSessionForTorch(sess) end = time.perf_counter() - begin stats["ort_session_create_time"] = end if verbose: print(f"[run_inference] created session in {end}") print(f"[run_inference] start {warmup} warmup iterations") if torch_model: expected = [ torch_model(*example_inputs[i % len(example_inputs)]) for i in range(warmup) ] got = [] iterations = [] begin = time.perf_counter() for i in range(warmup): t0 = time.perf_counter() got.append(wrapped_session.run_dlpack(*example_inputs[i % len(example_inputs)])) iterations.append(time.perf_counter() - t0) end = time.perf_counter() - begin stats["warmup"] = warmup stats["warmup_time"] = end / warmup stats["warmup_iter"] = iterations if torch_model: d = measure_discrepancies(expected, got) stats["discrepancies_abs"] = d["abs"] stats["discrepancies_rel"] = d["rel"] stats["discrepancies_avg"] = d["avg"] if verbose: print(f"[run_inference] warmup done in {time.perf_counter() - begin}") print(f"[run_inference] start {repeat} iterations") iterations = [] begin = time.perf_counter() for i in range(repeat): t0 = time.perf_counter() wrapped_session.run_dlpack(*example_inputs[i % len(example_inputs)]) iterations.append(time.perf_counter() - t0) end = time.perf_counter() - begin stats["repeat"] = repeat stats["repeat_time"] = end / repeat stats["repeat_iter"] = iterations if verbose: print(f"[run_inference] measure done in {time.perf_counter() - begin}") return stats
def str_shape(shape: Tuple[Any, ...]) -> str: s = "x".join(str(i) for i in shape) return s def str_dtype(dtype: Any) -> str: s = str(dtype) return s.replace("torch.", "").replace("'", "").replace("<", "").replace(">", "")