Dumps intermediate results of a torch model

Looking for discrepancies is quickly annoying. Discrepancies come from two results obtained with the same models implemented in two different ways, pytorch and onnx. Models are big so where do they come from? That’s the unavoidable question. Unless there is an obvious reason, the only way is to compare intermediate outputs alon the computation. The first step into that direction is to dump the intermediate results coming from pytorch. We use onnx_diagnostic.helpers.torch_helper.steal_forward() for that.

A simple LLM Model

See onnx_diagnostic.helpers.torch_helper.dummy_llm() for its definition. It is mostly used for unit test or example.

import numpy as np
import pandas
import onnx
import torch
import onnxruntime
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ReportResultComparison


model, inputs, ds = dummy_llm(dynamic_shapes=True)

We use float16.

Let’s check.

print(f"type(model)={type(model)}")
print(f"inputs={string_type(inputs, with_shape=True)}")
print(f"ds={string_type(ds, with_shape=True)}")
type(model)=<class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.LLM'>
inputs=(T7s2x30,)
ds=dict(input_ids:{0:Dim(batch),1:Dim(length)})

It contains the following submodules.

for name, mod in model.named_modules():
    print(f"- {name}: {type(mod)}")
- : <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.LLM'>
- embedding: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.Embedding'>
- embedding.embedding: <class 'torch.nn.modules.sparse.Embedding'>
- embedding.pe: <class 'torch.nn.modules.sparse.Embedding'>
- decoder: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.DecoderLayer'>
- decoder.attention: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.MultiAttentionBlock'>
- decoder.attention.attention: <class 'torch.nn.modules.container.ModuleList'>
- decoder.attention.attention.0: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.AttentionBlock'>
- decoder.attention.attention.0.query: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.0.key: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.0.value: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.AttentionBlock'>
- decoder.attention.attention.1.query: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1.key: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1.value: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.linear: <class 'torch.nn.modules.linear.Linear'>
- decoder.feed_forward: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.FeedForward'>
- decoder.feed_forward.linear_1: <class 'torch.nn.modules.linear.Linear'>
- decoder.feed_forward.relu: <class 'torch.nn.modules.activation.ReLU'>
- decoder.feed_forward.linear_2: <class 'torch.nn.modules.linear.Linear'>
- decoder.norm_1: <class 'torch.nn.modules.normalization.LayerNorm'>
- decoder.norm_2: <class 'torch.nn.modules.normalization.LayerNorm'>

Steal and dump the output of submodules

The following context spies on the intermediate results for the following module and submodules. It stores in one onnx file all the input/output for those.

with steal_forward(
    [
        ("model", model),
        ("model.decoder", model.decoder),
        ("model.decoder.attention", model.decoder.attention),
        ("model.decoder.feed_forward", model.decoder.feed_forward),
        ("model.decoder.norm_1", model.decoder.norm_1),
        ("model.decoder.norm_2", model.decoder.norm_2),
    ],
    dump_file="plot_dump_intermediate_results.inputs.onnx",
    verbose=1,
    storage_limit=2**28,
):
    expected = model(*inputs)
+model -- stolen forward for class LLM -- iteration 0
  <- args=(T7s2x30,) --- kwargs={}
+model.decoder -- stolen forward for class DecoderLayer -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
+model.decoder.norm_1 -- stolen forward for class LayerNorm -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.norm_1.
-- stores key=('model.decoder.norm_1', 0), size 1Kb -- T10s2x30x16
+model.decoder.attention -- stolen forward for class MultiAttentionBlock -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.attention.
-- stores key=('model.decoder.attention', 0), size 1Kb -- T10s2x30x16
+model.decoder.norm_2 -- stolen forward for class LayerNorm -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.norm_2.
-- stores key=('model.decoder.norm_2', 0), size 1Kb -- T10s2x30x16
+model.decoder.feed_forward -- stolen forward for class FeedForward -- iteration 0
  <- args=(T10s2x30x16,) --- kwargs={}
  -> T10s2x30x16
-model.decoder.feed_forward.
-- stores key=('model.decoder.feed_forward', 0), size 1Kb -- T10s2x30x16
  -> T10s2x30x16
-model.decoder.
-- stores key=('model.decoder', 0), size 1Kb -- T10s2x30x16
  -> T10s2x30x16
-model.
-- stores key=('model', 0), size 1Kb -- T10s2x30x16
-- gather stored 12 objects, size=0 Mb
-- dumps stored objects
-- done dump stored objects

