Source code for onnx_diagnostic.torch_onnx.sbs

import inspect
import time
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Sequence, Tuple, Union
import onnx
import onnx.helper as oh
import numpy as np
import torch
from ..helpers import string_type, string_diff, max_diff, flatten_object
from ..helpers.onnx_helper import pretty_onnx
from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype
from ..helpers.torch_fx_graph_helper import prepare_args_kwargs, run_fx_node
from ..reference.ort_evaluator import OnnxList, OnnxruntimeEvaluator
from .sbs_dataclasses import (
    ReplayConfiguration,
    RunAlignedRecord,
    StatusRunAligned,
    make_torch_inputs,
)


def _check_tensor_(use_tensor, name, obj, flip_type=False):
    if flip_type:
        if use_tensor:
            if isinstance(obj, np.ndarray):
                obj = from_numpy(obj)
        else:
            if isinstance(obj, torch.Tensor):
                obj = to_numpy(obj)

    assert not use_tensor or isinstance(obj, (torch.Tensor, OnnxList)), (
        f"Unexpected type {type(obj)} for {name!r}. "
        f"use_tensor is True so torch.Tensor is expected."
    )
    assert use_tensor or isinstance(obj, (np.ndarray, OnnxList)), (
        f"Unexpected type {type(obj)} for {name!r}. "
        f"use_tensor is False so np.array is expected."
    )
    return obj


def _make_node_from_initializer(proto: onnx.TensorProto) -> onnx.NodeProto:
    return oh.make_node("Constant", [], [proto.name], value=proto)


def _loop_cmp(
    mapping_onnx_to_torch: Dict[str, str],
    torch_results: Dict[str, torch.Tensor],
    onnx_results: Dict[str, Any],
    onnx_name: str,
    onnx_result: torch.Tensor,
    second_onnx_result: torch.Tensor,
    verbose: int,
    atol: Optional[float],
    rtol: Optional[float],
    i_torch: int,
    i_onnx: int,
    str_kws: Dict[str, bool],
    exc: bool,
    use_tensor: bool,
) -> Optional[RunAlignedRecord]:
    onnx_results[onnx_name] = _check_tensor_(use_tensor, onnx_name, onnx_result)
    if verbose > 1:
        print(f"[run_aligned-nx] +res: {onnx_name}={string_type(onnx_result, **str_kws)}")

    to = mapping_onnx_to_torch.get(onnx_name, onnx_name)
    if to in torch_results:
        d = max_diff(torch_results[to], onnx_result, hist=[0.1, 0.01])
        if verbose > 1:
            if onnx_name == to:
                print(f"[run_aligned-==] cmp {to}: {string_diff(d)}")
            else:
                print(f"[run_aligned-~~] cmd {to}/{onnx_name}: {string_diff(d)}")
            if not (atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol)):
                if exc:
                    raise ValueError(
                        f"discrepancies detected for results [{to}/{onnx_name}]: "
                        f"{string_diff(d)}"
                        f"\n-- onnx_result: {string_type(onnx_result[to], **str_kws)}"
                        f"\n-- onnx_results: {string_type(onnx_result, **str_kws)}"
                        f"\n-- torch\n{onnx_result[to]}"
                    )
                else:
                    print(
                        f"[run_align-dx] discrepancies {string_diff(d)} - [{to}/{onnx_name}]"
                    )
        r = RunAlignedRecord(
            ep_id_node=i_torch,
            onnx_id_node=i_onnx,
            ep_name=to,
            onnx_name=onnx_name,
            ep_shape_type=string_type(torch_results[to], **str_kws),
            onnx_shape_type=string_type(onnx_result, **str_kws),
        )
        r.set_diff(d)
        if second_onnx_result is not None:
            d2 = max_diff(torch_results[to], second_onnx_result, hist=[0.1, 0.01])
            r.set_diff2(d2)
        mapping_onnx_to_torch[onnx_name] = to
        return r
    return None


def _duplicated_values(d):
    rev = {}
    for k, v in d.items():
        if v in rev:
            rev[v].append(k)
        else:
            rev[v] = [k]
    res = {k: v for k, v in rev.items() if len(v) > 1}
    final = set()
    for v in res.values():
        final |= set(v)
    return final


