Note
Go to the end to download the full example code.
Dumps intermediate results of a torch model¶
Looking for discrepancies is quickly annoying. Discrepancies
come from two results obtained with the same models
implemented in two different ways, pytorch and onnx.
Models are big so where do they come from? That’s the
unavoidable question. Unless there is an obvious reason,
the only way is to compare intermediate outputs alon the computation.
The first step into that direction is to dump the intermediate results
coming from pytorch.
We use onnx_diagnostic.helpers.torch_helper.steal_forward()
for that.
A simple LLM Model¶
See onnx_diagnostic.helpers.torch_helper.dummy_llm()
for its definition. It is mostly used for unit test or example.
import numpy as np
import pandas
import onnx
import torch
import onnxruntime
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ReportResultComparison
model, inputs, ds = dummy_llm(dynamic_shapes=True)
We use float16.
model = model.to(torch.float16)
Let’s check.
print(f"type(model)={type(model)}")
print(f"inputs={string_type(inputs, with_shape=True)}")
print(f"ds={string_type(ds, with_shape=True)}")
type(model)=<class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.LLM'>
inputs=(T7s2x30,)
ds=dict(input_ids:{0:Dim(batch),1:Dim(length)})
It contains the following submodules.
for name, mod in model.named_modules():
print(f"- {name}: {type(mod)}")
- : <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.LLM'>
- embedding: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.Embedding'>
- embedding.embedding: <class 'torch.nn.modules.sparse.Embedding'>
- embedding.pe: <class 'torch.nn.modules.sparse.Embedding'>
- decoder: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.DecoderLayer'>
- decoder.attention: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.MultiAttentionBlock'>
- decoder.attention.attention: <class 'torch.nn.modules.container.ModuleList'>
- decoder.attention.attention.0: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.AttentionBlock'>
- decoder.attention.attention.0.query: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.0.key: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.0.value: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.AttentionBlock'>
- decoder.attention.attention.1.query: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1.key: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1.value: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.linear: <class 'torch.nn.modules.linear.Linear'>
- decoder.feed_forward: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.FeedForward'>
- decoder.feed_forward.linear_1: <class 'torch.nn.modules.linear.Linear'>
- decoder.feed_forward.relu: <class 'torch.nn.modules.activation.ReLU'>
- decoder.feed_forward.linear_2: <class 'torch.nn.modules.linear.Linear'>
- decoder.norm_1: <class 'torch.nn.modules.normalization.LayerNorm'>
- decoder.norm_2: <class 'torch.nn.modules.normalization.LayerNorm'>
Steal and dump the output of submodules¶
The following context spies on the intermediate results for the following module and submodules. It stores in one onnx file all the input/output for those.
with steal_forward(
[
("model", model),
("model.decoder", model.decoder),
("model.decoder.attention", model.decoder.attention),
("model.decoder.feed_forward", model.decoder.feed_forward),
("model.decoder.norm_1", model.decoder.norm_1),
("model.decoder.norm_2", model.decoder.norm_2),
],
dump_file="plot_dump_intermediate_results.inputs.onnx",
verbose=1,
storage_limit=2**28,
):
expected = model(*inputs)
+model -- stolen forward for class LLM -- iteration 0
<- args=(T7s2x30,) --- kwargs={}
+model.decoder -- stolen forward for class DecoderLayer -- iteration 0
<- args=(T10s2x30x16,) --- kwargs={}
+model.decoder.norm_1 -- stolen forward for class LayerNorm -- iteration 0
<- args=(T10s2x30x16,) --- kwargs={}
-> T10s2x30x16
-model.decoder.norm_1.
-- stores key=('model.decoder.norm_1', 0), size 1Kb -- T10s2x30x16
+model.decoder.attention -- stolen forward for class MultiAttentionBlock -- iteration 0
<- args=(T10s2x30x16,) --- kwargs={}
-> T10s2x30x16
-model.decoder.attention.
-- stores key=('model.decoder.attention', 0), size 1Kb -- T10s2x30x16
+model.decoder.norm_2 -- stolen forward for class LayerNorm -- iteration 0
<- args=(T10s2x30x16,) --- kwargs={}
-> T10s2x30x16
-model.decoder.norm_2.
-- stores key=('model.decoder.norm_2', 0), size 1Kb -- T10s2x30x16
+model.decoder.feed_forward -- stolen forward for class FeedForward -- iteration 0
<- args=(T10s2x30x16,) --- kwargs={}
-> T10s2x30x16
-model.decoder.feed_forward.
-- stores key=('model.decoder.feed_forward', 0), size 1Kb -- T10s2x30x16
-> T10s2x30x16
-model.decoder.
-- stores key=('model.decoder', 0), size 1Kb -- T10s2x30x16
-> T10s2x30x16
-model.
-- stores key=('model', 0), size 1Kb -- T10s2x30x16
-- gather stored 12 objects, size=0 Mb
-- dumps stored objects
-- done dump stored objects
Restores saved inputs/outputs¶
All the intermediate tensors were saved in one unique onnx model,
every tensor is stored in a constant node.
The model can be run with any runtime to restore the inputs
and function create_input_tensors_from_onnx_model
can restore their names.
saved_tensors = create_input_tensors_from_onnx_model(
"plot_dump_intermediate_results.inputs.onnx"
)
for k, v in saved_tensors.items():
print(f"{k} -- {string_type(v, with_shape=True)}")
('model', 0, 'I') -- ((T7s2x30,),{})
('model.decoder', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_1', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_1', 0, 'O') -- T10s2x30x16
('model.decoder.attention', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.attention', 0, 'O') -- T10s2x30x16
('model.decoder.norm_2', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_2', 0, 'O') -- T10s2x30x16
('model.decoder.feed_forward', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.feed_forward', 0, 'O') -- T10s2x30x16
('model.decoder', 0, 'O') -- T10s2x30x16
('model', 0, 'O') -- T10s2x30x16
Let’s explained the naming convention.
('model.decoder.norm_2', 0, 'I') -- ((T1s2x30x16,),{})
| | |
| | +--> input, the format is args, kwargs
| |
| +--> iteration, 0 means the first time the execution
| went through that module
| it is possible to call multiple times,
| the model to store more
|
+--> the name given to function steal_forward
The same goes for output except 'I'
is replaced by 'O'
.
('model.decoder.norm_2', 0, 'O') -- T1s2x30x16
This trick can be used to compare intermediate results coming from pytorch to any other implementation of the same model as long as it is possible to map the stored inputs/outputs.
Conversion to ONNX¶
The difficult point is to be able to map the saved intermediate results to intermediate results in ONNX. Let’s create the ONNX model.
ep = torch.export.export(model, inputs, dynamic_shapes=ds)
epo = torch.onnx.export(ep, dynamo=True)
epo.optimize()
epo.save("plot_dump_intermediate_results.onnx")
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 4 of general pattern rewrite rules.
Discrepancies¶
We have a torch model, intermediate results and an ONNX graph equivalent to the torch model. Let’s see how we can check the discrepancies. First the discrepancies of the whole model.
sess = onnxruntime.InferenceSession(
"plot_dump_intermediate_results.onnx", providers=["CPUExecutionProvider"]
)
feeds = dict(
zip([i.name for i in sess.get_inputs()], [t.detach().cpu().numpy() for t in inputs])
)
got = sess.run(None, feeds)
diff = max_diff(expected, got)
print(f"discrepancies torch/ORT: {string_diff(diff)}")
discrepancies torch/ORT: abs=0.001953125, rel=0.045897074220567706, n=960.0,amax=0,0,8
What about intermediate results? Let’s use a runtime still based on onnxruntime running an eager evaluation.
sess_eager = OnnxruntimeEvaluator(
"plot_dump_intermediate_results.onnx",
providers=["CPUExecutionProvider"],
torch_or_numpy=True,
)
feeds_tensor = dict(zip([i.name for i in sess.get_inputs()], inputs))
got = sess_eager.run(None, feeds_tensor)
diff = max_diff(expected, got)
print(f"discrepancies torch/eager ORT: {string_diff(diff)}")
discrepancies torch/eager ORT: abs=0.001953125, rel=0.058166589111214514, n=960.0,amax=0,0,11
They are almost the same. That’s good. Let’s now dig into the intermediate results. They are compared to the outputs stored in saved_tensors during the execution of the model.
baseline = {}
for k, v in saved_tensors.items():
if k[-1] == "I": # inputs are excluded
continue
if isinstance(v, torch.Tensor):
baseline[f"{k[0]}.{k[1]}".replace("model.decoder", "decoder")] = v
report_cmp = ReportResultComparison(baseline)
sess_eager.run(None, feeds_tensor, report_cmp=report_cmp)
[tensor([[[-2.2637e+00, -4.0991e-01, 4.7754e-01, 1.5400e+00, 8.5156e-01,
-1.2939e+00, 9.7559e-01, 8.1201e-01, -2.4648e+00, 5.7910e-01,
2.1465e+00, -1.1602e+00, -7.2559e-01, 2.0332e+00, -8.1445e-01,
-3.0879e+00],
[-1.7490e+00, 1.6846e-01, -1.5732e+00, 1.4053e+00, 7.3193e-01,
-1.9453e+00, -9.8779e-01, 2.9419e-01, 8.0566e-01, -2.2402e+00,
-1.0527e+00, 1.0254e+00, -5.5176e-02, -5.1318e-01, -1.3193e+00,
-1.2031e+00],
[-1.0430e+00, -6.1426e-01, -1.0898e+00, -2.3999e-01, -5.6274e-02,
1.9951e+00, -6.3721e-01, 1.6035e+00, -3.0640e-01, -7.4951e-02,
1.0020e+00, -1.8672e+00, 6.7383e-01, -1.0977e+00, 1.4590e+00,
-1.8691e+00],
[ 4.9316e-02, 3.7070e+00, 8.3594e-01, -1.9082e+00, 1.9592e-01,
1.1133e-01, 2.5859e+00, 5.7324e-01, 3.4082e-01, 9.4238e-02,
-7.2998e-02, -7.7539e-01, 4.5264e-01, -4.3750e+00, 2.9414e+00,
2.0430e+00],
[ 2.4766e+00, -4.8486e-01, -2.2949e+00, 1.1309e+00, -1.5137e+00,
-1.9727e-01, 2.5078e+00, -8.4961e-02, -2.3828e+00, -3.5474e-01,
-7.4072e-01, -3.9648e+00, -1.5791e+00, -1.4521e+00, -9.0771e-01,
-7.3242e-01],
[ 3.1543e-01, 6.2939e-01, 3.8281e-01, 1.1797e+00, 1.1445e+00,
-7.9492e-01, -1.3789e+00, 1.8477e+00, -2.5146e-01, 2.1472e-01,
9.1895e-01, -2.0625e+00, -1.4219e+00, 2.3789e+00, -2.8906e-01,
1.7651e-01],
[-2.4048e-01, 1.5117e+00, 1.0732e+00, 5.4688e-01, -2.3672e+00,
-1.2292e-01, -2.8809e+00, 7.2705e-01, -1.0674e+00, 2.3789e+00,
-5.1904e-01, 3.2324e-01, -3.5840e+00, 5.9570e-01, 2.3499e-01,
1.2412e+00],
[-2.1875e+00, -3.0347e-01, 1.5986e+00, 2.3828e-01, -1.3301e+00,
5.4199e-02, -2.0703e+00, 1.7725e+00, -1.3660e-01, 5.8252e-01,
1.1318e+00, 1.5801e+00, 2.6172e-01, 1.0801e+00, 7.7588e-01,
2.2051e+00],
[ 1.0439e+00, -1.6028e-01, -9.3408e-01, -8.4619e-01, -1.5547e+00,
7.3828e-01, -4.5801e-01, -1.6250e+00, 2.0723e+00, 6.0645e-01,
-1.1182e+00, -3.1953e+00, 2.4980e+00, 8.5498e-01, -4.4678e-01,
-6.2598e-01],
[-1.5537e+00, -1.9521e+00, 2.6074e-01, 6.3818e-01, -1.3740e+00,
-3.4033e-01, 3.0371e+00, 8.1641e-01, 8.5449e-01, 2.5625e+00,
-8.4229e-01, -7.0312e-01, 1.5625e+00, -2.7363e+00, 1.0254e+00,
9.0479e-01],
[-1.0215e+00, 6.0449e-01, 2.7100e-01, -1.0830e+00, -4.3579e-01,
-1.5449e+00, -3.0273e+00, 3.0723e+00, -1.4229e+00, 2.1152e+00,
-2.0488e+00, -2.6777e+00, -8.0750e-02, -2.0273e+00, 1.3652e+00,
-4.0503e-01],
[-2.1606e-01, 2.0742e+00, 7.9102e-02, 7.0264e-01, -3.2012e+00,
2.3750e+00, 3.4863e-01, -7.7100e-01, -1.7764e+00, -2.5488e+00,
1.8008e+00, -9.5605e-01, 4.9463e-01, 1.2915e-01, -1.8164e+00,
1.9248e+00],
[-8.7646e-01, 2.1484e+00, 1.4277e+00, -3.1104e-01, 2.5605e+00,
-1.9521e+00, 1.1562e+00, 7.3975e-01, -1.6904e+00, 1.4980e+00,
-1.2988e-01, -1.4121e+00, 2.1914e+00, -4.8901e-01, -9.9219e-01,
-1.6357e+00],
[ 6.6797e-01, -8.1787e-01, 7.8027e-01, -1.9873e-01, -7.5293e-01,
-3.4414e+00, -3.4131e-01, 1.4180e+00, -2.3242e+00, 4.8828e-01,
1.4514e-01, -3.0117e+00, 2.0820e+00, 4.1431e-01, -2.1250e+00,
-9.0918e-01],
[-5.0537e-01, -1.8384e-01, 3.3911e-01, -6.6211e-01, 1.0762e+00,
-2.5352e+00, 1.3223e+00, 1.2476e-01, 8.8721e-01, 1.1211e+00,
3.8501e-01, 5.2368e-02, -3.0518e-01, 3.4741e-01, 8.4570e-01,
-5.6787e-01],
[ 8.6035e-01, 1.7383e+00, -1.6396e+00, 5.4395e-01, -5.4077e-02,
-1.4863e+00, -6.2256e-01, -1.0254e+00, -4.7949e-01, -9.2773e-01,
6.3232e-02, 8.8086e-01, 1.3789e+00, 3.7441e+00, 2.4395e+00,
3.6572e-01],
[ 7.3047e-01, -9.5215e-03, 1.2354e+00, 1.5498e+00, -2.3945e+00,
-7.7295e-01, -1.3125e+00, -1.6201e+00, -1.7832e+00, -6.4648e-01,
2.0905e-02, 6.6699e-01, 1.3135e+00, -3.2734e+00, -1.8271e+00,
6.0254e-01],
[-5.6836e-01, 5.1270e-01, -8.2861e-01, 8.8867e-01, -5.4395e-01,
-7.8906e-01, -1.0322e+00, 7.8418e-01, -1.1875e+00, -1.5161e-01,
-1.8372e-02, -5.2441e-01, -9.1309e-02, -3.5303e-01, 1.4121e+00,
-2.0449e+00],
[ 1.2051e+00, -9.9731e-02, 3.5693e-01, 6.9922e-01, -1.0342e+00,
1.8965e+00, 7.4036e-02, 3.5229e-01, -1.7871e+00, 1.5508e+00,
-9.4287e-01, 8.8330e-01, -1.0186e+00, -1.8652e+00, 4.2627e-01,
-8.4912e-01],
[-1.4709e-01, 1.8564e+00, 1.8242e+00, -1.0283e+00, 3.3984e-01,
-3.2578e+00, 1.1836e+00, 9.8877e-03, -6.3574e-01, 6.8164e-01,
4.6313e-01, -9.4434e-01, -2.0117e+00, 1.8225e-01, 1.0479e+00,
1.8721e+00],
[ 3.5425e-01, 1.7266e+00, 8.5254e-01, -1.8232e+00, -2.4766e+00,
-8.0518e-01, -7.7051e-01, -7.4316e-01, -3.7329e-01, 2.2207e+00,
-2.3398e+00, 1.3896e+00, -8.3008e-01, 8.3545e-01, 7.3096e-01,
-1.2207e-03],
[-8.0713e-01, 7.3047e-01, 2.6934e+00, 1.6968e-02, 8.2886e-02,
-1.6543e+00, 1.3555e+00, 1.9551e+00, -1.6870e-01, 5.7422e-01,
3.1328e+00, -4.3359e-01, 1.8701e+00, 1.1387e+00, 2.0000e+00,
-8.0029e-01],
[-2.0032e-01, -3.5132e-01, 2.5635e-01, 1.2490e+00, -1.4365e+00,
-1.2031e+00, 1.2246e+00, -8.9648e-01, -1.4023e+00, 1.8633e+00,
1.7959e+00, -2.5605e+00, 2.7891e+00, 4.8022e-01, 6.8750e-01,
1.7588e+00],
[ 6.7676e-01, 2.2441e+00, 1.2598e+00, 1.1221e+00, -1.1016e+00,
-1.9844e+00, -8.5645e-01, -1.4229e+00, -1.0488e+00, 1.0068e+00,
1.8154e+00, -2.3066e+00, -1.7676e-01, -1.9170e+00, 1.8604e-01,
1.5405e-01],
[-3.6182e-01, -6.9238e-01, 1.9287e-02, -4.3921e-01, -7.7393e-01,
3.2959e-02, 1.1611e+00, 2.0684e+00, 6.2598e-01, 7.5684e-03,
-4.0381e-01, 8.0029e-01, -3.8965e-01, 1.3379e-01, 3.2305e+00,
2.9639e-01],
[-6.6309e-01, 7.2656e-01, -2.6001e-01, -7.3633e-01, 3.2715e-01,
-2.8345e-01, 1.1875e+00, -1.1875e+00, -3.6719e+00, 2.2246e+00,
9.6875e-01, 1.1123e+00, 1.2373e+00, 1.3535e+00, 5.5664e-01,
-1.9443e+00],
[-1.7012e+00, -1.6201e+00, -2.4043e+00, 3.2520e+00, 4.2603e-02,
-1.5293e+00, 3.0420e-01, 5.5615e-01, 1.1172e+00, 9.7949e-01,
-1.3818e+00, -1.1533e+00, -5.7910e-01, 4.3213e-02, -2.3730e+00,
4.4287e-01],
[-1.5244e+00, -1.6748e+00, 1.1455e+00, 2.1484e-02, 4.8584e-01,
-1.8047e+00, -1.6328e+00, 2.8848e+00, 1.4404e+00, 1.2500e+00,
-8.8916e-01, -2.4824e+00, -2.0625e+00, -1.5732e+00, -9.1406e-01,
2.0996e-01],
[-1.7053e-01, 1.7590e-01, 1.1074e+00, 2.0898e+00, 2.3965e+00,
-8.0566e-01, 4.7192e-01, -1.6953e+00, 1.3008e+00, -2.2510e-01,
3.8984e+00, -2.7520e+00, -1.5654e+00, -7.1533e-01, -2.2192e-01,
-1.4844e+00],
[-2.7793e+00, -9.0137e-01, 1.5498e+00, 3.8477e-01, 1.8750e+00,
-6.9385e-01, 2.8184e+00, 4.1943e-01, 2.4824e+00, -6.6797e-01,
1.7139e+00, 1.4697e-01, -4.5929e-02, -6.3525e-01, -7.7454e-02,
4.8486e-01]],
[[ 1.7832e+00, -1.2402e+00, 2.8906e+00, -2.1934e+00, -7.4756e-01,
3.8647e-01, -1.1455e+00, 8.1592e-01, 2.2012e+00, 1.5332e+00,
-4.9341e-01, 1.3174e+00, -1.5713e+00, -1.0547e+00, 2.3657e-01,
4.5459e-01],
[-1.7109e+00, 5.5957e-01, -6.5723e-01, 7.0117e-01, 2.0781e+00,
-9.9707e-01, -4.6338e-01, -2.0312e+00, 1.0234e+00, 1.2109e+00,
-1.2471e+00, -1.4062e+00, -2.3359e+00, -3.2617e+00, -1.8145e+00,
1.8524e-02],
[ 1.8086e+00, -1.5469e+00, 2.5254e+00, -1.5000e+00, -2.6777e+00,
-2.2793e+00, -2.3242e+00, 6.9678e-01, -2.8633e+00, -6.8066e-01,
-3.4131e-01, 1.6914e+00, 9.9170e-01, -1.5942e-01, -1.4141e+00,
-6.1914e-01],
[ 4.0234e+00, 6.5039e-01, -6.3623e-01, 2.3613e+00, 3.1982e-01,
2.5508e+00, 2.3789e+00, 4.7852e-01, -4.1367e+00, 9.5410e-01,
4.4897e-01, 2.6523e+00, -1.3223e+00, 1.0596e+00, -1.8203e+00,
9.4629e-01],
[ 1.8271e+00, 1.8174e+00, 9.5654e-01, -9.6680e-01, -9.5898e-01,
8.3740e-02, 9.2676e-01, 4.1943e-01, -2.2012e+00, 1.7266e+00,
-1.3320e+00, -6.0596e-01, -1.2764e+00, -2.1719e+00, -7.3389e-01,
3.6182e-01],
[ 1.7920e+00, 1.0859e+00, 8.7891e-01, -1.6953e+00, 1.3301e+00,
1.8232e+00, 4.5044e-01, -1.2598e+00, 1.3428e-02, -4.6802e-01,
-2.0391e+00, 2.8398e+00, -2.1545e-01, 1.5576e+00, 1.2559e+00,
-2.8203e+00],
[ 2.7344e+00, 2.7422e+00, -4.8706e-01, 6.8555e-01, -2.5171e-01,
1.3311e+00, -9.3701e-01, -2.4414e-01, 2.0801e+00, 2.3267e-01,
-7.9346e-01, 2.3320e+00, -9.4727e-02, -3.8770e-01, -5.2490e-02,
3.0117e+00],
[ 1.9717e+00, 1.8799e-01, -4.9316e-02, 8.6060e-02, 1.1035e+00,
2.7148e+00, 8.3496e-01, 3.8770e-01, -1.6123e+00, -8.4424e-01,
-1.2520e+00, -1.9619e+00, 8.6182e-01, 4.2773e-01, 8.0859e-01,
-3.3105e-01],
[-2.0840e+00, -6.6357e-01, -2.4414e-02, -1.3760e+00, -5.2734e-01,
-8.4570e-01, 1.9717e+00, 2.0312e+00, 2.5742e+00, -1.3062e-02,
5.0293e-01, 8.1055e-01, -1.2585e-01, 7.9492e-01, 5.1660e-01,
-8.5938e-01],
[ 3.0176e+00, -1.9153e-01, -2.2500e+00, 1.0498e+00, -1.3184e+00,
-3.1372e-02, 2.4277e+00, -7.0557e-01, -2.1133e+00, -6.7285e-01,
-1.2588e+00, -3.7031e+00, -1.6318e+00, -1.2109e+00, -1.2197e+00,
-1.1533e+00],
[-9.3994e-02, -1.4824e+00, 2.9336e+00, -1.4883e+00, -5.0781e-01,
-2.4219e+00, 9.0527e-01, -2.2090e+00, 3.9014e-01, -1.6895e+00,
-1.1309e+00, -1.9238e+00, 1.2012e+00, -2.3047e+00, 6.9531e-01,
1.5244e+00],
[-2.6406e+00, 2.8418e-01, -4.4043e-01, 8.5156e-01, 2.2891e+00,
2.1445e+00, 9.2920e-01, -1.1201e+00, 1.5186e+00, -8.3936e-01,
-2.7793e+00, -5.3223e-01, 2.4727e+00, 8.9600e-02, -7.7832e-01,
1.5605e+00],
[-7.5244e-01, -2.3125e+00, 1.1865e+00, 1.0645e+00, -1.8574e+00,
5.9521e-01, 9.7754e-01, 3.0547e+00, -3.2104e-02, -6.2354e-01,
-7.7832e-01, -1.1543e+00, -7.0752e-01, -7.8125e-01, 1.3799e+00,
-1.8203e+00],
[-1.5967e+00, 2.8857e-01, -1.8301e+00, -2.1914e+00, 4.2798e-01,
-1.4453e+00, -1.8652e+00, 4.7388e-01, 1.4814e+00, -2.4922e+00,
3.3936e-02, 6.7773e-01, -9.5801e-01, -1.1504e+00, -4.8340e-01,
3.5801e+00],
[ 2.0078e+00, 7.7759e-02, -1.0176e+00, 1.4014e+00, 5.6836e-01,
3.9233e-01, -1.4636e-01, -7.4707e-01, -1.9736e+00, 1.1885e+00,
-6.4258e-01, -1.6973e+00, -8.3398e-01, -3.0396e-01, 1.0840e+00,
-2.0000e+00],
[-4.0869e-01, 2.3096e-01, -6.3047e+00, 8.6060e-02, -1.5098e+00,
-5.0000e-01, 1.2295e+00, 1.8730e+00, 1.5127e+00, -2.9297e+00,
4.4702e-01, -9.5825e-02, 2.3613e+00, -3.5078e+00, 1.3806e-01,
-2.3486e-01],
[ 2.3926e+00, 6.5332e-01, 8.7012e-01, -8.5449e-02, 6.3574e-01,
-1.7441e+00, 4.5044e-02, 7.9297e-01, 9.7998e-01, 2.5898e+00,
7.9199e-01, 1.6143e+00, -1.2354e+00, -8.2520e-01, -3.9136e-01,
6.4062e-01],
[ 1.3555e+00, 1.4238e+00, -1.0869e+00, -3.5547e-01, 1.9990e+00,
-3.8940e-01, 1.7334e+00, -6.5674e-02, 7.9883e-01, 1.7236e+00,
-2.3621e-01, -2.8574e+00, -1.3486e+00, -1.6152e+00, 8.3252e-01,
-6.3623e-01],
[-8.9258e-01, -2.2534e-01, 1.2305e+00, 1.3232e+00, -9.0479e-01,
-1.4160e-01, -1.1250e+00, 2.5820e+00, 8.5083e-02, -3.7158e-01,
-1.9727e+00, 1.8176e-01, -2.1729e-01, 2.2305e+00, -2.3145e+00,
-1.7178e+00],
[ 4.7363e-01, 1.5283e+00, 2.3359e+00, 1.4961e+00, -4.0967e-01,
9.7900e-01, 1.6670e+00, 3.0840e+00, -1.8311e+00, -2.8564e-02,
-2.8848e+00, 1.5137e+00, 3.5664e+00, -5.6738e-01, 2.4531e+00,
-1.0801e+00],
[-1.3320e+00, 1.6865e+00, -2.6641e+00, -1.4033e+00, -1.1230e+00,
2.0840e+00, -2.7148e+00, 1.5254e+00, -1.3369e+00, -1.1270e+00,
1.7314e+00, -4.9683e-02, -1.0879e+00, 1.4160e-02, 3.2441e+00,
-8.1836e-01],
[ 7.7930e-01, 6.4648e-01, 1.2100e+00, -1.5684e+00, 1.7373e+00,
3.6255e-01, 8.4521e-01, 3.3633e+00, 1.1191e+00, 1.6250e+00,
3.6758e+00, -9.5605e-01, 2.5977e-01, 5.5481e-02, 1.1396e+00,
2.2046e-01],
[ 4.3970e-01, 1.0264e+00, 2.6904e-01, -1.8789e+00, 9.3359e-01,
-1.5088e+00, 3.5332e+00, 1.7725e+00, -1.7383e+00, -4.7656e-01,
-3.0215e+00, -1.0742e+00, 1.5908e+00, 9.1553e-01, -1.5039e+00,
-1.3105e+00],
[ 3.7085e-01, -1.8574e+00, 9.0625e-01, -7.9980e-01, -1.4648e+00,
7.2168e-01, -3.5596e-01, 5.2979e-02, 1.3213e+00, -7.3926e-01,
5.8594e-01, -4.3945e-02, 7.7930e-01, -1.1572e+00, -2.2637e+00,
-1.9248e+00],
[-1.3574e-01, 5.3418e-01, 3.1787e-01, 1.5137e+00, -3.7842e-01,
-2.5117e+00, -5.6738e-01, -1.2236e+00, 4.1836e+00, -4.1602e-01,
-7.0508e-01, 1.4258e+00, -6.3086e-01, -8.8379e-01, -2.3281e+00,
1.8701e-01],
[ 7.7246e-01, -4.6240e-01, 1.1514e+00, 1.9102e+00, -2.5625e+00,
-1.5508e+00, 4.4678e-01, -1.4395e+00, 9.5642e-02, 3.5547e-01,
7.3535e-01, 9.3115e-01, 1.8857e+00, -8.3936e-01, -6.3379e-01,
-2.4336e+00],
[ 2.2383e+00, 9.3701e-01, -1.0381e+00, -1.3164e+00, 1.1631e+00,
-2.0605e-01, -4.2188e-01, -2.2461e-01, 4.6436e-01, 1.6602e+00,
2.9414e+00, 1.6631e+00, -3.8940e-01, -5.6348e-01, -1.3562e-01,
-2.7891e+00],
[ 1.1680e+00, 1.5039e-01, -1.8652e+00, -1.5020e+00, -2.2266e+00,
-1.4971e+00, 7.7197e-01, 1.2471e+00, -2.0386e-01, 9.4141e-01,
1.4551e+00, -1.1045e+00, -3.3032e-01, 1.1016e+00, 1.0762e+00,
-2.4277e+00],
[-1.1230e+00, 7.5586e-01, 1.0234e+00, -1.2666e+00, 2.8984e+00,
1.4512e+00, -2.1934e+00, -1.0479e+00, 1.4043e+00, -5.2197e-01,
1.6809e-01, 8.7402e-01, -1.5303e+00, -1.4355e+00, -8.1421e-02,
1.0127e+00],
[-2.2695e+00, -9.0137e-01, 1.3584e+00, -8.9355e-01, 1.2598e-01,
-4.7998e-01, 1.3926e+00, -6.8604e-01, 8.4326e-01, 1.5664e+00,
-9.4238e-01, -2.7612e-01, -1.3418e+00, 3.1763e-01, -2.4121e+00,
6.5918e-01]]], dtype=torch.float16)]
Let’s see the results.
data = report_cmp.data
df = pandas.DataFrame(data)
piv = df.pivot(index=("run_index", "run_name"), columns="ref_name", values="abs")
print(piv)
ref_name decoder.0 decoder.attention.0 ... decoder.norm_2.0 model.0
run_index run_name ...
1 embedding 3.896484 3.340332 ... 3.148438 3.896484
2 embedding_1 3.543526 3.952759 ... 2.581055 3.543526
3 add_8 1.546875 6.360962 ... 3.443359 1.546875
4 layer_norm 3.529297 3.010254 ... 0.855225 3.529297
5 linear 5.829834 1.977417 ... 3.794922 5.829834
6 linear_1 6.206848 2.011230 ... 3.361328 6.206848
7 linear_2 5.543945 2.264160 ... 3.975586 5.543945
17 matmul_1 6.373596 1.702148 ... 3.105957 6.373596
18 linear_3 6.734619 2.851562 ... 3.710938 6.734619
19 linear_4 5.492188 1.987793 ... 3.340576 5.492188
20 linear_5 5.259766 1.946289 ... 4.120117 5.259766
30 matmul_3 6.466553 1.946289 ... 3.006836 6.466553
32 val_48 6.278046 0.168030 ... 2.981689 6.278046
33 linear_6 6.427368 0.000977 ... 2.949707 6.427368
34 add_115 0.960938 6.239868 ... 3.322266 0.960938
35 layer_norm_1 3.509766 2.949707 ... 0.001465 3.509766
39 val_54 6.189697 1.502197 ... 3.029297 6.189697
40 linear_8 6.116943 1.543945 ... 2.943115 6.116943
41 add_136 0.001953 6.427368 ... 3.509766 0.001953
[19 rows x 6 columns]
Let’s clean a little bit.
piv[piv >= 1] = np.nan
print(piv.dropna(axis=0, how="all"))
ref_name decoder.0 decoder.attention.0 ... decoder.norm_2.0 model.0
run_index run_name ...
4 layer_norm NaN NaN ... 0.855225 NaN
32 val_48 NaN 0.168030 ... NaN NaN
33 linear_6 NaN 0.000977 ... NaN NaN
34 add_115 0.960938 NaN ... NaN 0.960938
35 layer_norm_1 NaN NaN ... 0.001465 NaN
39 val_54 NaN NaN ... NaN NaN
40 linear_8 NaN NaN ... NaN NaN
41 add_136 0.001953 NaN ... NaN 0.001953
[8 rows x 6 columns]
We can identity which results is mapped to which expected tensor.
Picture of the model¶

doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue")

Total running time of the script: (0 minutes 9.365 seconds)
Related examples

Find where a model is failing by running submodels