Source code for onnx_diagnostic.export.validate

import inspect
import itertools
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from ..helpers import string_type, max_diff, string_diff
from ..helpers.torch_test_helper import torch_deepcopy
from .dynamic_shapes import CoupleInputsDynamicShapes


[docs] def compare_modules( modep: torch.nn.Module, mod: Optional[torch.nn.Module] = None, args: Optional[Tuple[Any, ...]] = None, kwargs: Optional[Dict[str, Any]] = None, copy: bool = False, exc: bool = True, verbose: int = 0, atol: float = 1e-2, rtol: float = 1e-1, ) -> Dict[str, Any]: """ Compares two torch modules, usually one coming from an exported program, the other being the origin model. :param model: first module :param mod: second module (it produces the expected values) :param args: positional arguments :param kwargs: named arguments :param copy: copy the inputs before executing the model (they may modify them inplace) :param exc: raise exception if discrepancies are too high :param verbose: verbosity level :param atol: absolute tolerance :param rtol: relative tolerance :return: dictionary with inputs, outputs and tolerance Example: .. runpython:: :showcode: import torch from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes class Model(torch.nn.Module): def forward(self, x, y): return x + y model = Model() x = torch.randn((5, 6)) y = torch.randn((1, 6)) model(x, y) # to make it is running ds = ({0: "a", 1: "b"}, {1: "b"}) cpl = CoupleInputsDynamicShapes((x, y), {}, ds) ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by()) validate_ep( ep, model, args=(x, y), verbose=2, copy=True, dynamic_shapes=ds, values_to_try={"a": [5, 10], "b": [10, 20]}, ) """ args = args or () kwargs = kwargs or {} def _get(a): return torch_deepcopy(a) if copy else a if verbose: begin = time.perf_counter() print( f"[compare_modules] check ep with " f"args={string_type(args, with_shape=True, with_device=True)}, " f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}..." ) got = modep(*_get(args), **_get(kwargs)) if verbose: d = time.perf_counter() - begin print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}") if mod: if verbose: begin = time.perf_counter() print("[compare_modules] run torch module...") expected = mod(*_get(args), **_get(kwargs)) diff = max_diff(expected, got) if verbose: d = time.perf_counter() - begin print( f"[compare_modules] done in {d} with " f"output={string_type(expected, with_shape=True)}" ) print(f"[compare_modules] discrepancies={string_diff(diff)}") assert not exc or ( diff["abs"] <= atol and diff["rel"] <= rtol ), f"Discrepancies={string_diff(diff)} higher than expected." return dict(args=args, kwargs=kwargs, expected=expected, got=got, diff=diff) return dict(args=args, kwargs=kwargs, got=got)
[docs] def validate_ep( ep: Union[torch.nn.Module, torch.export.ExportedProgram], mod: Optional[torch.nn.Module] = None, args: Optional[Tuple[Any, ...]] = None, kwargs: Optional[Dict[str, Any]] = None, copy: bool = False, dynamic_shapes: Optional[Any] = None, values_to_try: Optional[Dict[str, List[int]]] = None, exc: bool = True, verbose: int = 0, atol: float = 1e-2, rtol: float = 1e-1, ) -> List[Dict[str, Any]]: """ Validates an exported program. :param model: first module :param mod: second module (it produces the expected values) :param args: positional arguments :param kwargs: named arguments :param copy: copy the inputs before executing the model (they may modify them inplace) :param dynamic_shapes: dynamic shapes, string should be used not ``torch.export.Dim`` :param values_to_try: dictionary with the values to try for every dynamic dimension :param exc: raise exception if discrepancies are too high :param verbose: verbosity level :param atol: absolute tolerance :param rtol: relative tolerance :return: dictionary with inputs, outputs and tolerance """ modep = ep.module() if isinstance(ep, torch.export.ExportedProgram) else ep results = [ compare_modules( modep, mod, args, kwargs, copy=copy, verbose=verbose, atol=atol, rtol=rtol ) ] assert (dynamic_shapes and values_to_try) or ( not dynamic_shapes and not values_to_try ), "Either both dynamic_shapes and values_to_try are specified, either none." if not dynamic_shapes or not values_to_try: return results items = list(values_to_try.items()) keys = [_[0] for _ in items] values = [_[1] for _ in items] all_vals = list(itertools.product(*values)) cpl = CoupleInputsDynamicShapes( args or (), kwargs or {}, dynamic_shapes, args_names=( list(inspect.signature(modep.forward).parameters) if args and kwargs else None ), ) for i, vals in enumerate(all_vals): change_dims = dict(zip(keys, vals)) if verbose: print(f"[validate_ep] try {i}/{len(all_vals)}: {change_dims}") new_params = cpl.change_dynamic_dimensions(change_dims, args_kwargs=True) na, nkw = new_params c = compare_modules( modep, mod, na, nkw, copy=copy, verbose=max(verbose - 1, 0), atol=atol, rtol=rtol ) results.append(c) return results