def _validation_nn_functional(
    node: onnx.NodeProto, new_feeds: Dict[str, torch.Tensor], expected: List[torch.Tensor]
) -> Optional[str]:
    if node.op_type == "Gemm" and len(node.input) == 3:
        atts = {}
        for att in node.attribute:
            if att.name in ("alpha", "beta"):
                atts[att.name] = att.f
            elif att.name in ("transA", "transB"):
                atts[att.name] = att.i
        if atts == {"transB": 1}:
            res = torch.nn.functional.linear(*[new_feeds[i] for i in node.input])
            diff = max_diff(res, expected[0])
            return f"function.linear:{string_diff(diff)}"
    return None


def _loop_onnx_node(
    onx: onnx.ModelProto,
    ep_graph_nodes: List[torch.fx.Node],
    onnx_results: Dict[str, Any],
    mapping_onnx_to_torch: Dict[str, str],
    torch_results: Dict[str, torch.Tensor],
    ep_durations,
    use_tensor: bool,
    i_torch: int,
    i_onnx: int,
    name_to_ep_node: Dict[str, int],
    run_cls_kwargs: Dict[str, Any],
    str_kws: Dict[str, bool],
    status: StatusRunAligned,
    already_run_onnx: Set[int],
    torch_names_to_onnx_names: Dict[str, str],
    verbose: int,
    exc: bool,
    atol: float,
    rtol: float,
    reset_names: Set[str],
    replay_configuration: Optional[ReplayConfiguration],
    has_cuda: bool,
    run_cls: type,
    loop: Any,
    run_onnx_with_torch_inputs: bool,
) -> Iterator[Optional[RunAlignedRecord]]:

    if i_onnx in already_run_onnx:
        yield None
    node = onx.graph.node[i_onnx]
    if verbose > 1:
        print(
            f"[run_aligned] run onx.graph.node[{i_onnx}]: "
            f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
        )
    elif verbose == 1:
        loop.set_description(
            f"ep {i_torch}/{len(ep_graph_nodes)} nx {i_onnx}/{len(onx.graph.node)} "
            f"{status.to_str()}"
        )
        loop.update(min(1, 1 + i_torch + i_onnx))

    ref = run_cls(node, **run_cls_kwargs)
    # We need to clone because the runtime maybe using dlpack to create OrtValue
    hidden_inputs = OnnxruntimeEvaluator._get_hidden_node_inputs(node)
    all_inputs = [*node.input, *hidden_inputs] if hidden_inputs else node.input
    feeds = (
        {k: onnx_results[k].clone() for k in all_inputs if k}
        if use_tensor
        else {k: onnx_results[k].copy() for k in all_inputs if k}
    )
    assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}"
    if verbose > 1:
        print(f"[run_aligned] feeds={string_type(feeds, **str_kws)}")
    begin = time.perf_counter()
    try:
        res = ref.run(None, feeds)  # type: ignore[attr-defined]
    except Exception as e:
        raise RuntimeError(
            f"Unable to run node {node.op_type}, domain={node.domain} "
            f"with inputs={node.input}, feeds={string_type(feeds, **str_kws)}"
        ) from e
    duration = time.perf_counter() - begin
    if verbose > 1:
        print(f"[run_aligned] res={string_type(res, **str_kws)}")
    assert (
        not has_cuda
        or not any(t is not None and t.is_cuda for t in feeds.values())
        or any(t is not None and t.is_cuda for t in res)
        or node.op_type in {"Shape", "Size"}  # on CPU no matter what
        or node.op_type
        in {
            "Add",
            "Concat",
            "Div",
            "Gather",
            "Mul",
            "Range",
            "Squeeze",
            "Sub",
            "Unsqueeze",
        }  # not sure, could be about shapes
    ), (
        f"One input is on cuda but there is no float output on cuda, "
        f"feeds={string_type(feeds, with_device=True, with_shape=True)}, "
        f"res={string_type(res, with_device=True, with_shape=True)}, "
        f"node is {pretty_onnx(node)}"
    )

    comment = None
    cross = None
    if run_onnx_with_torch_inputs:
        # Let's run the operator with torch results if they are available
        new_feeds, removed = make_torch_inputs(
            node.input,
            {
                **{v: k for k, v in torch_names_to_onnx_names.items()},
                **mapping_onnx_to_torch,
            },
            onnx_results,
            torch_results,
            submodel=None,
        )
        if not removed:
            if verbose > 1:
                print(
                    f"[run_aligned] feeds for second run="
                    f"{string_type(new_feeds, **str_kws)}"
                )
            cross = ref.run(None, new_feeds)
            if verbose > 1:
                print(f"[run_aligned] got for second run={string_type(cross, **str_kws)}")
            # Gemm = torch.nn.function.linear, in that case, we just run it as well
            to = mapping_onnx_to_torch.get(node.output[0], node.output[0])
            if to in torch_results:
                comment = _validation_nn_functional(node, new_feeds, [torch_results[to]])
        elif verbose > 1:
            print(f"[run_aligned] second run not possible because of missing {removed}")

    if cross is None:
        cross = [None for _ in res]

    list_node_output = list(node.output)
    node_output = [o for o in list_node_output if o]
    for o, r, r2 in zip(node_output, res, cross):
        if r is None or not o:
            continue
        tmp = _loop_cmp(
            mapping_onnx_to_torch,
            torch_results,
            onnx_results,
            o,
            r,
            r2,
            verbose,
            atol,
            rtol,
            i_torch,
            i_onnx,
            str_kws,
            exc,
            use_tensor,
        )
        if tmp is not None:
            if tmp.ep_name in name_to_ep_node:
                tmp.ep_id_node = name_to_ep_node[tmp.ep_name]
                tmp.ep_target = str(ep_graph_nodes[tmp.ep_id_node].target)
                tmp.ep_time_run = ep_durations[tmp.ep_id_node]
            else:
                tmp.ep_id_node = None
                tmp.ep_target = None
                tmp.ep_name = None
            tmp.onnx_op_type = onx.graph.node[tmp.onnx_id_node].op_type
            tmp.onnx_id_output = list_node_output.index(o)
            tmp.onnx_time_run = duration
            status.yielded_nodes += 1
            if tmp.err_abs is not None:
                status.update(tmp.err_abs)
            tmp.comment = comment
            yield tmp

            # do we need to dump pieces if graph the user can replay?
            if replay_configuration:
                if replay_configuration.select(
                    name=tmp.onnx_name, op_type=tmp.onnx_op_type, err_abs=tmp.err_abs
                ):
                    replay_configuration.dump(
                        name=tmp.onnx_name,
                        onnx_id_node=tmp.onnx_id_node,
                        model=onx,
                        onnx_results=onnx_results,
                        torch_results=torch_results,
                        onnx_name_to_ep_name={
                            **{v: k for k, v in torch_names_to_onnx_names.items()},
                            **mapping_onnx_to_torch,
                        },
                        verbose=max(verbose - 1, 0),
                    )
                    status.last_replay = tmp.onnx_name

            # reset_names: replaces onnx_results by torch_results to see
            # if that fixes the discrepancies problem
            if reset_names and tmp.ep_name in reset_names:
                assert (
                    tmp.ep_name in torch_results
                ), f"name {tmp.ep_name!r} set to be reset is missing in torch_results."
                assert (
                    tmp.onnx_name in onnx_results
                ), f"name {tmp.onnx_name!r} set to be reset is missing in onnx_results."
                onnx_results[tmp.onnx_name] = torch_results[tmp.ep_name]
                tmp = _loop_cmp(
                    mapping_onnx_to_torch,
                    torch_results,
                    onnx_results,
                    o,
                    torch_results[tmp.ep_name],
                    None,
                    verbose,
                    atol,
                    rtol,
                    i_torch,
                    i_onnx,
                    str_kws,
                    exc,
                    use_tensor,
                )
                assert tmp.err_abs == 0, f"Reset did not happen, tmp={tmp}"
                if tmp is not None:
                    tmp.onnx_op_type = "reset"
                    tmp.onnx_id_output = list_node_output.index(o)
                    status.yielded_nodes += 1
                    yield tmp
    already_run_onnx.add(i_onnx)


