Dumps intermediate results of a torch model

Looking for discrepancies is quickly annoying. Discrepancies come from two results obtained with the same models implemented in two different ways, pytorch and onnx. Models are big so where do they come from? That’s the unavoidable question. Unless there is an obvious reason, the only way is to compare intermediate outputs alon the computation. The first step into that direction is to dump the intermediate results coming from pytorch. We use onnx_diagnostic.helpers.torch_helper.steal_forward() for that.

A simple LLM Model

See onnx_diagnostic.helpers.torch_helper.dummy_llm() for its definition. It is mostly used for unit test or example.

import numpy as np
import pandas
import onnx
import torch
import onnxruntime
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ReportResultComparison


model, inputs, ds = dummy_llm(dynamic_shapes=True)

We use float16.

Let’s check.

print(f"type(model)={type(model)}")
print(f"inputs={string_type(inputs, with_shape=True)}")
print(f"ds={string_type(ds, with_shape=True)}")
type(model)=<class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.LLM'>
inputs=(T7s2x30,)
ds=dict(input_ids:{0:Dim(batch),1:Dim(length)})

It contains the following submodules.

for name, mod in model.named_modules():
    print(f"- {name}: {type(mod)}")
- : <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.LLM'>
- embedding: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.Embedding'>
- embedding.embedding: <class 'torch.nn.modules.sparse.Embedding'>
- embedding.pe: <class 'torch.nn.modules.sparse.Embedding'>
- decoder: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.DecoderLayer'>
- decoder.attention: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.MultiAttentionBlock'>
- decoder.attention.attention: <class 'torch.nn.modules.container.ModuleList'>
- decoder.attention.attention.0: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.AttentionBlock'>
- decoder.attention.attention.0.query: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.0.key: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.0.value: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.AttentionBlock'>
- decoder.attention.attention.1.query: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1.key: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1.value: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.linear: <class 'torch.nn.modules.linear.Linear'>
- decoder.feed_forward: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.FeedForward'>
- decoder.feed_forward.linear_1: <class 'torch.nn.modules.linear.Linear'>
- decoder.feed_forward.relu: <class 'torch.nn.modules.activation.ReLU'>
- decoder.feed_forward.linear_2: <class 'torch.nn.modules.linear.Linear'>
- decoder.norm_1: <class 'torch.nn.modules.normalization.LayerNorm'>
- decoder.norm_2: <class 'torch.nn.modules.normalization.LayerNorm'>

Steal and dump the output of submodules

The following context spies on the intermediate results for the following module and submodules. It stores in one onnx file all the input/output for those.

with steal_forward(
    [
        ("model", model),
        ("model.decoder", model.decoder),
        ("model.decoder.attention", model.decoder.attention),
        ("model.decoder.feed_forward", model.decoder.feed_forward),
        ("model.decoder.norm_1", model.decoder.norm_1),
        ("model.decoder.norm_2", model.decoder.norm_2),
    ],
    dump_file="plot_dump_intermediate_results.inputs.onnx",
    verbose=1,
    storage_limit=2**28,
):
    expected = model(*inputs)
+model -- stolen forward for class LLM -- iteration 0
  <- args=(T7s2x30,) --- kwargs={}
+model.decoder -- stolen forward for class DecoderLayer -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
+model.decoder.norm_1 -- stolen forward for class LayerNorm -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.norm_1.
-- stores key=('model.decoder.norm_1', 0), size 1Kb -- T10s2x30x16
+model.decoder.attention -- stolen forward for class MultiAttentionBlock -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.attention.
-- stores key=('model.decoder.attention', 0), size 1Kb -- T10s2x30x16
+model.decoder.norm_2 -- stolen forward for class LayerNorm -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.norm_2.
-- stores key=('model.decoder.norm_2', 0), size 1Kb -- T10s2x30x16
+model.decoder.feed_forward -- stolen forward for class FeedForward -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.feed_forward.
-- stores key=('model.decoder.feed_forward', 0), size 1Kb -- T10s2x30x16
  -> T10s2x30x16
-model.decoder.
-- stores key=('model.decoder', 0), size 1Kb -- T10s2x30x16
  -> T10s2x30x16
-model.
-- stores key=('model', 0), size 1Kb -- T10s2x30x16
-- gather stored 12 objects, size=0 Mb
-- dumps stored objects
-- done dump stored objects

Restores saved inputs/outputs

