Source code for onnx_diagnostic.helpers.torch_fx_graph_helper

from typing import Any, Dict, Optional, Tuple
import torch
from .helper import string_type


[docs] def validate_fx_tensor( node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...] ) -> None: """ Validates the shape of tensor is expected. :param node: node :param tensor: tensor :param expected_shape: expected shape """ assert len(tensor.shape) == len(expected_shape), ( f"Shape mismatch, got {tensor.shape} expected {expected_shape}, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" ) for a, b in zip(tensor.shape, expected_shape): assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, ( f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" )
[docs] def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None: """ Validates the outputs of a node using metadata stored in the node. :param node: node :param outputs: outputs """ if "val" not in node.meta: return if isinstance(outputs, torch.Tensor): validate_fx_tensor(node, outputs, node.meta["val"].shape) return if isinstance(outputs, (tuple, list)): assert isinstance(node.meta["val"], (list, tuple)), ( f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" ) assert len(outputs) == len(node.meta["val"]), ( f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" ) for a, b in zip(outputs, node.meta["val"]): validate_fx_tensor(node, a, b.shape) return if isinstance(outputs, int): assert ( isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat)) or outputs == node.meta["val"] ), ( f"Int mismatch, got {outputs} expected {node.meta['val']}, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" ) return if outputs is None: assert node.meta["val"] is None, ( f"None mismatch, got {outputs} expected {node.meta['val']}, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" ) return raise NotImplementedError( f"Validation for output type {type(outputs)} is not implemented, " f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, " f"node.args={node.args}, node.kwargs={node.kwargs}, " f"node.meta={node.meta}" )
[docs] def run_fx_node( node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None ) -> Tuple[Any, ...]: """ Executes a node :param node: runs a node :param args: unnamed inputs to the node :param kwargs: named inputs to the node :return: results """ if node.op == "output": assert len(args) == 1 and not kwargs, ( f"Unexpected inputs: args={string_type(args, limit=20)} " f"kwargs={string_type(kwargs, limit=20)}" ) return args if node.op == "call_function": assert callable(node.target), f"{node.target!r} not callable in node {node!r}" for a, ea in zip(args, node.args): if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta: ta = ea.meta["val"] assert ( isinstance(ta, torch.Tensor) and len(a.shape) == len(ta.shape) and a.dtype == ta.dtype ), ( f"Unable to run node {node!r}, target={node.target!r}, " f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, " f"args={string_type(args, with_shape=True, with_device=True)}, " f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}" ) try: outputs = node.target(*args, **(kwargs or {})) except RuntimeError as e: raise RuntimeError( f"Unable to run node {node!r}, target={node.target!r}, " f"args={string_type(args, with_shape=True, with_device=True)}, " f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}" ) from e validate_fx_outputs(node, outputs) return outputs raise NotImplementedError( f"node.op={node.op!r} is not implemented, node.name={node.name!r}" )
def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any: "See :func:`prepare_args_kwargs`." if isinstance(ref, torch.fx.Node): return torch_results[ref.name] if isinstance(ref, list): return [_pick_result(torch_results, n) for n in ref] if isinstance(ref, tuple): return tuple(_pick_result(torch_results, n) for n in ref) if isinstance(ref, dict): return {k: _pick_result(torch_results, v) for k, v in ref.items()} if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)): return ref if ref is None: return None if isinstance(ref, torch.layout): return ref raise NotImplementedError(f"Unable to process args type {type(ref)}")
[docs] def prepare_args_kwargs( torch_results: Dict[str, Any], node: torch.fx.Node ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """ Prepares args and kwargs before executing a fx node. :param torch_results: existing results :param node: node to execute :return: new args and kwargs """ new_args = _pick_result(torch_results, node.args) new_kwargs = _pick_result(torch_results, node.kwargs) return new_args, new_kwargs