def _preparation_with_fx_graph(
    ep_graph_nodes: List[torch.fx.Node],
    name_to_ep_node: Dict[str, int],
    torch_input_names: List[str],
    onx: onnx.ModelProto,
    torch_names_to_onnx_names: Dict[str, str],
    skip_mapping_torch_onnx,
    torch_results: Dict[str, torch.Tensor],
    placeholders: Dict[str, torch.Tensor],
    placeholders_to_state_dict: Dict[str, str],
    ep_state_dict: Dict[str, torch.Tensor],
    positions: Dict[str, Dict[str, int]],
) -> List[str]:
    torch_output_names = None
    for i, node in enumerate(ep_graph_nodes):
        if isinstance(node.name, str):
            positions[node.name] = dict(fx=i)
        else:
            for n in node.name:
                positions[n] = dict(fx=i)
        if node.op == "placeholder":
            if node.name in placeholders_to_state_dict:
                # This a weight.
                placeholders[node.name] = ep_state_dict[placeholders_to_state_dict[node.name]]
                torch_results[node.name] = placeholders[node.name]
                assert isinstance(torch_results[node.name], torch.Tensor), (
                    f"torch_results[{node.name}] not a tensor but "
                    f"{type(torch_results[node.name])}"
                )
            else:
                # This is an input
                assert len(torch_input_names) < len(onx.graph.input), (
                    f"torch_input_names={torch_input_names!r}, "
                    f"onnx_input_names={[n.name for n in onx.graph.input]}, "
                    f"node.name={node.name!r} cannot be an input"
                )
                assert node.name not in skip_mapping_torch_onnx, (
                    f"{node.name!r} is ambiguous, cannot be mapped due to "
                    f"{skip_mapping_torch_onnx}"
                )
                torch_names_to_onnx_names[node.name] = onx.graph.input[
                    len(torch_input_names)
                ].name
                torch_input_names.append(node.name)
        elif node.op == "output":
            torch_output_names = [n.name for n in node.args[0]]
        assert isinstance(node.name, str), (
            f"Unexpected type {type(node.name)} for node={node} (target={node.target}), "
            f"args={node.args}"
        )
        name_to_ep_node[node.name] = i
    assert torch_output_names is not None, "No output node ws found the graph."
    return torch_output_names