All the intermediate tensors were saved in one unique onnx model, every tensor is stored in a constant node. The model can be run with any runtime to restore the inputs and function create_input_tensors_from_onnx_model can restore their names.

saved_tensors = create_input_tensors_from_onnx_model(
    "plot_dump_intermediate_results.inputs.onnx"
)
for k, v in saved_tensors.items():
    print(f"{k} -- {string_type(v, with_shape=True)}")
('model', 0, 'I') -- ((T7s2x30,),{})
('model.decoder', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_1', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_1', 0, 'O') -- T10s2x30x16
('model.decoder.attention', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.attention', 0, 'O') -- T10s2x30x16
('model.decoder.norm_2', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_2', 0, 'O') -- T10s2x30x16
('model.decoder.feed_forward', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.feed_forward', 0, 'O') -- T10s2x30x16
('model.decoder', 0, 'O') -- T10s2x30x16
('model', 0, 'O') -- T10s2x30x16

Let’s explained the naming convention.

('model.decoder.norm_2', 0, 'I') -- ((T1s2x30x16,),{})
            |            |   |
            |            |   +--> input, the format is args, kwargs
            |            |
            |            +--> iteration, 0 means the first time the execution
            |                 went through that module
            |                 it is possible to call multiple times,
            |                 the model to store more
            |
            +--> the name given to function steal_forward

The same goes for output except 'I' is replaced by 'O'.

('model.decoder.norm_2', 0, 'O') -- T1s2x30x16

This trick can be used to compare intermediate results coming from pytorch to any other implementation of the same model as long as it is possible to map the stored inputs/outputs.

Conversion to ONNX

The difficult point is to be able to map the saved intermediate results to intermediate results in ONNX. Let’s create the ONNX model.

epo = torch.onnx.export(model, inputs, dynamic_shapes=ds, dynamo=True)
epo.optimize()
epo.save("plot_dump_intermediate_results.onnx")
[torch.onnx] Obtain model graph for `LLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `LLM([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 4 of general pattern rewrite rules.

Discrepancies

We have a torch model, intermediate results and an ONNX graph equivalent to the torch model. Let’s see how we can check the discrepancies. First the discrepancies of the whole model.

sess = onnxruntime.InferenceSession(
    "plot_dump_intermediate_results.onnx", providers=["CPUExecutionProvider"]
)
feeds = dict(
    zip([i.name for i in sess.get_inputs()], [t.detach().cpu().numpy() for t in inputs])
)
got = sess.run(None, feeds)
diff = max_diff(expected, got)
print(f"discrepancies torch/ORT: {string_diff(diff)}")
discrepancies torch/ORT: abs=0.00390625, rel=0.034148137590269856, n=960.0,amax=0,13,5

What about intermediate results? Let’s use a runtime still based on onnxruntime running an eager evaluation.

sess_eager = OnnxruntimeEvaluator(
    "plot_dump_intermediate_results.onnx",
    providers=["CPUExecutionProvider"],
    torch_or_numpy=True,
)
feeds_tensor = dict(zip([i.name for i in sess.get_inputs()], inputs))
got = sess_eager.run(None, feeds_tensor)
diff = max_diff(expected, got)
print(f"discrepancies torch/eager ORT: {string_diff(diff)}")
discrepancies torch/eager ORT: abs=0.001953125, rel=0.015601597603594607, n=960.0,amax=0,3,5

They are almost the same. That’s good. Let’s now dig into the intermediate results. They are compared to the outputs stored in saved_tensors during the execution of the model.

baseline = {}
for k, v in saved_tensors.items():
    if k[-1] == "I":  # inputs are excluded
        continue
    if isinstance(v, torch.Tensor):
        baseline[f"{k[0]}.{k[1]}".replace("model.decoder", "decoder")] = v

report_cmp = ReportResultComparison(baseline)
sess_eager.run(None, feeds_tensor, report_cmp=report_cmp)
[tensor([[[-7.5928e-01,  6.1914e-01,  1.7402e+00,  2.6245e-01,  2.3438e-02,
           1.3721e+00, -2.8467e-01, -1.9141e+00, -9.2871e-01,  1.3223e+00,
           1.6768e+00,  2.2227e+00, -9.5166e-01, -4.0405e-01, -3.9697e-01,
          -5.2393e-01],
         [-1.7129e+00,  5.5176e-01,  1.1904e+00,  8.2080e-01, -1.1729e+00,
          -2.5000e-01,  1.6777e+00,  3.7500e-01, -2.6230e+00, -2.0254e+00,
          -1.0889e+00,  7.2168e-01,  2.1953e+00,  2.3926e-01, -3.4729e-02,
           1.0215e+00],
         [-5.8838e-01, -9.3457e-01, -6.7090e-01,  6.8945e-01, -4.1992e-02,
           3.8184e-01,  3.2593e-01,  3.0195e+00,  1.5986e+00, -1.8066e-02,
          -6.8604e-02,  3.0195e+00, -3.1113e+00,  1.7842e+00, -3.4277e-01,
          -1.4961e+00],
         [-1.9092e+00, -1.0273e+00,  1.0234e+00,  2.1641e+00,  6.2256e-01,
          -2.6895e+00,  1.1963e-01,  6.3281e-01,  8.2764e-01, -1.4551e+00,
          -5.7227e-01,  2.6099e-01,  2.7031e+00,  3.0547e+00,  4.1504e-01,
          -7.0312e-01],
         [ 9.5801e-01, -1.7578e+00, -6.0010e-01, -1.4580e+00,  2.1621e+00,
           8.6523e-01, -1.2715e+00,  4.9829e-01, -1.0181e-01, -8.5010e-01,
          -2.3496e+00,  1.4526e-01, -7.6465e-01,  3.1406e+00,  2.3379e+00,
          -7.1289e-01],
         [-1.3389e+00,  7.5391e-01, -1.0361e+00,  9.5068e-01, -2.9805e+00,
          -9.8535e-01, -1.9180e+00, -3.2007e-01,  7.1680e-01, -8.5645e-01,
          -5.5713e-01, -9.3945e-01,  2.3999e-01, -1.5732e+00, -1.5020e+00,
           5.0354e-02],
         [ 1.7891e+00,  4.5557e-01, -1.9805e+00, -1.4033e+00, -1.6553e+00,
           7.9150e-01, -1.8281e+00, -2.8955e-01,  7.5684e-03,  1.2471e+00,
          -3.8574e-02,  2.2109e+00, -1.5459e+00, -1.4883e+00,  2.5977e+00,
          -1.1416e+00],
         [-1.3223e+00,  2.4878e-01,  1.7246e+00,  1.0479e+00,  5.0391e-01,
          -7.4219e-01, -3.2676e+00,  8.5693e-02, -1.2402e+00,  8.9160e-01,
          -1.0186e+00, -1.7832e+00,  5.8398e-01, -1.0479e+00,  3.0762e-01,
          -2.7324e+00],
         [ 2.7954e-02, -7.7686e-01, -1.1553e+00,  4.8340e-02, -1.7754e+00,
           4.6704e-01, -1.0586e+00, -1.3794e-01,  4.4873e-01, -8.4473e-01,
          -7.0654e-01, -4.4336e-01, -1.9072e+00,  2.2583e-02,  1.9062e+00,
          -1.6807e+00],
         [ 2.3203e+00, -1.2334e+00,  1.2598e-01, -1.4375e+00,  1.1707e-01,
          -6.2402e-01, -2.0898e-01,  1.6172e+00,  1.7031e+00, -4.2139e-01,
           1.0254e+00, -1.3965e-01,  7.5098e-01, -1.7646e+00,  1.5967e+00,
          -1.9150e+00],
         [ 7.6318e-01, -9.8633e-01,  1.1060e-01, -5.0146e-01,  1.2754e+00,
           7.2949e-01, -2.2422e+00, -6.7871e-01,  1.9004e+00, -9.9316e-01,
          -2.0078e+00,  3.0137e+00, -3.9429e-01, -1.5215e+00, -1.1279e+00,
          -6.5039e-01],
         [ 8.8196e-03,  1.3926e+00, -6.1230e-01, -7.5781e-01, -1.4893e-02,
          -9.8730e-01, -5.3418e-01, -1.0381e+00, -1.5869e-01, -1.8125e+00,
           1.9316e+00, -5.2612e-02,  1.6094e+00,  8.1689e-01,  9.8975e-01,
           4.3115e-01],
         [ 1.5449e+00, -1.5801e+00, -1.1533e+00,  5.3662e-01,  9.5276e-02,
          -1.1230e+00, -1.4124e-01,  9.7363e-01, -2.2168e-01, -4.7412e-01,
          -3.2520e-01,  1.7273e-01,  3.3594e-01, -1.9561e+00,  2.4170e-02,
          -1.3379e+00],
         [-1.6670e+00,  1.5840e+00, -2.3066e+00,  5.9082e-02,  2.1045e-01,
           4.6562e+00, -2.6094e+00,  1.1074e+00,  1.4189e+00,  1.2539e+00,
           4.5020e-01,  3.2871e+00, -1.9531e+00,  6.5430e-01,  1.6174e-01,
           1.1045e+00],
         [ 2.8398e+00, -1.8467e+00, -1.9031e-01,  4.4214e-01, -6.4160e-01,
           9.6973e-01,  2.1230e+00, -1.9287e+00,  5.5469e-01,  1.5127e+00,
          -3.3027e+00, -1.1113e+00,  2.3816e-01, -8.4961e-01, -1.0693e-01,
          -1.1318e+00],
         [ 2.2900e-01, -4.0469e+00,  4.9756e-01, -1.8535e+00, -4.8730e-01,
           1.0156e+00,  1.2588e+00, -7.6758e-01, -1.7764e+00,  1.3350e+00,
          -2.3242e-01, -7.0264e-01,  1.4971e+00,  1.7627e+00, -3.9551e-01,
          -2.1758e+00],
         [ 1.7891e+00,  1.0879e+00,  8.1934e-01,  1.2617e+00, -2.2188e+00,
           8.7598e-01,  6.4355e-01,  2.9419e-02,  3.9331e-01,  3.1445e+00,
          -9.3945e-01,  1.5637e-01, -1.9463e+00, -2.3574e+00, -3.9805e+00,
           1.6279e+00],
         [ 1.3838e+00, -1.4365e+00,  5.7959e-01,  1.2236e+00, -1.4043e+00,
          -2.7124e-01,  4.7363e-01, -4.2725e-03,  8.8281e-01,  1.1270e+00,
          -6.3330e-01,  1.7568e+00, -9.9951e-01, -7.1484e-01, -2.0156e+00,
          -3.3472e-01],
         [ 2.1350e-01,  3.4424e-02, -1.5469e+00,  1.3838e+00,  8.4863e-01,
           1.2129e+00, -1.3691e+00,  9.7949e-01, -3.6841e-01, -3.0566e-01,
          -2.7422e+00,  1.0605e+00, -1.3447e+00, -5.8008e-01,  7.1826e-01,
          -1.1025e+00],
         [-2.4670e-01,  7.3389e-01,  1.1777e+00, -4.5020e-01, -3.9502e-01,
          -7.1826e-01,  3.3765e-01, -3.6475e-01, -2.8066e+00,  1.2559e+00,
          -3.0640e-01, -4.1577e-01,  5.3418e-01, -8.6426e-01, -8.5059e-01,
           1.7910e+00],
         [ 1.0547e+00,  2.2441e+00, -1.3848e+00,  2.7891e+00, -1.0156e+00,
          -4.7607e-01, -8.4863e-01, -1.5605e+00, -2.3242e-01, -1.4805e+00,
           2.2485e-01, -1.0732e+00,  1.3428e-01, -1.6479e-01, -1.1045e+00,
          -1.1807e+00],
         [ 5.6396e-01,  5.9766e-01, -9.2969e-01, -2.3652e+00,  7.0166e-01,
           2.6113e+00, -1.3379e+00, -1.0098e+00, -2.0435e-01,  4.5508e-01,
           3.9233e-01, -9.6777e-01, -2.6562e+00,  5.5518e-01, -3.1787e-01,
          -3.9258e-01],
         [-4.2041e-01,  1.0371e+00, -5.0781e-01, -4.8022e-01, -5.3076e-01,
          -1.6006e+00,  1.0596e+00, -9.5557e-01,  1.9131e+00, -2.4976e-01,
          -1.2061e+00, -7.9492e-01,  1.9189e+00, -1.2207e+00,  1.5479e-01,
          -9.6289e-01],
         [-5.4785e-01, -1.3447e+00,  2.9272e-01,  3.2422e+00,  2.4727e+00,
           6.4600e-01, -7.8186e-02,  5.6982e-01, -3.1445e-01,  9.2188e-01,
          -3.2559e+00,  8.2812e-01,  8.7988e-01, -1.4717e+00, -1.5195e+00,
           1.2988e+00],
         [-3.8086e-02,  4.9219e-01, -1.7354e+00,  4.8889e-02, -3.0762e-01,
           2.4573e-01, -1.8740e+00,  8.9258e-01, -6.6699e-01, -1.4375e+00,
          -2.2383e+00, -1.7070e+00,  5.7422e-01, -7.8027e-01,  2.8125e+00,
           1.4629e+00],
         [-5.1855e-01, -7.0312e-01,  1.7490e+00,  2.0391e+00,  2.7207e+00,
          -1.9170e+00, -1.4277e+00, -1.4111e+00, -1.5518e+00,  2.0361e-01,
           1.0781e+00,  4.9927e-01,  1.6816e+00, -8.0859e-01, -2.8359e+00,
           2.3608e-01],
         [ 5.0537e-01,  6.4087e-02, -1.7383e+00, -1.7852e+00, -4.9316e-02,
           3.8477e-01,  2.3262e+00, -1.2773e+00, -2.1367e+00, -1.8135e+00,
          -1.7737e-01, -1.4785e+00, -2.1250e+00,  1.8623e+00, -2.1895e+00,
           4.7241e-01],
         [-1.2471e+00,  1.2881e+00,  5.4590e-01,  1.0840e+00,  1.6553e-01,
           8.1201e-01, -4.6692e-02,  9.8535e-01,  6.8262e-01,  2.1133e+00,
          -8.5107e-01, -1.8105e+00, -3.7939e-01,  2.3320e+00, -5.2734e-01,
           2.4451e-01],
         [ 7.6465e-01, -2.1250e+00,  8.3740e-01, -1.1738e+00,  1.1172e+00,
          -2.1602e+00, -7.2388e-02,  1.4590e+00,  1.2080e+00, -1.5352e+00,
           7.8027e-01, -1.1738e+00,  9.9805e-01,  6.0107e-01, -1.0486e-01,
          -9.4141e-01],
         [-1.1318e+00, -5.7031e-01, -7.1826e-01,  6.6748e-01, -2.3755e-01,
          -5.3027e-01,  8.7500e-01,  2.1133e+00,  5.3369e-01, -7.3145e-01,
          -9.4824e-01, -1.8535e+00,  5.4443e-01,  1.5762e+00,  1.0166e+00,
          -1.5479e-01]],

        [[ 1.1484e+00,  2.1133e+00, -6.3281e-01,  2.4902e+00, -1.0205e+00,
           2.2930e+00, -2.3267e-01,  3.0786e-01,  8.1543e-02,  5.7031e-01,
          -2.0781e+00,  2.0195e+00, -5.4883e-01, -2.2891e+00, -2.3398e+00,
          -1.8223e+00],
         [-3.8242e+00, -2.5781e-01,  6.3477e-01,  1.9199e+00,  1.3701e+00,
           3.6670e-01, -2.8271e-01, -2.0227e-01, -1.4951e+00,  1.8323e-01,
           1.2529e+00,  6.8848e-01, -4.9072e-02,  1.6055e+00, -6.7480e-01,
           2.2344e+00],
         [ 1.3643e+00, -3.0117e+00, -2.9590e+00,  1.7314e+00,  4.6606e-01,
           3.1758e+00, -1.4941e+00, -6.9629e-01, -1.9521e+00,  1.4087e-01,
           4.4849e-01,  2.6035e+00,  9.9756e-01, -3.8135e-01, -1.6963e+00,
          -7.7832e-01],
         [-4.8218e-01, -2.0781e+00, -1.5059e+00, -2.1367e+00,  3.9258e-01,
          -1.0820e+00, -8.5059e-01, -1.6309e-01, -1.9629e-01, -1.3594e+00,
           1.1240e+00,  1.1113e+00, -5.4053e-01,  3.8867e-01, -7.4268e-01,
           9.2773e-03],
         [-1.6328e+00,  3.0664e+00,  1.6289e+00,  3.9697e-01, -2.9028e-01,
           1.3877e+00, -2.0020e-02, -1.2920e+00, -1.4707e+00,  6.6895e-01,
          -1.9736e+00,  1.0967e+00,  1.3477e-01,  1.4355e+00,  2.6807e-01,
           3.9575e-01],
         [ 5.7520e-01,  2.9883e-01, -8.0225e-01, -8.8623e-01, -2.2742e-01,
          -2.0840e+00, -2.4004e+00,  4.7656e+00, -1.6855e+00, -1.9395e+00,
          -1.2578e+00,  1.0898e+00, -5.8789e-01,  5.1025e-01,  2.0195e+00,
           2.0801e+00],
         [-1.2422e+00,  1.9961e+00,  1.9263e-01, -9.8145e-01, -8.9160e-01,
          -2.5488e+00,  1.3125e+00,  4.2285e-01,  2.6016e+00, -1.0781e+00,
           1.1172e+00,  1.3252e+00,  6.9189e-01,  1.0537e+00,  9.0234e-01,
          -1.9580e-01],
         [ 2.8730e+00,  7.9346e-02, -5.0391e-01,  7.4121e-01,  5.8936e-01,
           2.1855e+00, -1.3604e+00,  4.7363e-01,  1.4268e+00,  6.1426e-01,
          -1.8457e+00,  9.9170e-01, -8.4766e-01, -1.4121e+00,  1.0195e+00,
          -2.0850e-01],
         [ 6.7285e-01, -7.3926e-01, -5.2612e-02, -5.3711e-03,  1.7461e+00,
           9.4580e-01,  6.2109e-01,  1.2769e-01, -2.3560e-01, -6.5723e-01,
          -1.1230e+00,  1.8906e+00,  4.9316e-01,  1.4795e+00, -1.5186e+00,
           6.6504e-01],
         [-6.5430e-02,  4.5679e-01,  4.4385e-01,  4.3086e+00,  2.6709e-01,
          -1.4863e+00,  2.2266e+00,  6.8506e-01,  1.2686e+00,  4.2114e-01,
          -1.0547e+00,  1.2246e+00,  2.6758e+00, -5.9668e-01,  2.9824e+00,
           1.7764e+00],
         [ 2.7266e+00,  3.3838e-01,  1.7200e-01, -1.9844e+00,  1.1224e-01,
          -1.1602e+00, -8.4326e-01, -1.6016e+00, -4.6021e-01, -2.9517e-01,
          -2.0273e+00, -1.3477e+00,  2.1399e-01,  1.9666e-01,  2.9609e+00,
          -1.2539e+00],
         [ 2.5957e+00,  1.9666e-01,  9.6484e-01,  1.4004e+00, -3.4521e-01,
           1.9473e+00, -2.7124e-01,  1.1582e+00, -2.6836e+00,  2.1074e+00,
          -2.3203e+00,  3.0078e+00,  7.6709e-01, -1.4551e-01, -2.2969e+00,
          -1.3438e+00],
         [-4.8950e-01, -8.5107e-01,  3.0742e+00, -8.2129e-01, -7.1289e-02,
          -1.4482e+00,  2.0547e+00, -1.5225e+00, -1.0566e+00, -1.8203e+00,
           1.1875e+00, -5.8472e-02,  2.4004e+00, -3.1543e+00,  2.4922e+00,
          -9.1992e-01],
         [ 4.9414e+00, -3.8635e-02,  1.2031e+00,  8.7793e-01,  2.1912e-01,
           7.0850e-01,  6.7236e-01, -3.7451e-01, -8.2031e-01,  7.5684e-01,
           7.2559e-01, -1.3711e+00, -1.1787e+00, -2.1621e+00,  2.1406e+00,
           6.6797e-01],
         [-2.0762e+00,  2.5879e-01,  4.0391e+00,  3.4692e-01,  8.8623e-01,
           1.7236e-01,  9.9951e-01,  1.8271e+00,  1.0762e+00,  8.4521e-01,
          -3.0176e-01, -9.1650e-01, -1.1816e+00, -6.1621e-01,  5.3613e-01,
          -4.3872e-01],
         [ 2.0820e+00, -8.4717e-01,  1.8494e-01,  6.0742e-01,  1.1096e-01,
           1.5039e+00,  6.2402e-01, -1.0156e+00, -1.1152e+00, -3.0000e+00,
           4.0356e-01, -5.3125e-01, -1.6611e+00, -3.1567e-01, -1.4199e+00,
          -3.4741e-01],
         [ 3.6475e-01, -3.1006e-01,  4.0991e-01,  2.6611e-02, -4.6069e-01,
           6.0840e-01,  9.3555e-01,  2.1914e+00, -2.4395e+00, -2.9648e+00,
           1.9971e+00, -7.9102e-02, -4.3945e-01,  9.8938e-02,  8.7219e-02,
          -8.1348e-01],
         [ 2.7930e+00,  9.7754e-01, -1.1133e+00, -9.7559e-01, -1.8037e+00,
          -7.0312e-01,  4.4336e-01, -1.9756e+00, -9.2383e-01,  1.1436e+00,
          -2.8242e+00, -5.9570e-02, -6.0596e-01, -1.2012e+00,  8.5059e-01,
          -1.6895e-01],
         [-9.6069e-02,  2.9961e+00,  1.7295e+00, -1.1133e+00, -7.6074e-01,
           1.8848e+00, -2.6484e+00, -1.4492e+00, -1.3525e+00,  1.3936e+00,
          -4.1919e-01,  2.5000e+00,  7.5977e-01, -3.4473e+00,  2.0371e+00,
          -4.3750e-01],
         [ 2.0371e+00,  1.1934e+00, -1.2344e+00,  1.3037e-01,  2.6680e+00,
          -4.6387e-01, -7.3926e-01, -4.5898e-01, -2.1934e+00,  8.2715e-01,
           2.0068e-01,  2.6934e+00,  3.0859e-01, -3.8135e-01, -8.4863e-01,
           1.5762e+00],
         [ 1.4685e-01, -2.7480e+00,  3.1885e-01, -1.0293e+00,  7.1143e-01,
          -7.6270e-01, -1.0332e+00, -1.9824e+00,  4.1699e-01,  1.6973e+00,
          -1.1934e+00,  1.5459e+00, -4.1943e-01,  2.0059e+00, -6.9238e-01,
           1.0107e+00],
         [ 7.4219e-01, -1.1494e+00, -1.9688e+00,  6.5527e-01, -1.7959e+00,
           2.2676e+00, -3.5430e+00,  1.1748e+00,  7.5439e-02,  1.8945e+00,
          -1.1230e+00,  1.1768e+00,  2.1621e+00, -2.1113e+00, -1.1973e+00,
          -3.4355e+00],
         [ 2.8652e+00, -1.4697e+00,  2.3203e+00, -2.9272e-01,  7.5586e-01,
           1.9512e+00, -3.4351e-01, -2.6514e-01, -1.1768e+00,  1.1182e+00,
          -3.8989e-01, -1.1436e+00, -8.6304e-02, -1.0312e+00, -8.4668e-01,
          -1.4375e+00],
         [-3.6133e-01, -4.5349e-02,  2.5020e+00,  1.1406e+00,  2.1055e+00,
           1.6943e-01,  2.0056e-01,  5.6836e-01, -9.3323e-02,  2.7441e-01,
          -7.6660e-02,  1.0225e+00,  4.2656e+00,  2.4453e+00,  1.4277e+00,
          -1.7178e+00],
         [-9.3213e-01, -2.2031e+00, -6.4648e-01, -2.4473e+00,  2.4524e-01,
          -1.2139e+00,  1.6191e+00, -3.3008e-01,  5.6445e-01,  5.8398e-01,
           1.0869e+00,  1.2402e-01, -4.2676e-01,  3.2397e-01, -7.9248e-01,
          -1.4424e+00],
         [-1.7217e+00, -3.2178e-01,  7.8906e-01, -7.0020e-01, -1.1250e+00,
           6.2402e-01, -1.1395e-01,  2.5195e+00, -5.3223e-02,  1.2891e+00,
          -2.3418e+00, -8.7891e-01, -6.9092e-02, -7.7344e-01, -1.5986e+00,
          -1.4912e+00],
         [ 8.7109e-01, -8.4619e-01, -1.6797e-01,  1.5361e+00,  1.0977e+00,
           1.0977e+00, -8.9990e-01, -1.5586e+00,  1.5771e+00, -2.6221e-01,
           8.7061e-01, -2.0332e+00, -7.7539e-01,  9.2285e-01,  1.8877e+00,
           3.7754e+00],
         [ 1.3408e+00, -6.1426e-01,  9.7949e-01, -3.6865e-01,  1.7451e+00,
           4.7913e-02,  9.4141e-01,  6.6772e-02,  6.2500e-01, -6.2988e-01,
           1.1084e+00, -3.3008e+00, -7.9688e-01,  4.2163e-01, -1.3789e+00,
           2.7734e+00],
         [ 3.8818e-02, -2.5215e+00,  2.1406e+00, -4.9829e-01, -5.0146e-01,
          -8.7207e-01,  3.3906e+00,  7.4512e-01, -3.4961e-01, -1.0918e+00,
          -9.8145e-01,  2.4487e-01,  1.5049e+00, -6.6797e-01,  3.0059e+00,
          -2.5547e+00],
         [-8.7158e-02, -1.8125e+00, -3.5684e+00,  4.5679e-01, -2.0723e+00,
          -6.5430e-01,  1.0000e+00, -8.9844e-02, -2.9688e+00,  9.8486e-01,
           2.0137e+00, -7.7344e-01,  2.0449e+00,  8.9111e-01, -1.4248e+00,
          -3.7646e-01]]], dtype=torch.float16)]

Let’s see the results.

data = report_cmp.data
df = pandas.DataFrame(data)
piv = df.pivot(index=("run_index", "run_name"), columns="ref_name", values="abs")
print(piv)
ref_name                decoder.0  decoder.attention.0  decoder.feed_forward.0  decoder.norm_1.0  decoder.norm_2.0   model.0
run_index run_name
1         embedding      3.610352             3.648682                3.727091          2.897949          2.694824  3.610352
2         embedding_1    3.890381             3.751709                3.790283          2.792480          2.880859  3.890381
3         add_8          1.029297             4.624146                4.740356          2.097656          2.099609  1.029297
4         layer_norm     2.447266             2.893860                3.274902          0.000000          0.565918  2.447266
5         linear         4.568848             1.706055                1.953369          3.262695          3.194336  4.568848
6         linear_1       5.594238             2.059937                2.334961          3.664062          3.698242  5.594238
7         linear_2       6.108398             1.869141                1.727783          3.930664          3.985352  6.108398
24        matmul_1       5.087769             1.447021                1.320801          3.236816          3.227051  5.087769
25        linear_3       5.139648             1.672363                1.940796          3.065430          3.053711  5.139648
26        linear_4       5.377441             1.771362                2.413574          3.622070          3.676758  5.377441
27        linear_5       5.612305             1.799927                1.863525          3.434570          3.489258  5.612305
42        matmul_3       5.140137             1.373413                1.651855          3.126221          3.116455  5.140137
44        val_48         4.887329             0.145508                1.208618          2.793091          2.796997  4.887329
45        linear_6       4.795898             0.000488                1.121338          2.893860          2.897766  4.795898
46        add_115        0.967285             4.768677                4.884888          2.175781          2.177734  0.967285
47        layer_norm_1   2.388672             2.897766                3.208496          0.566162          0.001953  2.388672
51        val_54         4.840149             1.115479                0.078125          3.197266          3.130859  4.840149
52        linear_8       4.912292             1.121338                0.000488          3.274902          3.208496  4.912292
53        add_136        0.001953             4.796021                4.912231          2.447266          2.388672  0.001953

Let’s clean a little bit.

piv[piv >= 1] = np.nan
print(piv.dropna(axis=0, how="all"))
ref_name                decoder.0  decoder.attention.0  decoder.feed_forward.0  decoder.norm_1.0  decoder.norm_2.0   model.0
run_index run_name
4         layer_norm          NaN                  NaN                     NaN          0.000000          0.565918       NaN
44        val_48              NaN             0.145508                     NaN               NaN               NaN       NaN
45        linear_6            NaN             0.000488                     NaN               NaN               NaN       NaN
46        add_115        0.967285                  NaN                     NaN               NaN               NaN  0.967285
47        layer_norm_1        NaN                  NaN                     NaN          0.566162          0.001953       NaN
51        val_54              NaN                  NaN                0.078125               NaN               NaN       NaN
52        linear_8            NaN                  NaN                0.000488               NaN               NaN       NaN
53        add_136        0.001953                  NaN                     NaN               NaN               NaN  0.001953

We can identity which results is mapped to which expected tensor.

Picture of the model

onx = onnx.load("plot_dump_intermediate_results.onnx")
plot_dot(onx)
plot dump intermediate results
doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue")
plot dump intermediate results

Total running time of the script: (0 minutes 7.864 seconds)

Related examples

Export microsoft/phi-2

Export microsoft/phi-2

Intermediate results with onnxruntime

Intermediate results with onnxruntime

Find where a model is failing by running submodels

Find where a model is failing by running submodels

Gallery generated by Sphinx-Gallery