onnx_diagnostic.torch_onnx.sbs

onnx_diagnostic.torch_onnx.sbs.run_aligned(ep: ExportedProgram, onx: ModelProto | FunctionProto, run_cls: Callable[[ModelProto | FunctionProto | GraphProto | NodeProto], List[ndarray | Tensor]], args: Tuple[Tensor, ...] | None = None, kwargs: Dict[str, Any] | None = None, use_tensor: bool = False, atol: float | None = None, rtol: float | None = None, verbose: int = 0, exc: bool = True, reset_names: List[str] | None = None, replay_configuration: ReplayConfiguration | None = None, run_onnx_with_torch_inputs: bool = False) Iterator[RunAlignedRecord][source][source]

Runs in parallel both the exported program and the onnx proto and looks for discrepancies. The function does match on result names so it assumes the exported program and the onnx model have the same names for equivalent results.

Parameters:
  • ep – exported program

  • onx – model or function proto

  • run_cls – defines the runtime to use for this task

  • args – input args

  • kwargs – input kwargs

  • use_tensor – use torch tensors instead of numpy arrays for the onnx runtime

  • atol – absolute tolerance

  • rtol – relative tolerance

  • verbose – verbosity level

  • exc – stops if an exception

  • reset_names – list of names, the onnx execution takes the torch outputs instead of its own result if the names falls into that set

  • replay_configuration – configuration to let the user dump any problematic piece of the onnx graph he wants to replay in order to investigate later, see :class: ReplayConfiguration <onnx_diagnostic.torch_onnx.sbs.ReplayConfiguration>

  • run_onnx_with_torch_inputs – run an onnx operator with torch results if they available

Returns:

a list of RunAlignedRecord

Example:

<<<

import pandas
import torch
from onnx_diagnostic.reference import (
    # This can be replaced by any runtime taking NodeProto as an input.
    ExtendedReferenceEvaluator as ReferenceEvaluator,
)
from onnx_diagnostic.torch_onnx.sbs import run_aligned


class Model(torch.nn.Module):
    def forward(self, x):
        ry = x.abs()
        rz = ry.exp()
        rw = rz + 1
        ru = rw.log() + rw
        return ru


x = torch.randn((5, 4))
Model()(x)  # to make sure the model is running
ep = torch.export.export(
    Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
)
onx = torch.onnx.export(
    Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
).model_proto
results = list(
    run_aligned(ep, onx, ReferenceEvaluator, (x,), atol=1e-5, rtol=1e-5, verbose=1)
)
print("------------")
print("final results")
df = pandas.DataFrame(results)
df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
print(df)