def _preparation_with_onnx_model(
    default_device,
    use_tensor: bool,
    onx: onnx.ModelProto,
    already_yielded: Dict[str, Any],
    str_kws: Dict[str, bool],
    positions: Dict[str, Dict[str, int]],
    torch_names_to_onnx_names: Dict[str, str],
    torch_output_names: List[str],
    torch_results: Dict[str, torch.Tensor],
    skip_mapping_torch_onnx: Set[str],
    verbose: int,
    args: Sequence[Any],
    kwargs: Dict[str, Any],
) -> Tuple[Dict[str, str], Dict[str, torch.Tensor], float, float, List[RunAlignedRecord]]:
    for inp in onx.graph.input:
        n = inp.name
        if n in positions:
            positions[n]["onnx"] = -1
        else:
            positions[n] = dict(onnx=-1)
    for inp in onx.graph.initializer:
        n = inp.name
        if n in positions:
            positions[n]["onnx"] = -1
        else:
            positions[n] = dict(onnx=-1)
    for i, node in enumerate(onx.graph.node):
        for n in node.output:
            if n in positions:
                positions[n]["onnx"] = i
            else:
                positions[n] = dict(onnx=i)

    onnx_outputs_names = [o.name for o in onx.graph.output]
    assert torch_output_names is not None and len(torch_output_names) == len(
        onnx_outputs_names
    ), (
        f"Unexpected number of outputs, torch_output_names={torch_output_names}, "
        f"onnx_outputs_names={onnx_outputs_names}"
    )
    mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names))

    onnx_args = list(args) if args else []
    if kwargs:
        onnx_args.extend(flatten_object(kwargs, drop_keys=True))
    if verbose:
        print(f"[run_aligned]   args: {string_type(args, **str_kws)}")
        print(f"[run_aligned] kwargs: {string_type(kwargs, **str_kws)}")
        print(f"[run_aligned]   onnx: {string_type(onnx_args, **str_kws)}")
        print(f"[run_aligned] nx: walks through {len(onx.graph.input)} onnx inputs")
    onnx_results: Dict[str, Any] = {}
    for inp, v in zip(onx.graph.input, onnx_args):
        onnx_results[inp.name] = _check_tensor_(
            use_tensor, inp.name, v if use_tensor else to_numpy(v)
        )
        if verbose:
            print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}")

    # alias for initializers
    skip_onnx_name = set()
    init_aliases: Dict[str, str] = {}
    for init in onx.graph.initializer:
        new_names = {
            n
            for n in [
                f"p_{init.name.replace('.', '_')}",
                f"p_{init.name.split('::')[0].split('--')[-1].replace('.', '_')}",
                f"{init.name.split('::')[0].split('--')[-1].replace('.', '_')}",
            ]
            if n != init.name
        }
        drop = False
        for new_name in new_names:
            if new_name in skip_onnx_name:
                drop = True
                break
        if drop:
            skip_onnx_name |= new_names | {init.name}
            for new_name in new_names:
                if new_names in init_aliases:
                    del init_aliases[new_name]
        else:
            for new_name in new_names:
                init_aliases[new_name] = init.name
    rev_init_aliases: Dict[str, Set[str]] = {}
    for k, v in init_aliases.items():
        if v in rev_init_aliases:
            rev_init_aliases[v].add(k)
        else:
            rev_init_aliases[v] = {k}

    # initializers
    if verbose:
        print(f"[run_aligned] nx: handles {len(onx.graph.initializer)} initializers from onnx")
    memory_cpu = 0
    memory_cuda = 0
    records_to_yield = []
    for init in onx.graph.initializer:  # type: ignore
        t = None
        if init.name in torch_results:
            if init.name not in skip_mapping_torch_onnx:
                t = torch_results[init.name]
                torch_names_to_onnx_names[init.name] = init.name
        elif init.name not in skip_onnx_name and init.name in rev_init_aliases:
            new_names = [
                k
                for k in rev_init_aliases[init.name]
                if k in torch_results and k not in skip_mapping_torch_onnx
            ]
            if new_names and len(new_names) == 1:
                new_name = new_names[0]  # type: ignore[assignment, index]
                t = torch_results[new_name]
                if (
                    len(set(t.shape)) == len(t.shape)  # not repeated dimension
                    and t.shape == tuple(init.dims)
                    and torch_dtype_to_onnx_dtype(t.dtype) == init.data_type
                ):
                    torch_names_to_onnx_names[new_name] = init.name
                else:
                    t = None

                # We should check tensors and proto are the same.
        if t is None:
            t = to_tensor(init)
            if default_device and t.numel() >= 1024:
                # Let's force its way to cuda (should check the device has well).
                t = t.to(default_device)
            records_to_yield.append(
                RunAlignedRecord(
                    onnx_id_node=-1,
                    onnx_name=init.name,
                    onnx_op_type="initializer",
                    onnx_shape_type=string_type(t, **str_kws),
                ).check(already_yielded)
            )

        size = t.element_size() * t.numel()
        if t.is_cuda:
            memory_cuda += size
        else:
            memory_cpu += size
        if init.name not in onnx_results:
            # otherwise, it is an input with a default value
            onnx_results[init.name] = _check_tensor_(use_tensor, init.name, t, flip_type=True)
    return mapping_onnx_to_torch, onnx_results, memory_cpu, memory_cuda, records_to_yield