Restores saved inputs/outputs

All the intermediate tensors were saved in one unique onnx model, every tensor is stored in a constant node. The model can be run with any runtime to restore the inputs and function create_input_tensors_from_onnx_model can restore their names.

saved_tensors = create_input_tensors_from_onnx_model(
    "plot_dump_intermediate_results.inputs.onnx"
)
for k, v in saved_tensors.items():
    print(f"{k} -- {string_type(v, with_shape=True)}")
('model', 0, 'I') -- ((T7s2x30,),{})
('model.decoder', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_1', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_1', 0, 'O') -- T10s2x30x16
('model.decoder.attention', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.attention', 0, 'O') -- T10s2x30x16
('model.decoder.norm_2', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_2', 0, 'O') -- T10s2x30x16
('model.decoder.feed_forward', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.feed_forward', 0, 'O') -- T10s2x30x16
('model.decoder', 0, 'O') -- T10s2x30x16
('model', 0, 'O') -- T10s2x30x16

Let’s explained the naming convention.

('model.decoder.norm_2', 0, 'I') -- ((T1s2x30x16,),{})
            |            |   |
            |            |   +--> input, the format is args, kwargs
            |            |
            |            +--> iteration, 0 means the first time the execution
            |                 went through that module
            |                 it is possible to call multiple times,
            |                 the model to store more
            |
            +--> the name given to function steal_forward

The same goes for output except 'I' is replaced by 'O'.

('model.decoder.norm_2', 0, 'O') -- T1s2x30x16

This trick can be used to compare intermediate results coming from pytorch to any other implementation of the same model as long as it is possible to map the stored inputs/outputs.

Conversion to ONNX

The difficult point is to be able to map the saved intermediate results to intermediate results in ONNX. Let’s create the ONNX model.

epo = torch.onnx.export(model, inputs, dynamic_shapes=ds, dynamo=True)
epo.optimize()
epo.save("plot_dump_intermediate_results.onnx")
[torch.onnx] Obtain model graph for `LLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `LLM([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `LLM([...]` with `torch.export.export(..., strict=True)`...
[torch.onnx] Obtain model graph for `LLM([...]` with `torch.export.export(..., strict=True)`... ❌
[torch.onnx] Obtain model graph for `LLM([...]` with `torch.export draft_export`...
[torch.onnx] Obtain model graph for `LLM([...]` with `torch.export draft_export`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 4 of general pattern rewrite rules.

Discrepancies

We have a torch model, intermediate results and an ONNX graph equivalent to the torch model. Let’s see how we can check the discrepancies. First the discrepancies of the whole model.

sess = onnxruntime.InferenceSession(
    "plot_dump_intermediate_results.onnx", providers=["CPUExecutionProvider"]
)
feeds = dict(
    zip([i.name for i in sess.get_inputs()], [t.detach().cpu().numpy() for t in inputs])
)
got = sess.run(None, feeds)
diff = max_diff(expected, got)
print(f"discrepancies torch/ORT: {string_diff(diff)}")
discrepancies torch/ORT: abs=0.00390625, rel=0.0560350204259831, n=960.0,amax=1,26,5

What about intermediate results? Let’s use a runtime still based on onnxruntime running an eager evaluation.

sess_eager = OnnxruntimeEvaluator(
    "plot_dump_intermediate_results.onnx",
    providers=["CPUExecutionProvider"],
    torch_or_numpy=True,
)
feeds_tensor = dict(zip([i.name for i in sess.get_inputs()], inputs))
got = sess_eager.run(None, feeds_tensor)
diff = max_diff(expected, got)
print(f"discrepancies torch/eager ORT: {string_diff(diff)}")
discrepancies torch/eager ORT: abs=0.001953125, rel=0.0255780642520974, n=960.0,amax=0,7,9

They are almost the same. That’s good. Let’s now dig into the intermediate results. They are compared to the outputs stored in saved_tensors during the execution of the model.

baseline = {}
for k, v in saved_tensors.items():
    if k[-1] == "I":  # inputs are excluded
        continue
    if isinstance(v, torch.Tensor):
        baseline[f"{k[0]}.{k[1]}".replace("model.decoder", "decoder")] = v

report_cmp = ReportResultComparison(baseline)
sess_eager.run(None, feeds_tensor, report_cmp=report_cmp)
[tensor([[[-7.6221e-01,  2.2109e+00, -2.9297e-01, -2.0684e+00,  1.9785e+00,
           1.5889e+00,  1.6406e-01, -8.2617e-01, -9.1113e-01, -1.1309e+00,
          -2.7344e-01, -1.1357e+00, -7.1167e-02,  3.9697e-01, -5.1660e-01,
          -7.8125e-02],
         [ 1.3318e-01,  1.4707e+00, -1.3623e-01,  9.8145e-01, -3.3179e-01,
           6.4746e-01, -1.4141e+00,  2.3145e+00, -9.6582e-01,  2.0352e+00,
           1.0674e+00, -9.1260e-01, -1.0625e+00, -6.4551e-01,  8.9941e-01,
           8.5693e-01],
         [-6.9238e-01,  1.8643e+00, -2.3125e+00, -1.3418e+00,  6.5723e-01,
           5.0586e-01,  1.6016e-01,  2.8047e+00,  2.6123e-01,  2.6914e+00,
          -6.6113e-01, -1.1768e-01, -2.7168e+00,  9.5703e-02, -1.9258e+00,
          -5.7080e-01],
         [ 2.0586e+00, -2.3691e+00, -3.1035e+00,  6.4209e-01,  9.3701e-01,
          -2.8516e-01,  3.9199e+00, -2.7417e-01,  8.3398e-01, -1.0664e+00,
           6.0645e-01, -4.8291e-01, -3.2251e-01,  2.9346e-01, -1.8037e+00,
          -6.4160e-01],
         [-1.2617e+00,  1.8223e+00,  1.6221e+00,  1.4307e-01,  1.6074e+00,
           1.0293e+00, -1.7119e+00,  2.8711e-01,  2.0630e-01, -8.9893e-01,
           8.3008e-01,  1.0225e+00, -1.1738e+00, -8.6670e-01,  5.9875e-02,
           1.0527e+00],
         [-4.3066e-01,  1.4014e+00,  2.5957e+00,  8.4863e-01, -1.3291e+00,
           1.2949e+00, -1.5977e+00, -1.0342e+00,  6.6064e-01, -1.1055e+00,
          -5.4395e-01, -6.8555e-01,  1.5693e+00,  7.9712e-02,  2.1973e+00,
          -1.9004e+00],
         [-2.3027e+00, -3.5986e-01, -2.3711e+00, -1.2510e+00,  9.4238e-02,
           2.1133e+00, -1.2383e+00, -3.0098e+00,  1.3359e+00,  9.9023e-01,
          -9.5020e-01, -9.3262e-01,  1.1094e+00,  3.8867e-01, -9.0430e-01,
           1.8643e+00],
         [-3.9819e-01, -2.5254e+00, -8.3594e-01, -1.0273e+00,  4.8438e-01,
          -1.3711e+00, -7.8223e-01, -2.5732e-01,  4.7729e-02,  2.6367e+00,
           1.7910e+00, -1.9531e+00, -6.6113e-01, -3.0664e-01,  6.9922e-01,
           5.3271e-01],
         [ 2.3262e+00, -2.9980e-01, -5.6396e-01,  2.2500e+00, -1.0908e+00,
          -2.7578e+00, -1.0625e+00,  1.1641e+00, -1.6758e+00, -9.7998e-01,
           1.3789e+00, -4.6118e-01, -1.7334e+00,  1.8398e+00,  1.9658e+00,
          -1.5996e+00],
         [ 3.1812e-01, -7.1411e-02,  8.6670e-01, -5.6396e-01,  4.4751e-01,
           1.7236e-01, -1.5762e+00,  1.6370e-01, -4.8096e-01,  3.8062e-01,
           1.1777e+00,  3.2593e-02, -7.0703e-01, -1.1006e+00,  9.0137e-01,
          -7.1143e-01],
         [ 7.9395e-01, -8.3838e-01,  6.1279e-01,  2.1816e+00, -3.3125e+00,
           2.7695e+00,  1.4482e+00,  1.2725e+00,  1.9697e+00,  7.1387e-01,
           6.6650e-01, -1.9434e+00,  1.3789e+00,  1.3525e+00,  1.3408e+00,
           9.7705e-01],
         [ 1.6577e-01, -1.5762e+00,  9.4141e-01, -3.2837e-02,  1.5840e+00,
           1.0625e+00, -2.9902e+00, -1.0713e+00,  1.9072e+00,  4.9707e-01,
          -3.2666e-01,  2.2217e-02, -9.5654e-01, -1.3306e-01, -9.7559e-01,
           8.1152e-01],
         [ 3.7280e-01,  4.4263e-01,  7.9736e-01, -4.3896e-01, -7.6123e-01,
          -1.5000e+00, -1.0039e+00,  9.9854e-01,  1.5957e+00,  6.8262e-01,
          -1.2451e+00, -2.3706e-01, -1.9453e+00,  1.2852e+00,  2.5938e+00,
           2.1411e-01],
         [-6.8262e-01,  5.8984e-01,  4.1895e-01,  1.5957e+00, -1.9326e+00,
           4.3945e-03,  2.5781e-01,  2.1836e+00, -1.4590e+00,  2.9004e-01,
          -3.8477e+00,  1.5161e-01,  3.0762e-02,  3.1680e+00, -4.9609e-01,
           1.4902e+00],
         [ 2.0820e+00,  6.7627e-01,  4.1113e-01,  1.5078e+00,  2.6211e+00,
           7.5049e-01, -8.5352e-01,  6.3867e-01,  2.0117e+00, -1.2871e+00,
           2.0312e+00,  1.2891e+00, -9.4629e-01, -1.9824e-01,  5.3223e-01,
          -4.7974e-01],
         [-1.8105e+00,  2.3574e+00, -2.1152e+00, -8.1348e-01,  3.7549e-01,
          -1.8105e+00, -7.3828e-01, -1.5195e+00,  6.7334e-01,  2.2656e+00,
          -3.0151e-01, -1.5781e+00, -1.8633e+00, -5.0195e-01,  1.8672e+00,
          -2.8906e+00],
         [-4.0381e-01,  6.6748e-01,  1.5049e+00, -5.2051e-01, -3.8105e+00,
           8.7891e-03, -2.5312e+00, -3.5913e-01,  1.4844e+00,  1.0029e+00,
           2.5312e+00, -6.2793e-01, -1.2266e+00, -1.1006e+00, -1.2900e+00,
          -3.4473e-01],
         [ 1.1445e+00, -2.4585e-01,  4.8706e-01,  1.3848e+00,  3.1494e-02,
           1.3257e-01,  1.3115e+00, -1.2734e+00,  6.3037e-01, -2.5605e+00,
          -1.5342e+00, -6.6833e-02, -6.0486e-02, -3.3867e+00, -2.6621e+00,
          -3.4961e-01],
         [-1.8188e-01,  6.1914e-01,  1.2930e+00, -1.2830e-01, -2.1973e+00,
           5.3711e-03, -8.3350e-01,  5.8594e-01,  1.5356e-01, -7.0410e-01,
          -1.4453e+00,  1.1123e+00, -2.1582e+00, -2.5430e+00,  5.0684e-01,
           2.2773e+00],
         [ 1.0273e+00,  4.0283e-01, -8.9697e-01, -1.6211e+00,  1.7852e+00,
          -1.1797e+00,  8.0322e-02, -1.3516e+00,  2.5625e+00, -1.7129e+00,
           6.0498e-01,  1.7227e+00,  6.1816e-01,  1.1953e+00, -2.9395e-01,
          -5.1562e-01],
         [ 3.9478e-01, -1.2539e+00, -1.9678e+00,  3.4180e-02, -3.0859e+00,
           1.6104e+00,  1.8545e+00, -2.9277e+00, -1.5693e+00,  2.2793e+00,
           2.5352e+00, -4.6143e-01, -1.3584e+00,  6.4697e-01,  5.8301e-01,
          -2.1289e+00],
         [-1.6052e-01,  1.6045e+00,  4.6021e-02,  1.0732e+00, -4.9487e-01,
           7.7832e-01, -1.2158e+00,  2.3457e+00, -1.0098e+00,  1.8701e+00,
           9.6289e-01, -8.8916e-01, -1.4316e+00, -6.4404e-01,  1.0596e+00,
           9.0527e-01],
         [-1.9688e+00,  1.0068e+00, -3.4454e-02, -1.0039e+00,  5.2295e-01,
          -2.0715e-01,  1.0195e+00, -5.7422e-01, -1.5254e+00, -2.3364e-01,
          -3.4229e-01,  3.5762e+00, -9.0332e-01,  3.1421e-01,  1.3516e+00,
           1.0020e+00],
         [-6.6895e-01,  9.4189e-01, -4.3091e-01, -1.4912e+00, -2.6367e+00,
           1.1074e+00,  1.8066e-01, -2.7930e+00, -1.4023e+00, -2.4824e+00,
           8.5083e-02,  1.1299e+00, -1.3096e+00, -5.8398e-01, -3.6182e-01,
          -6.4404e-01],
         [ 9.3933e-02, -1.4941e+00, -2.1936e-01, -1.6240e+00, -1.1201e+00,
           2.7661e-01,  1.0410e+00, -2.6855e+00,  1.1172e+00,  1.1973e+00,
          -3.2788e-01,  8.3252e-01, -2.7686e-01,  4.6021e-01, -1.1436e+00,
          -6.8799e-01],
         [-1.0762e+00,  7.9883e-01, -2.1191e+00,  1.5312e+00,  1.2236e+00,
           4.8486e-01,  2.3555e+00,  1.6523e+00, -6.3049e-02, -1.4326e+00,
          -1.1914e+00, -1.7188e+00, -6.4746e-01, -2.6807e-01,  1.3721e+00,
           1.9121e+00],
         [ 1.7188e+00,  2.7754e+00, -1.3271e+00, -3.1758e+00,  1.1367e+00,
           7.7051e-01, -2.1753e-01, -2.2109e+00, -2.7969e+00, -2.4780e-01,
          -3.7305e-01, -3.3716e-01, -1.0977e+00,  2.4072e-01,  5.7080e-01,
          -6.0352e-01],
         [ 2.0654e-01, -1.1543e+00, -1.7031e+00,  1.6211e+00, -2.2681e-01,
          -1.6816e+00,  2.7715e+00, -1.1660e+00, -2.3096e-01, -4.4141e-01,
          -2.7363e+00, -1.9592e-01,  6.2012e-01, -1.4521e+00,  1.4463e+00,
          -6.6699e-01],
         [-6.6895e-01,  1.0967e+00,  5.5713e-01,  2.8320e-01, -2.6543e+00,
           2.3711e+00, -2.9766e+00, -7.1924e-01, -2.6416e-01,  8.0933e-02,
          -2.3398e+00,  2.0234e+00,  1.1406e+00,  8.3740e-01, -9.1650e-01,
          -2.3418e+00],
         [ 1.7017e-01, -9.3457e-01, -1.9189e+00, -1.1758e+00,  1.7044e-02,
           7.1094e-01,  8.7891e-03,  2.6270e+00,  1.9961e+00,  8.4351e-02,
          -8.2520e-01,  4.2773e-01,  7.1228e-02,  1.6699e+00, -1.1533e+00,
           1.7090e+00]],

        [[-1.6094e+00, -1.2832e+00, -2.6035e+00, -2.8770e+00,  3.5083e-01,
           1.9014e+00,  1.6284e-01, -2.7246e-01,  1.1914e+00,  1.9443e+00,
          -2.8809e+00,  3.5010e-01, -8.8037e-01, -3.7070e+00, -2.5469e+00,
          -4.4434e-02],
         [-7.7100e-01, -5.4492e-01,  1.1016e+00,  9.6558e-02, -1.4197e-01,
          -1.1445e+00,  1.4443e+00,  3.7451e-01, -1.2512e-01, -4.6875e-01,
           7.4805e-01, -1.7080e+00,  7.9004e-01,  2.4961e+00, -8.6426e-01,
          -9.9121e-02],
         [ 2.7461e+00, -3.1738e-01, -9.0039e-01, -1.0229e-01,  1.0527e+00,
           9.9365e-01, -1.1975e-01,  1.1533e+00,  1.3203e+00, -1.4954e-01,
          -8.5449e-01, -4.6606e-01,  1.6953e+00,  2.3926e+00, -7.5635e-01,
           9.7705e-01],
         [ 9.7510e-01, -1.2695e+00, -2.6611e-01,  2.4590e+00, -2.7090e+00,
           8.3398e-01,  5.3418e-01, -8.7354e-01,  1.9473e+00, -5.0928e-01,
           4.4897e-01,  3.8232e-01, -1.0000e+00, -8.1299e-02,  6.4636e-02,
          -2.6904e-01],
         [-3.0879e+00,  2.1680e-01,  1.7354e+00, -3.5815e-01, -2.9512e+00,
          -9.0674e-01,  6.9922e-01, -2.5156e+00, -1.3828e+00, -2.6587e-01,
          -1.7559e+00, -1.2217e+00,  1.9023e+00,  2.4512e+00,  1.6602e+00,
          -1.0469e+00],
         [-7.0068e-02, -7.1729e-01, -7.6611e-01,  1.8848e+00,  3.0005e-01,
          -9.3359e-01, -1.1846e+00,  6.1768e-01,  8.4033e-01, -1.9844e+00,
           2.2988e+00,  1.8018e+00, -1.8347e-01, -2.9219e+00,  2.9277e+00,
          -1.2119e+00],
         [ 1.0034e-01, -1.8433e-01,  1.6284e-01, -1.4111e+00,  3.1152e-01,
           1.8232e+00, -8.4521e-01,  2.6270e-01,  2.9375e+00,  2.3594e+00,
           7.4951e-01, -5.6543e-01, -4.8633e-01,  1.0791e+00,  1.3057e+00,
           9.0527e-01],
         [ 2.9512e+00, -1.0254e+00,  1.7834e-01, -1.1582e+00,  2.1406e+00,
          -1.4062e+00, -6.2012e-01,  6.3379e-01, -1.1855e+00,  2.5801e+00,
          -8.5498e-01,  3.1152e+00,  6.5918e-02, -9.5020e-01,  1.0193e-02,
           3.7085e-01],
         [-7.3047e-01, -4.8315e-01,  1.3809e+00, -1.7549e+00, -1.1992e+00,
           3.6738e+00,  2.9785e-01, -4.0674e-01,  6.4941e-01,  5.4004e-01,
           1.2598e-01,  2.5859e+00,  2.2480e+00, -3.3783e-02, -2.0488e+00,
           2.9062e+00],
         [-1.8604e+00,  6.4148e-02, -6.3379e-01, -1.2939e-01,  4.5410e-02,
          -2.3398e+00, -1.4082e+00, -2.9648e+00, -8.5791e-01,  1.0996e+00,
          -7.1826e-01,  2.5859e+00, -1.1406e+00, -5.5127e-01, -2.2961e-01,
           1.4004e+00],
         [-9.8926e-01, -7.7686e-01, -1.7493e-01, -5.0684e-01,  4.6826e-01,
           1.1816e+00,  1.6562e+00,  4.4141e-01, -7.0605e-01,  1.2812e+00,
          -1.3486e+00, -1.2754e+00, -1.5791e+00, -1.9702e-01, -1.6943e+00,
          -2.1895e+00],
         [ 7.1582e-01, -7.3486e-02,  6.3184e-01,  1.2021e+00,  1.3047e+00,
           2.4805e+00,  2.6367e+00, -1.6699e+00, -1.2412e+00,  2.7559e+00,
          -2.2285e+00,  6.6284e-02,  1.0229e-01,  3.5137e+00, -7.8979e-02,
          -1.6836e+00],
         [-1.2393e+00, -5.3894e-02,  1.0859e+00, -2.0723e+00,  5.8777e-02,
          -4.2285e-01, -1.3906e+00, -1.6650e+00, -5.1208e-02, -3.1201e-01,
           2.6797e+00,  8.3154e-01, -2.9082e+00, -5.9753e-02,  1.5781e+00,
           8.4521e-01],
         [ 1.0410e+00,  2.1914e+00,  1.1895e+00,  2.4121e+00,  1.4287e+00,
           1.0488e+00, -8.2520e-01, -5.5371e-01, -8.8184e-01, -3.5474e-01,
          -4.8999e-01, -4.4360e-01,  6.4844e-01,  1.0488e+00,  2.0068e-01,
           1.0781e+00],
         [-3.9941e-01,  5.4932e-01, -2.7656e+00, -1.6270e+00,  5.0146e-01,
          -1.7627e-01, -1.0801e+00, -6.9336e-01, -3.2246e+00,  1.0312e+00,
          -9.8633e-02, -5.4590e-01, -6.1182e-01,  7.2412e-01,  1.7549e+00,
           5.3857e-01],
         [ 1.1240e+00,  3.8037e-01, -8.9551e-01, -1.7617e+00,  1.9268e+00,
          -1.1084e+00,  5.4565e-02, -1.2666e+00,  2.6289e+00, -1.7637e+00,
           6.1523e-01,  1.6084e+00,  7.9883e-01,  1.3037e+00, -2.5732e-01,
          -5.6104e-01],
         [ 1.4785e+00,  1.2617e+00,  1.1934e+00, -2.8457e+00,  3.0312e+00,
          -1.5161e-01, -3.3179e-01, -9.8340e-01,  2.0254e+00,  4.9512e-01,
           1.1230e-01, -6.0938e-01,  2.7051e-01, -3.5083e-01, -8.3594e-01,
           1.7559e+00],
         [ 1.9766e+00, -1.8076e+00, -3.4453e+00,  1.8213e+00, -1.3379e+00,
           2.4531e+00, -1.1406e+00, -2.7148e+00, -2.4023e+00, -6.0791e-01,
          -1.4834e+00, -6.7725e-01,  4.7437e-01,  9.0234e-01,  2.2109e+00,
           3.1387e+00],
         [ 2.1741e-01, -1.2510e+00,  2.8687e-03,  1.4414e+00, -1.5635e+00,
          -2.1582e+00, -5.4138e-02, -1.7725e+00, -1.6650e+00,  7.7637e-01,
          -1.7578e+00, -2.5840e+00,  1.2295e+00,  1.2354e+00,  6.7871e-01,
           4.3408e-01],
         [-1.1357e+00, -1.4722e-01, -1.0547e+00,  4.2773e-01, -7.0557e-01,
           3.9331e-01,  1.3857e+00,  1.8311e+00,  1.1475e-02, -9.0381e-01,
          -3.1348e+00,  1.6797e+00,  2.4048e-02,  4.5776e-03, -3.3325e-01,
           7.1680e-01],
         [-1.0820e+00, -1.7451e+00, -6.2793e-01, -2.4922e+00, -9.4482e-02,
           2.3022e-01,  3.5547e-01,  1.6523e+00,  7.5146e-01, -1.5820e+00,
          -3.1152e-01, -1.6758e+00,  1.1396e+00, -7.0947e-01, -1.4287e+00,
          -1.2861e+00],
         [-2.3164e+00,  1.0771e+00,  1.4951e+00, -1.3516e+00,  3.3081e-01,
           3.6875e+00, -3.0020e+00, -1.0469e+00, -1.0742e-01, -1.4453e+00,
           2.4688e+00, -1.6143e+00,  1.6924e+00, -3.4717e-01,  1.2217e+00,
           1.8535e+00],
         [-3.7598e-01,  1.9570e+00,  2.9141e+00,  6.8604e-01,  1.0078e+00,
           2.0776e-01,  1.6299e+00,  4.7607e-02, -1.2656e+00, -7.4414e-01,
          -1.0518e+00,  1.5059e+00, -2.5269e-01, -1.1055e+00,  8.9502e-01,
           1.6914e+00],
         [-1.2732e-01, -1.4453e+00,  6.1084e-01,  2.1152e+00, -5.3320e-01,
           5.8301e-01,  1.3110e-01, -8.5986e-01,  1.1582e+00,  1.5459e+00,
           3.6377e-01,  6.6797e-01,  6.4551e-01,  7.2314e-01, -7.0996e-01,
           1.7676e+00],
         [ 1.7249e-01,  5.0781e-01, -1.5352e+00,  1.4102e+00, -8.5498e-01,
          -1.4980e+00,  1.9092e+00, -1.5596e+00,  9.7266e-01,  1.8555e+00,
          -2.0723e+00, -2.6543e+00, -1.2285e+00, -1.1846e+00, -3.1982e-02,
           7.9395e-01],
         [ 7.1045e-01, -9.0576e-02,  6.8506e-01,  1.2080e+00,  1.2783e+00,
           2.4668e+00,  2.5684e+00, -1.7178e+00, -1.2324e+00,  2.7812e+00,
          -2.2266e+00,  2.2339e-02,  1.7578e-01,  3.4355e+00, -5.6885e-02,
          -1.7246e+00],
         [-2.1914e+00, -2.0020e+00,  4.7607e-02,  2.6660e+00,  9.0234e-01,
           4.1328e+00,  2.3926e-02, -9.0332e-01,  1.3438e+00,  1.1309e+00,
           6.5234e-01,  2.1191e-01, -2.6709e-01, -9.2969e-01,  6.6504e-01,
          -8.7793e-01],
         [ 1.4463e+00, -1.4893e+00,  1.5723e-01,  2.5703e+00,  9.3018e-02,
           1.7051e+00, -3.2593e-02,  2.1250e+00,  5.1172e+00,  1.0635e+00,
           6.6357e-01, -2.4817e-01,  3.3276e-01,  1.0742e+00,  1.3027e+00,
           1.3906e+00],
         [ 1.6006e+00, -4.4434e-01, -9.7510e-01,  2.2266e+00,  1.0010e+00,
          -5.6738e-01,  2.2793e+00,  1.9932e+00,  6.4209e-01,  1.5059e+00,
          -1.6846e-01, -1.6328e+00,  8.2031e-01,  1.1533e+00,  4.6924e-01,
           1.5068e+00],
         [-1.5928e+00,  2.9224e-01,  1.0479e+00, -1.6494e+00,  2.1484e+00,
           4.6411e-01,  7.0068e-01,  2.4531e+00,  1.7803e+00,  1.1211e+00,
          -5.9375e-01, -2.6680e+00, -1.5918e+00,  7.7148e-01,  1.0098e+00,
           1.4160e+00]]], dtype=torch.float16)]

Let’s see the results.

data = report_cmp.data
df = pandas.DataFrame(data)
piv = df.pivot(index=("run_index", "run_name"), columns="ref_name", values="abs")
print(piv)
ref_name                decoder.0  decoder.attention.0  decoder.feed_forward.0  decoder.norm_1.0  decoder.norm_2.0   model.0
run_index run_name
1         embedding      3.353760             3.804688                3.646154          3.273438          3.194336  3.353760
2         embedding_1    3.757812             2.981689                3.132324          2.791992          2.773438  3.757812
3         add_8          0.901367             5.424744                5.466675          2.386719          2.488281  0.901367
4         layer_norm     2.167969             3.038025                3.079956          0.000488          0.837891  2.167969
5         linear         5.480957             1.773804                2.070557          3.582031          3.587891  5.480957
6         linear_1       5.811035             1.938354                1.949707          3.869141          3.912109  5.811035
7         linear_2       5.172760             1.906738                1.933350          3.458008          3.479492  5.172760
24        matmul_1       5.122120             1.672363                1.486816          2.954151          2.852589  5.122120
25        linear_3       5.262817             2.174622                2.492188          3.267090          3.274414  5.262817
26        linear_4       4.065918             1.873535                1.839844          2.889648          2.946289  4.065918
27        linear_5       6.014160             1.930847                1.868774          3.846191          3.744629  6.014160
37        matmul_3       4.951782             1.638062                1.408295          3.580078          3.302734  4.951782
39        val_48         5.161041             0.154663                0.970764          2.993073          2.891510  5.161041
40        linear_6       5.205994             0.000488                0.934631          3.038025          2.936462  5.205994
41        add_115        0.833008             5.334900                5.376831          2.296875          2.398438  0.833008
42        layer_norm_1   2.269531             2.936462                2.978394          0.837524          0.001953  2.269531
46        val_54         5.178833             0.929688                0.084534          3.010864          2.909302  5.178833
47        linear_8       5.247803             0.934967                0.000488          3.079834          2.978271  5.247803
48        add_136        0.001953             5.205994                5.247925          2.167969          2.269531  0.001953

Let’s clean a little bit.

piv[piv >= 1] = np.nan
print(piv.dropna(axis=0, how="all"))
ref_name                decoder.0  decoder.attention.0  decoder.feed_forward.0  decoder.norm_1.0  decoder.norm_2.0   model.0
run_index run_name
3         add_8          0.901367                  NaN                     NaN               NaN               NaN  0.901367
4         layer_norm          NaN                  NaN                     NaN          0.000488          0.837891       NaN
39        val_48              NaN             0.154663                0.970764               NaN               NaN       NaN
40        linear_6            NaN             0.000488                0.934631               NaN               NaN       NaN
41        add_115        0.833008                  NaN                     NaN               NaN               NaN  0.833008
42        layer_norm_1        NaN                  NaN                     NaN          0.837524          0.001953       NaN
46        val_54              NaN             0.929688                0.084534               NaN               NaN       NaN
47        linear_8            NaN             0.934967                0.000488               NaN               NaN       NaN
48        add_136        0.001953                  NaN                     NaN               NaN               NaN  0.001953

We can identity which results is mapped to which expected tensor.

Picture of the model

onx = onnx.load("plot_dump_intermediate_results.onnx")
plot_dot(onx)
plot dump intermediate results
doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue")
plot dump intermediate results

Total running time of the script: (0 minutes 30.663 seconds)

Related examples

Export microsoft/phi-2

Export microsoft/phi-2

Intermediate results with onnxruntime

Intermediate results with onnxruntime

Find where a model is failing by running submodels

Find where a model is failing by running submodels

Gallery generated by Sphinx-Gallery