Source code for onnx_diagnostic.torch_onnx.sbs

from typing import Any, Dict, Iterator, Optional, Tuple, Union
import onnx
import torch
from ..helpers import string_type, string_diff, max_diff
from ..helpers.onnx_helper import to_array_extended


[docs] def validate_fx_tensor( node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...] ) -> None: """ Validates the shape of tensor is expected. :param node: node :param tensor: tensor :param expected_shape: expected shape """ assert len(tensor.shape) == len(expected_shape), ( f"Shape mismatch, got {tensor.shape} expected {expected_shape}, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" ) for a, b in zip(tensor.shape, expected_shape): assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, ( f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" )
[docs] def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None: """ Validates the outputs of a node using metadata stored in the node. :param node: node :param outputs: outputs """ if "val" not in node.meta: return if isinstance(outputs, torch.Tensor): validate_fx_tensor(node, outputs, node.meta["val"].shape) return if isinstance(outputs, (tuple, list)): assert isinstance(node.meta["val"], (list, tuple)), ( f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" ) assert len(outputs) == len(node.meta["val"]), ( f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" ) for a, b in zip(outputs, node.meta["val"]): validate_fx_tensor(node, a, b.shape) return if isinstance(outputs, int): assert ( isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat)) or outputs == node.meta["val"] ), ( f"Int mismatch, got {outputs} expected {node.meta['val']}, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" ) return if outputs is None: assert node.meta["val"] is None, ( f"None mismatch, got {outputs} expected {node.meta['val']}, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" ) return raise NotImplementedError( f"Validation for output type {type(outputs)} is not implemented, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" )
[docs] def run_fx_node( node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None ) -> Tuple[Any, ...]: """ Executes a node :param node: runs a node :param args: unnamed inputs to the node :param kwargs: named inputs to the node :return: results """ if node.op == "output": assert len(args) == 1 and not kwargs, ( f"Unexpected inputs: args={string_type(args, limit=20)} " f"kwargs={string_type(kwargs, limit=20)}" ) return args if node.op == "call_function": assert callable(node.target), f"{node.target!r} not callable in node {node!r}" outputs = node.target(*args, **(kwargs or {})) validate_fx_outputs(node, outputs) return outputs raise NotImplementedError( f"node.op={node.op!r} is not implemented, node.name={node.name!r}" )
def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any: "See :func:`prepare_args_kwargs`." if isinstance(ref, torch.fx.Node): return torch_results[ref.name] if isinstance(ref, list): return [_pick_result(torch_results, n) for n in ref] if isinstance(ref, tuple): return tuple(_pick_result(torch_results, n) for n in ref) if isinstance(ref, dict): return {k: _pick_result(torch_results, v) for k, v in ref.items()} if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)): return ref if ref is None: return None raise NotImplementedError(f"Unable to process args type {type(ref)}")
[docs] def prepare_args_kwargs( torch_results: Dict[str, Any], node: torch.fx.Node ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """ Prepares args and kwargs before executing a fx node. :param torch_results: existing results :param node: node to execute :return: new args and kwargs """ new_args = _pick_result(torch_results, node.args) new_kwargs = _pick_result(torch_results, node.kwargs) return new_args, new_kwargs
[docs] def run_aligned( ep: torch.export.ExportedProgram, onx: Union[onnx.ModelProto, onnx.FunctionProto], args: Tuple[torch.Tensor, ...], check_conversion_cls: Union[Dict[str, Any], type], kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, ) -> Iterator[Tuple[Any, ...]]: """ Runs in parallel both the exported program and the onnx proto and looks for discrepancies. The function does match on result names so it assumes the exported program and the onnx model have the same names for equivalent results. :param ep: exported program :param onx: model or function proto :param args: input args :param check_conversion_cls: defines the runtime to use for this task :param kwargs: input kwargs :param verbose: verbosity level :return: a list of tuples containing the results, they come in tuple, Example: .. runpython:: :showcode: :warningout: UserWarning import pprint import pandas import torch from onnx_diagnostic.reference import ( # This can be replace by any runtime taking NodeProto as an input. ExtendedReferenceEvaluator as ReferenceEvaluator, ) from onnx_diagnostic.torch_onnx.sbs import run_aligned class Model(torch.nn.Module): def forward(self, x): ry = x.abs() rz = ry.exp() rw = rz + 1 ru = rw.log() + rw return ru def post_process(obs): dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs)) dobs["err_abs"] = obs[-1]["abs"] dobs["err_rel"] = obs[-1]["rel"] return dobs x = torch.randn((5, 4)) Model()(x) # to make sure the model is running ep = torch.export.export( Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},) ) onx = torch.onnx.export( Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},), dynamo=True ).model_proto results = list( map( post_process, run_aligned( ep, onx, (x,), check_conversion_cls=dict(cls=ReferenceEvaluator, atol=1e-5, rtol=1e-5), verbose=1, ), ), ) print("------------") print("final results") df = pandas.DataFrame(results) print(df) """ assert not kwargs, f"Not implemented when kwargs={string_type(kwargs,with_shape=True)}" cls, atol, rtol = ( ( check_conversion_cls["cls"], check_conversion_cls["atol"], check_conversion_cls["rtol"], ) if isinstance(check_conversion_cls, dict) else (check_conversion_cls, None, None) ) # retrieve the positions positions: Dict[str, Any] = {} for i, node in enumerate(ep.graph.nodes): if isinstance(node.name, str): positions[node.name] = dict(fx=i) else: for n in node.name: positions[n] = dict(fx=i) for i, node in enumerate(onx.graph.node): for n in node.output: if n in positions: positions[n]["onnx"] = i else: positions[n] = dict(onnx=i) onnx_results: Dict[str, Any] = {} for init in onx.graph.initializer: # type: ignore positions[init.name] = -1 onnx_results[init.name] = to_array_extended(init) param_name = f"p_{init.name.replace('.', '_')}" if param_name == init.name: continue assert param_name not in onnx_results, ( f"Some confusion may happen because {init.name!r} -> {param_name!r} " f"and onnx_results has {sorted(onnx_results)}" ) onnx_results[param_name] = onnx_results[init.name] torch_results: Dict[str, Any] = { k: torch.from_numpy(v.copy()) for k, v in onnx_results.items() if not k.startswith("init") } last_position = 0 torch_output_names = None for node in ep.graph.nodes: if node.op == "output": torch_output_names = [n.name for n in node.args[0]] onnx_outputs_names = [o.name for o in onx.graph.output] assert torch_output_names is not None and len(torch_output_names) == len( onnx_outputs_names ), ( f"Unexpected number of outputs, torch_output_names={torch_output_names}, " f"onnx_outputs_names={onnx_outputs_names}" ) mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names)) if verbose: for k, v in torch_results.items(): print( f"[run_aligned] +torch-cst: {k}: " f"{string_type(v, with_shape=True, with_min_max=True)}" ) for k, v in onnx_results.items(): print( f"[run_aligned] +onnx-init: {k}: " f"{string_type(v, with_shape=True, with_min_max=True)}" ) for inp, v in zip(onx.graph.input, args): onnx_results[inp.name] = v.numpy() if verbose: print( f"[run_aligned] +onnx-input: {inp.name}: " f"{string_type(v, with_shape=True, with_min_max=True)}" ) for i, node in enumerate(ep.graph.nodes): if verbose: if node.op == "call_function": print( f"[run_aligned] run ep.graph.nodes[{i}]: " f"{node.op}[{node.target}] -> {node.name!r}" ) else: print(f"[run_aligned] run ep.graph.nodes[{i}]: {node.op} -> {node.name!r}") if node.op == "placeholder": if node.name in onnx_results: torch_results[node.name] = torch.from_numpy(onnx_results[node.name].copy()) if verbose: t = torch_results[node.name] print( f"[run_aligned] +torch {node.name}=" f"{string_type(t, with_shape=True, with_min_max=True)}" ) continue raise AssertionError( f"unable to process node {node.op} -> {node.name!r} " f"not in {sorted(onnx_results)}, len(args)={len(args)}, " f"onx.graph.input={[i.name for i in onx.graph.input]}" ) outputs = [node.name] if isinstance(node.name, str) else list(node.name) args, kwargs = prepare_args_kwargs(torch_results, node) new_outputs = run_fx_node(node, args, kwargs) if isinstance(new_outputs, (torch.Tensor, int, float, list)): new_outputs = (new_outputs,) if new_outputs is None: # Probably an assert. continue for k, v in zip(outputs, new_outputs): torch_results[k] = v if verbose: for k, v in zip(outputs, new_outputs): print( f"[run_aligned] +torch {k}=" f"{string_type(v, with_shape=True, with_min_max=True)}" ) max_pos = -2 for n in outputs: if n in positions and "onnx" in positions[n]: max_pos = max(max_pos, positions[n]["onnx"]) if max_pos == -2: # we skip. continue for i_onnx in range(last_position, max_pos + 1): node = onx.graph.node[i_onnx] if verbose: print( f"[run_aligned] run onx.graph.node[{i_onnx}]: " f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}" ) ref = cls(node) feeds = {k: onnx_results[k] for k in node.input} res = ref.run(None, feeds) for o, r in zip(node.output, res): onnx_results[o] = r if verbose: print( f"[run_aligned] +onnx {o}=" f"{string_type(r, with_shape=True, with_min_max=True)}" ) to = mapping_onnx_to_torch.get(o, o) if to in torch_results: d = max_diff(torch_results[to], r) if verbose: if o == to: print(f"[run_aligned] =common results {to}: {string_diff(d)}") else: print(f"[run_aligned] =common results {to}/{o}: {string_diff(d)}") if not ( atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol) ): skw = dict(with_shape=True, with_min_max=True) raise ValueError( f"discrepancies detected for results [{to}/{o}]: " f"{string_diff(d)}" f"\n-- torch_results: {string_type(torch_results[to], **skw)}" f"\n-- onnx_results: {string_type(r, **skw)}" f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}" ) yield (i, i_onnx, o, to, d) last_position = max_pos + 1 # complete the execution of the onnx graph for i_onnx in range(last_position, len(onx.graph.node)): node = onx.graph.node[i_onnx] if verbose: print( f"[run_aligned] run onx.graph.node[{i_onnx}]: " f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}" ) ref = cls(node) feeds = {k: onnx_results[k] for k in node.input} res = ref.run(None, feeds) for o, r in zip(node.output, res): onnx_results[o] = r if verbose: print( f"[run_aligned] +onnx {o}=" f"{string_type(r, with_shape=True, with_min_max=True)}" ) to = mapping_onnx_to_torch.get(o, o) if to in torch_results: d = max_diff(torch_results[to], r) if verbose: if o == to: print(f"[run_aligned] =common results* {to}: {string_diff(d)}") else: print(f"[run_aligned] =common results* {to}/{o}: {string_diff(d)}") if not ( atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol) ): skw = dict(with_shape=True, with_min_max=True) raise ValueError( f"discrepancies detected for results* [{to}/{o}]: {string_diff(d)}" f"\n-- torch_results: {string_type(torch_results[to], **skw)}" f"\n-- onnx_results: {string_type(r, **skw)}" f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}" ) yield (i, i_onnx, o, to, d)