Dumps intermediate results of a torch model

Looking for discrepancies is quickly annoying. Discrepancies come from two results obtained with the same models implemented in two different ways, pytorch and onnx. Models are big so where do they come from? That’s the unavoidable question. Unless there is an obvious reason, the only way is to compare intermediate outputs alon the computation. The first step into that direction is to dump the intermediate results coming from pytorch. We use onnx_diagnostic.helpers.torch_helper.steal_forward() for that.

A simple LLM Model

See onnx_diagnostic.helpers.torch_helper.dummy_llm() for its definition. It is mostly used for unit test or example.

import numpy as np
import pandas
import onnx
import torch
import onnxruntime
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ReportResultComparison


model, inputs, ds = dummy_llm(dynamic_shapes=True)

We use float16.

Let’s check.

print(f"type(model)={type(model)}")
print(f"inputs={string_type(inputs, with_shape=True)}")
print(f"ds={string_type(ds, with_shape=True)}")
type(model)=<class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.LLM'>
inputs=(T7s2x30,)
ds=dict(input_ids:{0:Dim(batch),1:Dim(length)})

It contains the following submodules.

for name, mod in model.named_modules():
    print(f"- {name}: {type(mod)}")
- : <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.LLM'>
- embedding: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.Embedding'>
- embedding.embedding: <class 'torch.nn.modules.sparse.Embedding'>
- embedding.pe: <class 'torch.nn.modules.sparse.Embedding'>
- decoder: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.DecoderLayer'>
- decoder.attention: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.MultiAttentionBlock'>
- decoder.attention.attention: <class 'torch.nn.modules.container.ModuleList'>
- decoder.attention.attention.0: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.AttentionBlock'>
- decoder.attention.attention.0.query: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.0.key: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.0.value: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.AttentionBlock'>
- decoder.attention.attention.1.query: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1.key: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1.value: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.linear: <class 'torch.nn.modules.linear.Linear'>
- decoder.feed_forward: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.FeedForward'>
- decoder.feed_forward.linear_1: <class 'torch.nn.modules.linear.Linear'>
- decoder.feed_forward.relu: <class 'torch.nn.modules.activation.ReLU'>
- decoder.feed_forward.linear_2: <class 'torch.nn.modules.linear.Linear'>
- decoder.norm_1: <class 'torch.nn.modules.normalization.LayerNorm'>
- decoder.norm_2: <class 'torch.nn.modules.normalization.LayerNorm'>

Steal and dump the output of submodules

The following context spies on the intermediate results for the following module and submodules. It stores in one onnx file all the input/output for those.

with steal_forward(
    [
        ("model", model),
        ("model.decoder", model.decoder),
        ("model.decoder.attention", model.decoder.attention),
        ("model.decoder.feed_forward", model.decoder.feed_forward),
        ("model.decoder.norm_1", model.decoder.norm_1),
        ("model.decoder.norm_2", model.decoder.norm_2),
    ],
    dump_file="plot_dump_intermediate_results.inputs.onnx",
    verbose=1,
    storage_limit=2**28,
):
    expected = model(*inputs)
+model -- stolen forward for class LLM -- iteration 0
  <- args=(T7s2x30,) --- kwargs={}
+model.decoder -- stolen forward for class DecoderLayer -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
+model.decoder.norm_1 -- stolen forward for class LayerNorm -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.norm_1.
-- stores key=('model.decoder.norm_1', 0), size 1Kb -- T10s2x30x16
+model.decoder.attention -- stolen forward for class MultiAttentionBlock -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.attention.
-- stores key=('model.decoder.attention', 0), size 1Kb -- T10s2x30x16
+model.decoder.norm_2 -- stolen forward for class LayerNorm -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.norm_2.
-- stores key=('model.decoder.norm_2', 0), size 1Kb -- T10s2x30x16
+model.decoder.feed_forward -- stolen forward for class FeedForward -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.feed_forward.
-- stores key=('model.decoder.feed_forward', 0), size 1Kb -- T10s2x30x16
  -> T10s2x30x16
-model.decoder.
-- stores key=('model.decoder', 0), size 1Kb -- T10s2x30x16
  -> T10s2x30x16
-model.
-- stores key=('model', 0), size 1Kb -- T10s2x30x16
-- gather stored 12 objects, size=0 Mb
-- dumps stored objects
-- done dump stored objects

Restores saved inputs/outputs