[docs] def run_aligned( ep: torch.export.ExportedProgram, onx: Union[onnx.ModelProto, onnx.FunctionProto], run_cls: Callable[ [ Union[ onnx.ModelProto, onnx.FunctionProto, onnx.GraphProto, onnx.NodeProto, ] ], List[Union[np.ndarray, torch.Tensor]], ], args: Optional[Tuple[torch.Tensor, ...]] = None, kwargs: Optional[Dict[str, Any]] = None, use_tensor: bool = False, atol: Optional[float] = None, rtol: Optional[float] = None, verbose: int = 0, exc: bool = True, reset_names: Optional[List[str]] = None, replay_configuration: Optional[ReplayConfiguration] = None, run_onnx_with_torch_inputs: bool = False, ) -> Iterator[RunAlignedRecord]: """ Runs in parallel both the exported program and the onnx proto and looks for discrepancies. The function does match on result names so it assumes the exported program and the onnx model have the same names for equivalent results. :param ep: exported program :param onx: model or function proto :param run_cls: defines the runtime to use for this task :param args: input args :param kwargs: input kwargs :param use_tensor: use torch tensors instead of numpy arrays for the onnx runtime :param atol: absolute tolerance :param rtol: relative tolerance :param verbose: verbosity level :param exc: stops if an exception :param reset_names: list of names, the onnx execution takes the torch outputs instead of its own result if the names falls into that set :param replay_configuration: configuration to let the user dump any problematic piece of the onnx graph he wants to replay in order to investigate later, see :class: `ReplayConfiguration <onnx_diagnostic.torch_onnx.sbs.ReplayConfiguration>` :param run_onnx_with_torch_inputs: run an onnx operator with torch results if they available :return: a list of :class:`RunAlignedRecord <onnx_diagnostic.torch_onnx.sbs_dataclasses.RunAlignedRecord>` Example: .. runpython:: :showcode: :warningout: UserWarning import pandas import torch from onnx_diagnostic.reference import ( # This can be replaced by any runtime taking NodeProto as an input. ExtendedReferenceEvaluator as ReferenceEvaluator, ) from onnx_diagnostic.torch_onnx.sbs import run_aligned class Model(torch.nn.Module): def forward(self, x): ry = x.abs() rz = ry.exp() rw = rz + 1 ru = rw.log() + rw return ru x = torch.randn((5, 4)) Model()(x) # to make sure the model is running ep = torch.export.export( Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},) ) onx = torch.onnx.export( Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},) ).model_proto results = list( run_aligned(ep, onx, ReferenceEvaluator, (x,), atol=1e-5, rtol=1e-5, verbose=1) ) print("------------") print("final results") df = pandas.DataFrame(results) df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col) print(df) This example uses :class:`onnx.reference.ReferenceEvaluator` to run the onnx model but onnxruntime can also be used through :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`. It relies on :epkg:`onnxruntime` and selects CPU or CUDA depending on the device where the inputs are located. The :class:`torch.export.ExportedProgram` can be saved on disk with ``ep.save("<filename>.pt")`` and restored with ``torch.export.load("<filename>.pt")``. That leeds the input to save. We can decouple the export and the alignment. .. runpython:: :showcode: :warningout: UserWarning import onnx import torch from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str class Model(torch.nn.Module): def forward(self, x): ry = x.abs() rz = ry.exp() rw = rz + 1 ru = rw.log() + rw return ru x = torch.randn((5, 4)) dynamic_shapes = ({0: "batch"},) Model()(x) # to make sure the model is running ep = torch.export.export( Model(), (x,), dynamic_shapes=use_dyn_not_str(dynamic_shapes) ) onx = torch.onnx.export( Model(), (x,), dynamic_shapes=dynamic_shapes ).model_proto torch.export.save(ep, "test_doc_sbs_example.pt2") onnx.save(onx, "test_doc_sbs_example.onnx") torch.save((x,), "test_doc_sbs_example.pt") Then we can restore all of them and run it. .. runpython:: :showcode: :warningout: UserWarning import pandas import onnx import torch from onnx_diagnostic.torch_onnx.sbs import run_aligned from onnx_diagnostic.reference import OnnxruntimeEvaluator ep = torch.export.load("test_doc_sbs_example.pt2") onx = onnx.load("test_doc_sbs_example.onnx") inputs = torch.load("test_doc_sbs_example.pt") results = list( run_aligned( ep, onx, OnnxruntimeEvaluator, inputs, atol=1e-5, rtol=1e-5, verbose=1, use_tensor=True, ) ) print("------------") print("final results") df = pandas.DataFrame(results) df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col) print(df) A command line can also be run: .. code-block:: bash python -m onnx_diagnostic sbs -i <tensors>.input.pt \\ --ep <exported_program>.pt2 \\ -m <model>.onnx \\ -o results.xlsx \\ -v 1 --atol=0.1 --rtol=1 """ assert callable(run_cls), f"run_cls={run_cls} not a callable" already_yielded = {} reset_names = set(reset_names) if reset_names else set() str_kws = dict(with_shape=True, with_device=True) has_cuda = any( (isinstance(t, torch.Tensor) and t.is_cuda) for t in flatten_object([args, kwargs], drop_keys=True) ) default_device = None if has_cuda: for t in flatten_object([args, kwargs], drop_keys=True): if t is not None and t.is_cuda: default_device = t.device break run_cls_kwargs = { "ir_version": onx.ir_version, "opsets": {d.domain: d.version for d in onx.opset_import}, "verbose": max(verbose - 1, 0), "providers": ( ["CUDAExecutionProvider", "CPUExecutionProvider"] if has_cuda else ["CPUExecutionProvider"] ), } run_cls_kwargs = { k: v for k, v in run_cls_kwargs.items() if k in set(inspect.signature(run_cls).parameters) } if verbose: print(f"[run_aligned] run_cls={run_cls}") print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}") if replay_configuration: print(f"[run_aligned] replay={replay_configuration}") # preparation with ep.graph.nodes ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)} placeholders_to_state_dict = { **{f"p_{name.replace('.', '_')}": name for name in ep.state_dict}, **{f"b_{name.replace('.', '_')}": name for name, _ in ep.named_buffers()}, **{f"c_{name.replace('.', '_')}": name for name in ep.tensor_constants}, } skip_mapping_torch_onnx = _duplicated_values(placeholders_to_state_dict) placeholders = {} if verbose: print(f"[run_aligned] ep: model has {len(ep_state_dict)} torch constants or weights.") if verbose: print(f"[run_aligned] ep: walks through {len(ep.graph.nodes)} nodes from torch") # dictionary mapping result names and their position in both graphs. positions: Dict[str, Dict[str, int]] = {} ep_graph_nodes = list(ep.graph.nodes) torch_results: Dict[str, Any] = {} torch_output_names = None torch_input_names: List[str] = [] name_to_ep_node = {} torch_names_to_onnx_names = {} torch_output_names = _preparation_with_fx_graph( ep_graph_nodes, name_to_ep_node, torch_input_names, onx, torch_names_to_onnx_names, skip_mapping_torch_onnx, torch_results, placeholders, placeholders_to_state_dict, ep_state_dict, positions, ) # prepration for onnx if verbose: print(f"[run_aligned] ep: found {len(torch_results)} torch constants or weights.") print(f"[run_aligned] ep: found inputs {torch_input_names}") print(f"[run_aligned] ep: found outputs {torch_output_names}") print(f"[run_aligned] nx: walks through {len(onx.graph.node)} nodes from onnx") mapping_onnx_to_torch, onnx_results, memory_cpu, memory_cuda, records_to_yield = ( _preparation_with_onnx_model( default_device, use_tensor, onx, already_yielded, str_kws, positions, torch_names_to_onnx_names, torch_output_names, torch_results, skip_mapping_torch_onnx, verbose, args, kwargs, ) ) for record in records_to_yield: yield record if verbose: print(f"[run_aligned] nx: handled {len(onnx_results)} initializers from onnx") print(f"[run_aligned] nx: memory cpu {memory_cpu / 2**20:.3f} Mb") print(f"[run_aligned] nx: memory cuda {memory_cuda / 2**20:.3f} Mb") print(f"[run_aligned] nx: {len(onnx_results)} constants") print(f"[run_aligned] nx: {len(onx.graph.input)} inputs") print(f"[run_aligned] nx: {len(onx.graph.output)} outputs") print(f"[run_aligned] bo: {len(mapping_onnx_to_torch)} outputs") print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}") if verbose > 1: for k, v in torch_results.items(): print(f"[run_aligned-ep] +cst: {k}: {string_type(v, **str_kws)}") for k, v in onnx_results.items(): print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}") # starts the side-by-side if verbose: print( f"[run_aligned] ep: starts side-by-side with {len(ep_graph_nodes)} " f"fx nodes and {len(onx.graph.node)} onnx nodes" ) if verbose == 1: import tqdm loop = tqdm.tqdm(total=len(ep_graph_nodes) + len(onx.graph.node)) else: loop = None already_run: Set[int] = set() ep_durations = {} status = StatusRunAligned() last_position = 0 for i_torch, node in enumerate(ep_graph_nodes): if verbose > 1: if node.op == "call_function": print( f"[run_aligned] run ep.graph.nodes[{i_torch}]: " f"{node.op}[{node.target}] -> {node.name!r}" ) else: print( f"[run_aligned] run ep.graph.nodes[{i_torch}]: {node.op} -> {node.name!r}" ) elif verbose == 1: loop.set_description( f"ep {i_torch}/{len(ep_graph_nodes)} nx {last_position}/{len(onx.graph.node)} " f"{status.to_str()}" ) loop.update(min(1, 1 + i_torch + last_position)) if node.op == "placeholder": is_input = node.name not in placeholders if is_input: torch_results[node.name] = ( onnx_results[torch_names_to_onnx_names[node.name]] if use_tensor else from_numpy(onnx_results[torch_names_to_onnx_names[node.name]]) ) assert isinstance(torch_results[node.name], torch.Tensor), ( f"torch_results[{node.name}] not a tensor but " f"{type(torch_results[node.name])}, use_tensor={use_tensor}" ) t = torch_results[node.name] if verbose > 1: print(f"[run_aligned-ep] =ags: {node.name}={string_type(t, **str_kws)}") # Otherwise, it is an input. record = RunAlignedRecord( ep_id_node=i_torch, onnx_id_node=-1, ep_name=node.name, onnx_name=torch_names_to_onnx_names[node.name], ep_target="input", onnx_op_type="input", ep_shape_type=string_type(t, **str_kws), onnx_shape_type=string_type( onnx_results[torch_names_to_onnx_names[node.name]], **str_kws ), ) yield record.check(already_yielded) else: assert node.name in placeholders_to_state_dict, ( f"Unable to find placeholder {node.name!r} (node.op={node.op!r}), " f"existing: {sorted(placeholders_to_state_dict)}" ) assert node.name in torch_results, ( f"placeholder {node.name!r} (node.op={node.op!r}), " f"should have been added to torch_results: {sorted(torch_results)}" ) t = torch_results[node.name] if ( node.name in torch_names_to_onnx_names and node.name not in skip_mapping_torch_onnx ): if verbose > 1: print( f"[run_aligned-ep] =plh: " f"{node.name}={string_type(t, **str_kws)}" ) record = RunAlignedRecord( ep_id_node=i_torch, onnx_id_node=-1, ep_name=node.name, onnx_name=torch_names_to_onnx_names[node.name], ep_target="placeholder", onnx_op_type="initializer", ep_shape_type=string_type(t, **str_kws), onnx_shape_type=string_type( onnx_results[torch_names_to_onnx_names[node.name]], **str_kws ), ) if not is_input: record.set_diff( max_diff( t, onnx_results[torch_names_to_onnx_names[node.name]], hist=[0.1, 0.01], ) ) yield record.check(already_yielded) else: if verbose > 1: print( f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}" ) yield RunAlignedRecord( ep_id_node=i_torch, ep_name=node.name, ep_target="placeholder", ep_shape_type=string_type(t, **str_kws), ).check(already_yielded) continue outputs = [node.name] if isinstance(node.name, str) else list(node.name) args, kwargs = prepare_args_kwargs(torch_results, node) begin = time.perf_counter() new_outputs = run_fx_node(node, args, kwargs) duration = time.perf_counter() - begin ep_durations[i_torch] = duration if isinstance(new_outputs, (torch.Tensor, int, float, list, tuple)): new_outputs = (new_outputs,) if new_outputs is None: # Probably an assert. continue for k, v in zip(outputs, new_outputs): torch_results[k] = v if verbose > 1: for k, v in zip(outputs, new_outputs): print(f"[run_aligned-ep] +res: {k}={string_type(v, **str_kws)}") max_pos = -2 for n in outputs: if n in positions: if "onnx" in positions[n]: max_pos = max(max_pos, positions[n]["onnx"]) if "fx" in positions[n]: if positions[n]["fx"] > i_torch: max_pos = -2 break if max_pos == -2: # we skip. continue next_to_visit = last_position for i_onnx in range(last_position, max_pos + 1): if i_onnx in already_run: continue # The onnx node may produce more than one output, in that # case, we need to check the exported program is not behind. node = onx.graph.node[i_onnx] ep_behind = False for iname in node.output: if iname in positions and "fx" in positions[iname]: if positions[iname]["fx"] > i_torch: ep_behind = True break if ep_behind: break for r in _loop_onnx_node( onx, ep_graph_nodes, onnx_results, mapping_onnx_to_torch, torch_results, ep_durations, use_tensor, i_torch, i_onnx, name_to_ep_node, run_cls_kwargs, str_kws, status, already_run, torch_names_to_onnx_names, verbose, exc, atol, rtol, reset_names, replay_configuration, has_cuda, run_cls, loop, run_onnx_with_torch_inputs, ): if r: yield r.check(already_yielded) next_to_visit = i_onnx + 1 last_position = next_to_visit # complete the execution of the onnx graph if verbose > 1: print( f"[run_aligned] complete execution of onnx graph from pos={last_position} " f"to {len(onx.graph.node)}" ) for i_onnx in range(last_position, len(onx.graph.node)): if i_onnx in already_run: continue for r in _loop_onnx_node( onx, ep_graph_nodes, onnx_results, mapping_onnx_to_torch, torch_results, ep_durations, use_tensor, i_torch, i_onnx, name_to_ep_node, run_cls_kwargs, str_kws, status, already_run, torch_names_to_onnx_names, verbose, exc, atol, rtol, reset_names, replay_configuration, has_cuda, run_cls, loop, run_onnx_with_torch_inputs, ): if r: yield r.check(already_yielded) if loop is not None: loop.close() if verbose: print(f"[run_aligned] done with status={status.to_str()}")