onnx_diagnostic.torch_onnx.sbs

onnx_diagnostic.torch_onnx.sbs.prepare_args_kwargs(torch_results: Dict[str, Any], node: Node) Tuple[Tuple[Any, ...], Dict[str, Any]][source]

Prepares args and kwargs before executing a fx node.

Parameters:
  • torch_results – existing results

  • node – node to execute

Returns:

new args and kwargs

onnx_diagnostic.torch_onnx.sbs.run_aligned(ep: ExportedProgram, onx: ModelProto | FunctionProto, args: Tuple[Tensor, ...], check_conversion_cls: Dict[str, Any] | type, kwargs: Dict[str, Any] | None = None, verbose: int = 0) Iterator[Tuple[Any, ...]][source]

Runs in parallel both the exported program and the onnx proto and looks for discrepancies. The function does match on result names so it assumes the exported program and the onnx model have the same names for equivalent results.

Parameters:
  • ep – exported program

  • onx – model or function proto

  • args – input args

  • check_conversion_cls – defines the runtime to use for this task

  • kwargs – input kwargs

  • verbose – verbosity level

Returns:

a list of tuples containing the results, they come in tuple,

Example:

<<<

import pprint
import pandas
import torch
from onnx_diagnostic.reference import (
    # This can be replace by any runtime taking NodeProto as an input.
    ExtendedReferenceEvaluator as ReferenceEvaluator,
)
from onnx_diagnostic.torch_onnx.sbs import run_aligned


class Model(torch.nn.Module):
    def forward(self, x):
        ry = x.abs()
        rz = ry.exp()
        rw = rz + 1
        ru = rw.log() + rw
        return ru


def post_process(obs):
    dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs))
    dobs["err_abs"] = obs[-1]["abs"]
    dobs["err_rel"] = obs[-1]["rel"]
    return dobs


x = torch.randn((5, 4))
Model()(x)  # to make sure the model is running
ep = torch.export.export(
    Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
)
onx = torch.onnx.export(
    Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},), dynamo=True
).model_proto
results = list(
    map(
        post_process,
        run_aligned(
            ep,
            onx,
            (x,),
            check_conversion_cls=dict(cls=ReferenceEvaluator, atol=1e-5, rtol=1e-5),
            verbose=1,
        ),
    ),
)
print("------------")
print("final results")
df = pandas.DataFrame(results)
print(df)

>>>

    /home/xadupre/github/onnxscript/onnxscript/converter.py:823: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
      param_schemas = callee.param_schemas()
    /home/xadupre/github/onnxscript/onnxscript/converter.py:823: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
      param_schemas = callee.param_schemas()
    [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`... ✅
    [torch.onnx] Run decomposition...
    [torch.onnx] Run decomposition... ✅
    [torch.onnx] Translate the graph into ONNX...
    [torch.onnx] Translate the graph into ONNX... ✅
    [run_aligned] +onnx-input: x: T1s5x4[-2.37814998626709,1.8000156879425049:A0.2102274749428034]
    [run_aligned] run ep.graph.nodes[0]: placeholder -> 'x'
    [run_aligned] +torch x=T1s5x4[-2.37814998626709,1.8000156879425049:A0.2102274749428034]
    [run_aligned] run ep.graph.nodes[1]: call_function[aten.abs.default] -> 'abs_1'
    [run_aligned] +torch abs_1=T1s5x4[0.0842517837882042,2.37814998626709:A0.8864509236067534]
    [run_aligned] run onx.graph.node[0]: Abs(x) -> abs_1
    [run_aligned] +onnx abs_1=A1s5x4[0.0842517837882042,2.37814998626709:A0.8864509236067534]
    [run_aligned] =common results abs_1: abs=0.0, rel=0.0
    [run_aligned] run ep.graph.nodes[2]: call_function[aten.exp.default] -> 'exp'
    [run_aligned] +torch exp=T1s5x4[1.0879027843475342,10.784932136535645:A2.931220865249634]
    [run_aligned] run onx.graph.node[1]: Exp(abs_1) -> exp
    [run_aligned] +onnx exp=A1s5x4[1.0879027843475342,10.784933090209961:A2.931220918893814]
    [run_aligned] =common results exp: abs=9.5367431640625e-07, rel=9.311585534838043e-08, n=20.0
    [run_aligned] run ep.graph.nodes[3]: call_function[aten.add.Tensor] -> 'add'
    [run_aligned] +torch add=T1s5x4[2.087902784347534,11.784932136535645:A3.9312208533287047]
    [run_aligned] run ep.graph.nodes[4]: call_function[aten.log.default] -> 'log'
    [run_aligned] +torch log=T1s5x4[0.7361600995063782,2.4668216705322266:A1.2638569533824922]
    [run_aligned] run onx.graph.node[2]: Constant() -> scalar_tensor_default
    [run_aligned] +onnx scalar_tensor_default=A1s=1.0
    [run_aligned] run onx.graph.node[3]: Add(exp, scalar_tensor_default) -> add_6
    [run_aligned] +onnx add_6=A1s5x4[2.087902784347534,11.784933090209961:A3.9312209248542787]
    [run_aligned] run onx.graph.node[4]: Log(add_6) -> log
    [run_aligned] +onnx log=A1s5x4[0.7361600995063782,2.4668219089508057:A1.26385697722435]
    [run_aligned] =common results log: abs=2.384185791015625e-07, rel=1.4452490643151165e-07, n=20.0
    [run_aligned] run ep.graph.nodes[5]: call_function[aten.add.Tensor] -> 'add_1'
    [run_aligned] +torch add_1=T1s5x4[2.8240628242492676,14.251753807067871:A5.195077753067016]
    [run_aligned] run ep.graph.nodes[6]: output -> 'output'
    [run_aligned] +torch output=(T1s5x4[2.8240628242492676,14.251753807067871:A5.195077753067016],)
    [run_aligned] run onx.graph.node[5]: Add(log, add_6) -> add_13
    [run_aligned] +onnx add_13=A1s5x4[2.8240628242492676,14.251754760742188:A5.1950778603553776]
    [run_aligned] =common results* add_1/add_13: abs=9.5367431640625e-07, rel=1.5597590704800244e-07, n=20.0
    ------------
    final results
       ep_id_node  onnx_id_node ep_name onnx_name       err_abs       err_rel
    0           1             0   abs_1     abs_1  0.000000e+00  0.000000e+00
    1           2             1     exp       exp  9.536743e-07  9.311586e-08
    2           4             4     log       log  2.384186e-07  1.445249e-07
    3           6             5  add_13     add_1  9.536743e-07  1.559759e-07
onnx_diagnostic.torch_onnx.sbs.run_fx_node(node: Node, args: Tuple[Any, ...], kwargs: Dict[str, Any] | None = None) Tuple[Any, ...][source]

Executes a node

Parameters:
  • node – runs a node

  • args – unnamed inputs to the node

  • kwargs – named inputs to the node

Returns:

results

onnx_diagnostic.torch_onnx.sbs.validate_fx_outputs(node: Node, outputs: Tuple[Any, ...]) None[source]

Validates the outputs of a node using metadata stored in the node.

Parameters:
  • node – node

  • outputs – outputs

onnx_diagnostic.torch_onnx.sbs.validate_fx_tensor(node: Node, tensor: Tensor, expected_shape: Tuple[Any, ...]) None[source]

Validates the shape of tensor is expected.

Parameters:
  • node – node

  • tensor – tensor

  • expected_shape – expected shape