All the intermediate tensors were saved in one unique onnx model, every tensor is stored in a constant node. The model can be run with any runtime to restore the inputs and function create_input_tensors_from_onnx_model can restore their names.

saved_tensors = create_input_tensors_from_onnx_model(
    "plot_dump_intermediate_results.inputs.onnx"
)
for k, v in saved_tensors.items():
    print(f"{k} -- {string_type(v, with_shape=True)}")
('model', 0, 'I') -- ((T7s2x30,),{})
('model.decoder', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_1', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_1', 0, 'O') -- T10s2x30x16
('model.decoder.attention', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.attention', 0, 'O') -- T10s2x30x16
('model.decoder.norm_2', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_2', 0, 'O') -- T10s2x30x16
('model.decoder.feed_forward', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.feed_forward', 0, 'O') -- T10s2x30x16
('model.decoder', 0, 'O') -- T10s2x30x16
('model', 0, 'O') -- T10s2x30x16

Let’s explained the naming convention.

('model.decoder.norm_2', 0, 'I') -- ((T1s2x30x16,),{})
            |            |   |
            |            |   +--> input, the format is args, kwargs
            |            |
            |            +--> iteration, 0 means the first time the execution
            |                 went through that module
            |                 it is possible to call multiple times,
            |                 the model to store more
            |
            +--> the name given to function steal_forward

The same goes for output except 'I' is replaced by 'O'.

('model.decoder.norm_2', 0, 'O') -- T1s2x30x16

This trick can be used to compare intermediate results coming from pytorch to any other implementation of the same model as long as it is possible to map the stored inputs/outputs.

Conversion to ONNX

The difficult point is to be able to map the saved intermediate results to intermediate results in ONNX. Let’s create the ONNX model.

ep = torch.export.export(model, inputs, dynamic_shapes=ds)
epo = torch.onnx.export(ep)
epo.optimize()
epo.save("plot_dump_intermediate_results.onnx")
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 4 of general pattern rewrite rules.

Discrepancies

We have a torch model, intermediate results and an ONNX graph equivalent to the torch model. Let’s see how we can check the discrepancies. First the discrepancies of the whole model.

sess = onnxruntime.InferenceSession(
    "plot_dump_intermediate_results.onnx", providers=["CPUExecutionProvider"]
)
feeds = dict(
    zip([i.name for i in sess.get_inputs()], [t.detach().cpu().numpy() for t in inputs])
)
got = sess.run(None, feeds)
diff = max_diff(expected, got)
print(f"discrepancies torch/ORT: {string_diff(diff)}")
discrepancies torch/ORT: abs=0.001953125, rel=0.054789607905642336, n=960.0,amax=0,2,1

What about intermediate results? Let’s use a runtime still based on onnxruntime running an eager evaluation.

sess_eager = OnnxruntimeEvaluator(
    "plot_dump_intermediate_results.onnx",
    providers=["CPUExecutionProvider"],
    torch_or_numpy=True,
)
feeds_tensor = dict(zip([i.name for i in sess.get_inputs()], inputs))
got = sess_eager.run(None, feeds_tensor)
diff = max_diff(expected, got)
print(f"discrepancies torch/eager ORT: {string_diff(diff)}")
discrepancies torch/eager ORT: abs=0.001953125, rel=0.05526083112290009, n=960.0,amax=0,4,7

They are almost the same. That’s good. Let’s now dig into the intermediate results. They are compared to the outputs stored in saved_tensors during the execution of the model.

baseline = {}
for k, v in saved_tensors.items():
    if k[-1] == "I":  # inputs are excluded
        continue
    if isinstance(v, torch.Tensor):
        baseline[f"{k[0]}.{k[1]}".replace("model.decoder", "decoder")] = v

report_cmp = ReportResultComparison(baseline)
sess_eager.run(None, feeds_tensor, report_cmp=report_cmp)
[tensor([[[-1.4600e+00, -1.6279e+00, -2.0449e+00,  1.8223e+00, -1.0059e-01,
           5.5420e-01, -1.3105e+00, -3.3545e-01, -1.4082e+00,  1.7793e+00,
           1.3545e+00, -8.2178e-01,  3.1738e-03, -1.4795e+00, -3.7305e-01,
          -1.2573e-01],
         [ 1.2139e+00, -1.4639e+00, -2.0684e+00,  1.4570e+00, -1.1475e-02,
           2.0349e-01, -1.9082e+00,  5.0488e-01,  4.4751e-01, -5.1807e-01,
           9.8145e-01,  6.3281e-01, -5.2051e-01, -1.0684e+00,  1.2471e+00,
           1.3086e-01],
         [-1.4111e-01, -2.8516e+00, -1.5205e+00, -4.7485e-02,  2.0332e+00,
           1.9482e+00,  2.0850e-01, -1.8594e+00, -1.6709e+00, -8.4473e-01,
          -1.8340e+00, -1.8008e+00, -1.2817e-03, -1.2695e+00, -2.6094e+00,
           2.8760e-01],
         [-1.2158e-01, -1.4365e+00,  1.4746e+00,  1.0879e+00, -2.2930e+00,
          -1.0527e+00, -3.9868e-01,  2.9956e-01,  6.8311e-01,  5.7715e-01,
           6.2561e-02,  1.8018e+00, -8.3691e-01, -7.8125e-01, -3.0469e+00,
          -2.0938e+00],
         [ 2.1948e-01, -2.2988e+00, -5.0732e-01,  2.4243e-01,  5.5811e-01,
          -2.8076e-01, -8.0908e-01, -3.1113e+00, -1.7471e+00, -2.3682e-01,
          -1.9639e+00,  1.2109e-01,  7.8809e-01, -2.4585e-01, -1.4053e+00,
           1.4980e+00],
         [-9.9805e-01, -2.6816e+00, -4.3604e-01, -1.0527e+00,  1.9746e+00,
          -6.0400e-01, -1.3535e+00, -2.4570e+00, -5.4395e-01, -1.8643e+00,
          -3.0605e+00,  1.1973e+00,  1.2168e+00,  5.7715e-01,  9.1309e-01,
          -1.1172e+00],
         [ 3.0469e+00, -5.6982e-01, -8.0176e-01,  1.8574e+00, -1.2383e+00,
          -2.1035e+00, -4.5410e-02,  1.6543e+00,  5.0659e-03, -2.5742e+00,
          -1.4707e+00, -1.4688e+00,  2.0234e+00,  4.6411e-01, -1.8057e+00,
          -1.1025e+00],
         [ 1.1152e+00,  1.3721e-01, -1.1182e+00,  1.1182e+00,  9.7559e-01,
           2.2520e+00, -2.2461e-01, -1.1279e-01, -4.2041e-01,  1.9023e+00,
           1.3159e-01, -1.0859e+00,  5.4443e-01,  2.3887e+00,  1.1406e+00,
          -2.0410e+00],
         [ 9.1699e-01, -1.3281e+00,  2.3828e+00,  2.8931e-02, -1.4734e-01,
          -1.7705e+00,  5.5762e-01, -4.0039e-02,  1.3213e+00,  1.6510e-02,
           1.6855e+00, -8.5156e-01, -3.2031e+00, -1.1383e-01, -1.8652e+00,
          -1.0498e+00],
         [-1.0439e+00,  1.9424e+00, -9.1895e-01,  1.0430e+00,  2.1094e+00,
          -2.9736e-01,  2.4180e+00, -6.0547e-01,  3.3276e-01, -2.8730e+00,
          -3.0640e-01, -1.3086e+00, -1.7402e+00, -2.1465e+00, -1.8730e+00,
           4.7656e-01],
         [-9.2383e-01,  6.8970e-02, -1.4912e+00,  8.4863e-01,  5.1172e-01,
          -5.6152e-02, -1.7998e+00, -2.9336e+00, -1.8994e-01,  2.8540e-01,
          -7.1289e-01,  1.1353e-01, -2.3457e+00,  2.4243e-01, -2.2241e-01,
           1.4287e+00],
         [ 4.0820e-01, -4.0527e-01,  1.6650e+00,  1.8965e+00, -2.4043e+00,
           6.5430e-01, -6.6943e-01,  2.0020e-02,  1.3496e+00, -4.8047e-01,
           1.7656e+00, -1.6919e-01, -3.7964e-01, -1.7676e+00, -5.0781e-01,
          -7.1924e-01],
         [ 7.0117e-01, -2.1523e+00, -5.3711e-01,  2.3059e-01, -5.5566e-01,
          -3.9331e-01,  1.9165e-01, -1.9473e+00,  2.3477e+00,  3.6816e-01,
          -1.9902e+00, -2.5757e-02, -1.5537e+00,  6.3672e-01, -1.3398e+00,
          -1.7549e+00],
         [ 1.1025e+00, -5.9180e-01,  9.0234e-01,  1.0010e+00, -3.4180e-01,
           1.7031e+00, -3.0781e+00,  1.1484e+00,  9.0283e-01,  9.9609e-01,
          -3.7148e+00, -7.5635e-01,  1.9912e+00,  1.9556e-01, -1.6055e+00,
          -6.2744e-01],
         [-5.8789e-01, -3.2461e+00, -9.4873e-01,  1.1729e+00,  2.4727e+00,
           1.2471e+00, -5.1910e-02, -1.1348e+00,  1.1162e+00,  9.6680e-01,
           1.5918e+00,  3.6841e-01,  1.6045e+00, -1.9775e+00, -1.0020e+00,
          -2.0703e+00],
         [ 2.2510e-01, -1.7988e+00,  1.0332e+00, -3.5254e-01,  4.2041e-01,
          -2.8945e+00, -2.1924e-01, -1.1078e-01,  1.3154e+00,  9.1602e-01,
           7.9785e-01, -8.0444e-02,  1.3857e+00, -7.7148e-01, -1.3340e+00,
           3.4863e+00],
         [-1.1836e+00, -6.6699e-01,  1.2705e+00, -8.0469e-01,  1.3984e+00,
           1.3008e+00, -1.1841e-01, -4.7388e-01,  3.3926e+00, -1.5039e+00,
          -5.2148e-01, -4.9390e-01, -2.8496e+00, -1.5586e+00, -2.6641e+00,
          -1.1846e+00],
         [-4.8901e-01,  3.6230e-01, -6.8555e-01,  5.8594e-03,  4.6631e-01,
          -1.6318e+00, -8.4473e-01, -1.7407e-01, -2.2227e+00,  1.2617e+00,
          -6.3525e-01,  2.1699e+00, -1.2344e+00,  7.0459e-01,  1.9531e-03,
          -5.7617e-01],
         [-3.4131e-01,  2.0586e+00, -4.7974e-01,  1.5342e+00, -9.8242e-01,
           6.1084e-01, -3.0225e-01,  2.2246e+00,  6.9922e-01,  6.5820e-01,
           1.6797e+00,  5.8496e-01,  2.1367e+00,  1.0419e-01,  7.9395e-01,
          -1.1045e+00],
         [-4.7046e-01, -2.8101e-01, -1.4043e+00, -1.2158e-01, -1.1406e+00,
          -1.5771e+00, -9.9512e-01,  3.1309e+00,  3.8306e-01,  1.8154e+00,
           1.2402e+00, -1.6296e-01,  2.4961e+00,  1.2524e-01, -6.3965e-01,
          -5.1855e-01],
         [ 9.5898e-01, -1.5840e+00,  1.1445e+00,  1.1953e+00,  4.8389e-01,
          -4.7217e-01, -4.6777e-01,  1.3564e+00, -1.3711e+00, -2.3120e-01,
          -1.2568e+00,  1.4619e+00,  7.0654e-01,  8.9600e-02, -2.8379e+00,
          -4.4971e-01],
         [ 3.3066e+00, -3.8809e+00,  7.8516e-01,  1.5830e+00, -4.5605e-01,
          -1.8457e-01, -1.0674e+00,  3.4785e+00,  2.5254e+00,  2.0154e-01,
           7.2363e-01, -2.9199e-01,  1.6191e+00, -1.2520e+00, -2.5254e+00,
           2.0039e+00],
         [ 6.9531e-01, -2.5562e-01,  5.5469e-01,  1.3408e+00, -3.4399e-01,
          -4.0332e-01,  1.1836e+00, -2.8828e+00,  9.2383e-01,  1.9639e+00,
           9.4824e-01, -4.7388e-01,  1.6035e+00,  3.2227e+00,  1.7627e+00,
           8.7402e-01],
         [ 1.6296e-01, -6.0352e-01,  1.5254e+00,  9.8438e-01, -5.2002e-01,
          -8.9209e-01,  2.2437e-01,  2.8750e+00,  7.1094e-01,  2.0547e+00,
          -3.3008e+00,  1.1758e+00, -8.3350e-01,  4.1870e-02, -2.0039e+00,
          -3.1973e+00],
         [ 2.5928e-01, -8.4424e-01,  1.9189e+00,  6.9141e-01,  8.5889e-01,
           1.2861e+00, -2.0391e+00, -1.3379e+00, -1.0225e+00,  5.2148e-01,
          -8.9160e-01,  8.0566e-02,  3.1172e+00, -1.1699e+00,  4.8901e-01,
          -1.3750e+00],
         [ 1.8096e+00,  8.4814e-01,  1.8652e-01, -2.9844e+00,  7.3389e-01,
          -4.0967e-01, -1.8616e-01,  6.3818e-01,  1.2451e-01,  1.3057e+00,
           2.4277e+00,  1.2041e+00, -2.9863e+00, -1.0430e+00,  8.2031e-01,
          -2.1309e+00],
         [ 1.7998e+00,  1.6418e-01, -5.1270e-01,  1.8203e+00,  5.9375e-01,
          -2.0977e+00, -6.9238e-01, -2.5859e+00,  1.1836e+00, -1.1211e+00,
          -1.4307e+00, -2.2480e+00, -1.0028e-01, -1.6416e+00,  5.7422e-01,
           7.2559e-01],
         [ 3.6621e-01,  6.8970e-02,  1.1377e+00,  8.5449e-03, -1.6230e+00,
           1.9666e-01,  1.9385e-01, -3.8672e+00, -5.4590e-01,  1.3389e+00,
           1.9385e+00,  2.3413e-01, -9.7070e-01, -2.9443e-01,  3.8940e-01,
           1.2168e+00],
         [ 9.4629e-01, -2.2500e+00, -2.2207e+00,  4.4043e-01,  1.8506e+00,
          -1.9434e+00, -3.7695e-01,  7.9199e-01,  6.1719e-01, -7.6660e-01,
          -1.4414e+00, -1.5898e+00,  7.0740e-02,  2.3242e-01, -2.2217e-01,
           4.6143e-02],
         [ 6.8115e-01, -2.3999e-01,  5.4199e-01,  1.3516e+00, -3.5889e-01,
          -4.2578e-01,  1.1631e+00, -2.8496e+00,  9.3701e-01,  1.9434e+00,
           1.0156e+00, -4.7852e-01,  1.5898e+00,  3.2383e+00,  1.7363e+00,
           8.7793e-01]],

        [[ 1.5791e+00,  3.2104e-02, -3.0391e+00,  1.4512e+00, -2.3945e+00,
          -4.0117e+00,  1.7715e+00, -8.9941e-01, -1.5996e+00,  9.7168e-01,
           1.6318e+00,  1.8594e+00, -2.3848e+00, -1.4287e+00,  6.5674e-01,
           1.1240e+00],
         [ 9.6387e-01,  5.4382e-02, -1.5625e+00,  1.0410e+00,  7.5049e-01,
           2.4570e+00, -3.1348e-01, -9.2773e-03, -3.7207e-01,  2.1230e+00,
           2.7466e-01, -1.2627e+00,  3.3496e-01,  2.5859e+00,  8.7988e-01,
          -2.1465e+00],
         [ 4.6606e-01,  1.1890e-01, -3.1875e+00, -7.6465e-01,  1.6240e+00,
          -7.2559e-01, -6.9043e-01, -1.1641e+00,  7.8662e-01,  2.6245e-01,
           2.6738e+00,  1.3672e+00, -5.7959e-01,  1.1807e+00,  1.3672e-02,
          -2.2031e+00],
         [ 2.8887e+00,  6.2061e-01, -1.1582e+00, -3.1191e+00,  2.0276e-01,
          -4.5679e-01, -7.2510e-01, -5.3711e-01,  2.0957e+00, -5.2441e-01,
          -1.1221e+00,  1.8877e+00, -2.3047e+00,  1.2871e+00, -1.9297e+00,
           1.1836e+00],
         [ 7.1289e-01, -1.8713e-01,  6.3525e-01,  2.3389e-01,  2.5854e-01,
           3.4424e-02, -1.3857e+00,  9.6094e-01,  1.2646e-01,  2.5840e+00,
           7.2510e-02, -6.3599e-02,  2.9785e-02,  2.1914e+00,  8.5840e-01,
           2.7969e+00],
         [ 9.2773e-01,  1.9775e-02,  1.6064e+00, -5.0195e-01, -2.0918e+00,
           2.5098e-01,  1.3975e+00,  2.1777e+00, -1.9141e+00, -4.8364e-01,
           2.2437e-01,  7.3096e-01, -1.2578e+00, -1.5684e+00, -9.7949e-01,
          -9.9805e-01],
         [ 1.5801e+00, -2.2034e-01, -1.8982e-01, -1.4343e-01,  1.1162e+00,
           5.8984e-01, -3.0176e-01,  5.8594e-01,  1.5000e+00,  9.4116e-02,
           4.6338e-01,  1.5127e+00, -3.4082e-01,  1.2373e+00, -1.5791e+00,
          -1.7412e+00],
         [ 3.4883e+00, -8.1982e-01,  1.3301e+00, -2.1924e-01, -1.6982e+00,
          -3.9673e-01, -1.2769e-01, -1.4819e-01, -1.6221e+00,  2.2852e+00,
          -6.0107e-01,  1.1406e+00, -5.1172e-01,  1.1816e+00,  4.7168e-01,
           2.8320e-02],
         [ 7.2363e-01, -6.7041e-01,  6.0254e-01,  8.2031e-02, -6.5918e-01,
          -7.3193e-01,  2.6318e-01,  6.0986e-01,  1.2930e+00, -2.1008e-01,
           1.3545e+00, -1.5625e-02, -9.1492e-02, -2.0752e-02,  1.1777e+00,
          -1.2002e+00],
         [ 1.7500e+00, -8.2520e-01,  8.4375e-01, -1.8567e-01, -1.0127e+00,
          -1.7383e+00,  2.7773e+00,  7.6367e-01, -4.0747e-01, -6.5479e-01,
          -1.1670e+00,  2.8926e+00, -3.3984e+00,  3.3008e-01, -2.0020e-01,
           4.7729e-02],
         [-5.5566e-01, -4.0918e-01,  3.1367e+00, -1.0039e+00, -9.4922e-01,
           1.7334e-01, -2.5938e+00, -8.7012e-01,  2.4792e-01,  1.3477e+00,
          -3.0273e-01, -1.7275e+00, -8.1445e-01, -7.3438e-01,  7.5977e-01,
          -1.4121e+00],
         [ 3.5449e-01, -9.4385e-01, -8.1104e-01,  2.7295e-01,  4.8706e-01,
           1.2656e+00, -7.9883e-01, -2.2695e+00,  1.1113e+00, -1.1975e-01,
           1.5928e+00, -1.1318e+00, -1.2036e-01,  1.8535e+00,  1.6641e+00,
           1.4854e+00],
         [-5.8203e-01, -3.8818e-01,  3.1055e+00, -9.5508e-01, -1.0146e+00,
           9.0820e-02, -2.6309e+00, -8.7256e-01,  2.5757e-01,  1.3477e+00,
          -2.6465e-01, -1.7109e+00, -8.1592e-01, -7.6611e-01,  7.5195e-01,
          -1.4209e+00],
         [ 1.2168e+00, -5.4688e-01, -1.0156e+00,  1.0469e+00,  5.3223e-01,
          -6.0889e-01, -5.4590e-01, -1.9482e+00,  2.7319e-01,  1.2061e+00,
          -6.8848e-01, -1.3711e+00,  3.0078e-01, -8.0176e-01,  1.5205e+00,
           4.3628e-01],
         [ 3.0859e-01, -2.8867e+00,  1.0469e+00,  2.1973e-01, -3.3613e+00,
          -1.8057e+00,  1.3984e+00, -3.5889e-02,  1.3770e+00,  2.0586e+00,
          -1.0830e+00,  1.9424e+00, -2.8848e+00,  1.9355e+00, -1.0254e+00,
          -1.8389e+00],
         [ 2.0625e+00,  6.9580e-03, -2.4688e+00, -5.4688e-01, -1.5796e-01,
           5.2344e-01, -5.2734e-01,  4.3921e-01,  1.8042e-01,  4.9341e-01,
           2.6836e+00,  2.3711e+00,  1.1426e-01,  4.0479e-01,  1.6934e+00,
           2.4438e-01],
         [-4.8828e-03,  1.3745e-01,  1.8970e-01,  1.3467e+00, -1.0898e+00,
           6.1475e-01,  3.1787e-01,  2.3848e+00, -9.7998e-01, -5.4785e-01,
           9.2285e-02, -2.6147e-01, -1.0332e+00, -7.2217e-01,  1.4868e-01,
           2.5078e+00],
         [ 1.5898e+00,  3.6230e-01,  1.0898e+00, -8.5840e-01, -6.9873e-01,
           1.4922e+00, -5.0879e-01,  1.2588e+00,  2.7295e-01,  2.0288e-01,
          -1.4557e-02, -1.6333e-01,  3.1396e-01,  1.4453e+00, -1.3047e+00,
           1.7734e+00],
         [-3.8379e-01,  1.1084e+00, -2.8730e+00,  7.4170e-01, -2.4492e+00,
          -2.5293e-01,  6.1493e-02, -1.3799e+00,  2.2192e-01, -2.6719e+00,
          -8.4521e-01,  1.6602e+00, -3.4180e+00,  1.3711e+00, -1.8579e-01,
           1.8301e+00],
         [-2.1172e+00,  5.0586e-01,  1.2324e+00,  3.3105e+00,  1.0107e+00,
           1.5161e-01,  2.5723e+00, -8.3691e-01,  4.7217e-01, -4.8999e-01,
           1.3594e+00,  3.9331e-01, -7.6709e-01, -1.1250e+00, -2.0117e+00,
          -1.9590e+00],
         [ 1.2754e+00, -2.2498e-01,  1.3613e+00, -2.0449e+00,  6.3086e-01,
          -1.0771e+00,  7.6172e-01,  2.8052e-01,  9.4336e-01,  6.9189e-01,
           1.3105e+00,  6.1768e-01, -2.4219e-01, -2.0386e-01,  2.1328e+00,
          -5.2783e-01],
         [ 1.1895e+00, -2.4785e+00, -5.3809e-01,  1.5625e-01, -5.9424e-01,
          -1.1982e+00,  1.0400e+00, -9.1113e-01,  1.6133e+00,  2.6465e+00,
          -2.1445e+00,  1.0986e+00, -1.1475e+00,  2.1973e+00,  1.8320e+00,
           1.9053e+00],
         [ 5.7666e-01, -3.5449e-01, -7.5586e-01,  3.0723e+00, -1.0625e+00,
          -1.3936e+00,  1.9072e+00,  5.8643e-01,  2.7070e+00,  1.1094e+00,
          -2.2012e+00,  2.4590e+00,  2.2388e-01,  5.2637e-01, -4.3115e-01,
          -6.7725e-01],
         [ 1.2236e+00, -1.4775e+00, -3.3887e-01,  3.0786e-01,  1.2178e+00,
          -2.6777e+00,  1.2373e+00,  1.0088e+00, -1.8096e+00,  5.7812e-01,
          -1.2168e+00,  1.2168e+00,  3.1172e+00,  3.1582e+00, -8.1836e-01,
          -1.3193e+00],
         [ 1.2930e+00, -5.8203e-01, -9.0625e-01,  9.4531e-01,  6.4795e-01,
          -5.7275e-01, -5.2051e-01, -1.9736e+00,  2.6611e-01,  1.1641e+00,
          -7.6074e-01, -1.4102e+00,  3.4424e-01, -7.6855e-01,  1.6592e+00,
           4.0479e-01],
         [ 2.7227e+00, -9.8633e-01, -9.0234e-01, -1.2861e+00, -2.4746e+00,
          -1.6943e+00,  5.7520e-01, -1.2275e+00, -1.9834e+00,  1.8789e+00,
           3.5840e+00,  4.7412e-01, -1.4150e+00,  5.5371e-01,  1.8604e+00,
           8.4082e-01],
         [-5.4688e-01,  5.7764e-01, -8.5547e-01,  2.6855e-02,  3.6621e-01,
          -1.5742e+00, -7.8027e-01, -2.9590e-01, -2.1133e+00,  1.1348e+00,
          -3.7305e-01,  2.2383e+00, -1.4717e+00,  5.0391e-01,  4.0222e-02,
          -6.2158e-01],
         [-6.7285e-01, -6.9971e-01,  1.2012e+00, -1.5405e-01, -2.1016e+00,
          -2.1758e+00, -1.4844e+00, -4.6436e-01, -4.7192e-01,  7.6855e-01,
           1.0332e+00, -1.0537e+00,  4.1455e-01,  1.5127e+00,  2.3027e+00,
           1.4316e+00],
         [ 4.2090e-01,  1.2178e+00, -1.2085e-01, -1.7256e+00,  1.4268e+00,
           3.1909e-01,  7.5098e-01, -8.5742e-01, -7.0557e-01, -6.8652e-01,
           1.5352e+00,  1.4233e-01, -1.3213e+00, -8.8037e-01,  2.9414e+00,
           1.2793e+00],
         [-2.6428e-02,  7.8467e-01,  1.7939e+00, -1.7891e+00,  1.2559e+00,
          -2.7466e-01,  1.7969e+00,  1.4062e+00,  1.8184e+00, -2.0332e+00,
           8.3594e-01, -8.1396e-01,  2.4707e+00,  3.7451e-01, -1.6777e+00,
          -6.2891e-01]]], dtype=torch.float16)]

Let’s see the results.

data = report_cmp.data
df = pandas.DataFrame(data)
piv = df.pivot(index=("run_index", "run_name"), columns="ref_name", values="abs")
print(piv)
ref_name                decoder.0  decoder.attention.0  decoder.feed_forward.0  decoder.norm_1.0  decoder.norm_2.0   model.0
run_index run_name
1         embedding      2.988281             3.177734                3.321289          2.795898          2.767578  2.988281
2         embedding_1    3.465393             2.936523                3.054443          2.400269          2.413940  3.465393
3         add_8          0.950195             4.609375                4.033447          2.177734          2.289062  0.950195
4         layer_norm     1.970703             2.946106                2.804688          0.000000          0.771484  1.970703
5         linear         4.901367             2.021484                1.908813          3.631836          3.657227  4.901367
6         linear_1       4.474609             2.086914                1.990967          3.713867          3.666992  4.474609
7         linear_2       5.083008             1.811035                1.806396          3.571289          3.594727  5.083008
17        matmul_1       5.083008             1.515137                1.426392          3.112305          3.000977  5.083008
18        linear_3       4.153320             2.042358                2.004883          3.366211          3.497070  4.153320
19        linear_4       4.126709             1.947067                2.140137          3.583008          3.585938  4.126709
20        linear_5       4.644531             2.136963                2.065918          3.437012          3.404297  4.644531
30        matmul_3       3.957764             1.836914                1.677734          3.216797          3.230469  3.957764
32        val_48         4.333008             0.163635                1.057129          2.890442          2.865051  4.333008
33        linear_6       4.402344             0.000488                1.055420          2.946106          2.920715  4.402344
34        add_115        0.954163             4.218750                4.183350          1.953125          1.955078  0.954163
35        layer_norm_1   2.082031             2.920715                2.763672          0.771484          0.001953  2.082031
39        val_54         3.901001             1.033890                0.087646          2.749878          2.710449  3.901001
40        linear_8       3.948975             1.055481                0.000488          2.804688          2.763672  3.948975
41        add_136        0.001953             4.402344                3.948975          1.970703          2.082031  0.001953

Let’s clean a little bit.

piv[piv >= 1] = np.nan
print(piv.dropna(axis=0, how="all"))
ref_name                decoder.0  decoder.attention.0  decoder.feed_forward.0  decoder.norm_1.0  decoder.norm_2.0   model.0
run_index run_name
3         add_8          0.950195                  NaN                     NaN               NaN               NaN  0.950195
4         layer_norm          NaN                  NaN                     NaN          0.000000          0.771484       NaN
32        val_48              NaN             0.163635                     NaN               NaN               NaN       NaN
33        linear_6            NaN             0.000488                     NaN               NaN               NaN       NaN
34        add_115        0.954163                  NaN                     NaN               NaN               NaN  0.954163
35        layer_norm_1        NaN                  NaN                     NaN          0.771484          0.001953       NaN
39        val_54              NaN                  NaN                0.087646               NaN               NaN       NaN
40        linear_8            NaN                  NaN                0.000488               NaN               NaN       NaN
41        add_136        0.001953                  NaN                     NaN               NaN               NaN  0.001953

We can identity which results is mapped to which expected tensor.

Picture of the model

onx = onnx.load("plot_dump_intermediate_results.onnx")
plot_dot(onx)
plot dump intermediate results
doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue")
plot dump intermediate results

Total running time of the script: (0 minutes 12.893 seconds)

Related examples

Export microsoft/phi-2

Export microsoft/phi-2

Intermediate results with onnxruntime

Intermediate results with onnxruntime

Export with dynamic dimensions in {0,1} into ONNX (custom)

Export with dynamic dimensions in {0,1} into ONNX (custom)

Gallery generated by Sphinx-Gallery