>>>

    [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`... ✅
    [torch.onnx] Run decomposition...
    /usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
      return cls.__new__(cls, *args)
    [torch.onnx] Run decomposition... ✅
    [torch.onnx] Translate the graph into ONNX...
    [torch.onnx] Translate the graph into ONNX... ✅
    [run_aligned] run_cls=<class 'onnx_diagnostic.reference.evaluator.ExtendedReferenceEvaluator'>
    [run_aligned] run_cls_kwargs={'opsets': {'': 20}, 'verbose': 0}
    [run_aligned] ep: model has 0 torch constants or weights.
    [run_aligned] ep: walks through 7 nodes from torch
    [run_aligned] ep: found 0 torch constants or weights.
    [run_aligned] ep: found inputs  ['x']
    [run_aligned] ep: found outputs ['add_1']
    [run_aligned] nx: walks through 5 nodes from onnx
    [run_aligned]   args: (CT1s5x4,)
    [run_aligned] kwargs: None
    [run_aligned]   onnx: #1[CT1s5x4]
    [run_aligned] nx: walks through 1 onnx inputs
    [run_aligned-nx] +inp: x: CT1s5x4
    [run_aligned] nx: handles 1 initializers from onnx
    [run_aligned] nx: handled 2 initializers from onnx
    [run_aligned] nx: memory cpu 0.000 Mb
    [run_aligned] nx: memory cuda 0.000 Mb
    [run_aligned] nx: 2 constants
    [run_aligned] nx: 1 inputs
    [run_aligned] nx: 1 outputs
    [run_aligned] bo: 1 outputs
    [run_aligned] run_cls_kwargs={'opsets': {'': 20}, 'verbose': 0}
    [run_aligned] ep: starts side-by-side with 7 fx nodes and 5 onnx nodes
    
  0%|          | 0/12 [00:00<?, ?it/s]
ep 0/7 nx 0/5 yielded=0 maxabs=0.000 #inf=0 #nan=0:   0%|          | 0/12 [00:00<?, ?it/s]
ep 1/7 nx 0/5 yielded=0 maxabs=0.000 #inf=0 #nan=0:   8%|8         | 1/12 [00:00<00:00, 2603.54it/s]
ep 1/7 nx 0/5 yielded=0 maxabs=0.000 #inf=0 #nan=0:  17%|#6        | 2/12 [00:00<00:00, 2449.23it/s]
ep 2/7 nx 1/5 yielded=1 maxabs=0.000 #inf=0 #nan=0:  25%|##5       | 3/12 [00:00<00:00, 677.63it/s] 
ep 2/7 nx 1/5 yielded=1 maxabs=0.000 #inf=0 #nan=0:  33%|###3      | 4/12 [00:00<00:00, 818.12it/s]
ep 3/7 nx 2/5 yielded=2 maxabs=0.000 #inf=0 #nan=0:  42%|####1     | 5/12 [00:00<00:00, 731.02it/s]
ep 4/7 nx 2/5 yielded=2 maxabs=0.000 #inf=0 #nan=0:  50%|#####     | 6/12 [00:00<00:00, 827.09it/s]
ep 4/7 nx 2/5 yielded=2 maxabs=0.000 #inf=0 #nan=0:  58%|#####8    | 7/12 [00:00<00:00, 936.80it/s]
ep 4/7 nx 3/5 yielded=2 maxabs=0.000 #inf=0 #nan=0:  67%|######6   | 8/12 [00:00<00:00, 940.40it/s]
ep 5/7 nx 4/5 yielded=3 maxabs=0.000 #inf=0 #nan=0:  75%|#######5  | 9/12 [00:00<00:00, 883.09it/s]
ep 6/7 nx 4/5 yielded=3 maxabs=0.000 #inf=0 #nan=0:  83%|########3 | 10/12 [00:00<00:00, 940.95it/s]
ep 6/7 nx 4/5 yielded=3 maxabs=0.000 #inf=0 #nan=0:  92%|#########1| 11/12 [00:00<00:00, 1019.95it/s]
ep 6/7 nx 4/5 yielded=3 maxabs=0.000 #inf=0 #nan=0: 100%|##########| 12/12 [00:00<00:00, 982.23it/s] 
    [run_aligned] done with status=yielded=4 maxabs=0.000 #inf=0 #nan=0
    ------------
    final results
       ep_id_node  onnx_id_node ep_name              onnx_name         ep_target onnx_op_type  onnx_id_output ep_shape_type onnx_shape_type       err_abs       err_rel  err_dev err_nan  err_h01  err_h001  ep_time_run  onnx_time_run err_abs2 err_rel2 err_dev2 err_nan2 err_h012 err_h0012 comment
    0         NaN            -1          scalar_tensor_default                    initializer             NaN                          CT1s           NaN           NaN      NaN              NaN       NaN          NaN            NaN                                                               
    1         0.0            -1       x                      x             input        input             NaN       CT1s5x4          A1s5x4           NaN           NaN      NaN              NaN       NaN          NaN            NaN                                                               
    2         1.0             0   abs_1                  abs_1  aten.abs.default          Abs             0.0       CT1s5x4          A1s5x4  0.000000e+00  0.000000e+00      0.0              0.0       0.0     0.000205       0.000171                                                               
    3         2.0             1     exp                    exp  aten.exp.default          Exp             0.0       CT1s5x4          A1s5x4  9.536743e-07  1.157886e-07      0.0              0.0       0.0     0.000176       0.000142                                                               
    4         4.0             3     log                    log  aten.log.default          Log             0.0       CT1s5x4          A1s5x4  1.192093e-07  1.682992e-07      0.0              0.0       0.0     0.000062       0.000131                                                               
    5         5.0             4   add_1                 add_13   aten.add.Tensor          Add             0.0       CT1s5x4          A1s5x4  9.536743e-07  1.742278e-07      0.0              0.0       0.0     0.000170       0.000128

This example uses onnx.reference.ReferenceEvaluator to run the onnx model but onnxruntime can also be used through onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch. It relies on onnxruntime and selects CPU or CUDA depending on the device where the inputs are located.

The torch.export.ExportedProgram can be saved on disk with ep.save("<filename>.pt") and restored with torch.export.load("<filename>.pt"). That leeds the input to save. We can decouple the export and the alignment.

<<<

import onnx
import torch
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class Model(torch.nn.Module):
    def forward(self, x):
        ry = x.abs()
        rz = ry.exp()
        rw = rz + 1
        ru = rw.log() + rw
        return ru


x = torch.randn((5, 4))
dynamic_shapes = ({0: "batch"},)
Model()(x)  # to make sure the model is running
ep = torch.export.export(Model(), (x,), dynamic_shapes=use_dyn_not_str(dynamic_shapes))
onx = torch.onnx.export(Model(), (x,), dynamic_shapes=dynamic_shapes).model_proto

torch.export.save(ep, "test_doc_sbs_example.pt2")
onnx.save(onx, "test_doc_sbs_example.onnx")
torch.save((x,), "test_doc_sbs_example.pt")

>>>

    [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`... ✅
    [torch.onnx] Run decomposition...
    /usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
      return cls.__new__(cls, *args)
    [torch.onnx] Run decomposition... ✅
    [torch.onnx] Translate the graph into ONNX...
    [torch.onnx] Translate the graph into ONNX... ✅

Then we can restore all of them and run it.

<<<

import pandas
import onnx
import torch
from onnx_diagnostic.torch_onnx.sbs import run_aligned
from onnx_diagnostic.reference import OnnxruntimeEvaluator


ep = torch.export.load("test_doc_sbs_example.pt2")
onx = onnx.load("test_doc_sbs_example.onnx")
inputs = torch.load("test_doc_sbs_example.pt")


results = list(
    run_aligned(
        ep,
        onx,
        OnnxruntimeEvaluator,
        inputs,
        atol=1e-5,
        rtol=1e-5,
        verbose=1,
        use_tensor=True,
    )
)
print("------------")
print("final results")
df = pandas.DataFrame(results)
df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
print(df)

>>>

    [run_aligned] run_cls=<class 'onnx_diagnostic.reference.ort_evaluator.OnnxruntimeEvaluator'>
    [run_aligned] run_cls_kwargs={'ir_version': 10, 'opsets': {'': 20}, 'verbose': 0, 'providers': ['CPUExecutionProvider']}
    [run_aligned] ep: model has 0 torch constants or weights.
    [run_aligned] ep: walks through 7 nodes from torch
    [run_aligned] ep: found 0 torch constants or weights.
    [run_aligned] ep: found inputs  ['x']
    [run_aligned] ep: found outputs ['add_1']
    [run_aligned] nx: walks through 5 nodes from onnx
    [run_aligned]   args: (CT1s5x4,)
    [run_aligned] kwargs: None
    [run_aligned]   onnx: #1[CT1s5x4]
    [run_aligned] nx: walks through 1 onnx inputs
    [run_aligned-nx] +inp: x: CT1s5x4
    [run_aligned] nx: handles 1 initializers from onnx
    [run_aligned] nx: handled 2 initializers from onnx
    [run_aligned] nx: memory cpu 0.000 Mb
    [run_aligned] nx: memory cuda 0.000 Mb
    [run_aligned] nx: 2 constants
    [run_aligned] nx: 1 inputs
    [run_aligned] nx: 1 outputs
    [run_aligned] bo: 1 outputs
    [run_aligned] run_cls_kwargs={'ir_version': 10, 'opsets': {'': 20}, 'verbose': 0, 'providers': ['CPUExecutionProvider']}
    [run_aligned] ep: starts side-by-side with 7 fx nodes and 5 onnx nodes
    
  0%|          | 0/12 [00:00<?, ?it/s]
ep 0/7 nx 0/5 yielded=0 maxabs=0.000 #inf=0 #nan=0:   0%|          | 0/12 [00:00<?, ?it/s]
ep 1/7 nx 0/5 yielded=0 maxabs=0.000 #inf=0 #nan=0:   8%|8         | 1/12 [00:00<00:00, 9822.73it/s]
ep 1/7 nx 0/5 yielded=0 maxabs=0.000 #inf=0 #nan=0:  17%|#6        | 2/12 [00:00<00:00, 6350.20it/s]
ep 2/7 nx 1/5 yielded=1 maxabs=0.000 #inf=0 #nan=0:  25%|##5       | 3/12 [00:00<00:00, 301.02it/s] 
ep 2/7 nx 1/5 yielded=1 maxabs=0.000 #inf=0 #nan=0:  33%|###3      | 4/12 [00:00<00:00, 384.08it/s]
ep 3/7 nx 2/5 yielded=2 maxabs=0.000 #inf=0 #nan=0:  42%|####1     | 5/12 [00:00<00:00, 364.55it/s]
ep 4/7 nx 2/5 yielded=2 maxabs=0.000 #inf=0 #nan=0:  50%|#####     | 6/12 [00:00<00:00, 422.12it/s]
ep 4/7 nx 2/5 yielded=2 maxabs=0.000 #inf=0 #nan=0:  58%|#####8    | 7/12 [00:00<00:00, 483.76it/s]
ep 4/7 nx 3/5 yielded=2 maxabs=0.000 #inf=0 #nan=0:  67%|######6   | 8/12 [00:00<00:00, 419.34it/s]
ep 5/7 nx 4/5 yielded=3 maxabs=0.000 #inf=0 #nan=0:  75%|#######5  | 9/12 [00:00<00:00, 291.15it/s]
ep 6/7 nx 4/5 yielded=3 maxabs=0.000 #inf=0 #nan=0:  83%|########3 | 10/12 [00:00<00:00, 317.13it/s]
ep 6/7 nx 4/5 yielded=3 maxabs=0.000 #inf=0 #nan=0:  92%|#########1| 11/12 [00:00<00:00, 345.04it/s]
ep 6/7 nx 4/5 yielded=3 maxabs=0.000 #inf=0 #nan=0: 100%|##########| 12/12 [00:00<00:00, 325.14it/s]
    [run_aligned] done with status=yielded=4 maxabs=0.000 #inf=0 #nan=0
    ------------
    final results
       ep_id_node  onnx_id_node ep_name              onnx_name         ep_target onnx_op_type  onnx_id_output ep_shape_type onnx_shape_type       err_abs       err_rel  err_dev err_nan  err_h01  err_h001  ep_time_run  onnx_time_run err_abs2 err_rel2 err_dev2 err_nan2 err_h012 err_h0012 comment
    0         NaN            -1          scalar_tensor_default                    initializer             NaN                          CT1s           NaN           NaN      NaN              NaN       NaN          NaN            NaN                                                               
    1         0.0            -1       x                      x             input        input             NaN       CT1s5x4         CT1s5x4           NaN           NaN      NaN              NaN       NaN          NaN            NaN                                                               
    2         1.0             0   abs_1                  abs_1  aten.abs.default          Abs             0.0       CT1s5x4         CT1s5x4  0.000000e+00  0.000000e+00      0.0              0.0       0.0     0.000128       0.006312                                                               
    3         2.0             1     exp                    exp  aten.exp.default          Exp             0.0       CT1s5x4         CT1s5x4  4.768372e-07  6.229903e-08      0.0              0.0       0.0     0.000167       0.001670                                                               
    4         4.0             3     log                    log  aten.log.default          Log             0.0       CT1s5x4         CT1s5x4  0.000000e+00  0.000000e+00      0.0              0.0       0.0     0.000066       0.007119                                                               
    5         5.0             4   add_1                 add_13   aten.add.Tensor          Add             0.0       CT1s5x4         CT1s5x4  9.536743e-07  8.820589e-08      0.0              0.0       0.0     0.000230       0.001524

A command line can also be run:

python -m onnx_diagnostic sbs -i <tensors>.input.pt \
                              --ep <exported_program>.pt2 \
                              -m <model>.onnx  \
                              -o results.xlsx \
                              -v 1 --atol=0.1 --rtol=1