Note
Go to the end to download the full example code.
Dumps intermediate results of a torch model¶
Looking for discrepancies is quickly annoying. Discrepancies
come from two results obtained with the same models
implemented in two different ways, pytorch and onnx.
Models are big so where do they come from? That’s the
unavoidable question. Unless there is an obvious reason,
the only way is to compare intermediate outputs alon the computation.
The first step into that direction is to dump the intermediate results
coming from pytorch.
We use onnx_diagnostic.helpers.torch_helper.steal_forward()
for that.
A simple LLM Model¶
See onnx_diagnostic.helpers.torch_helper.dummy_llm()
for its definition. It is mostly used for unit test or example.
import numpy as np
import pandas
import onnx
import torch
import onnxruntime
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ReportResultComparison
model, inputs, ds = dummy_llm(dynamic_shapes=True)
We use float16.
model = model.to(torch.float16)
Let’s check.
print(f"type(model)={type(model)}")
print(f"inputs={string_type(inputs, with_shape=True)}")
print(f"ds={string_type(ds, with_shape=True)}")
type(model)=<class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.LLM'>
inputs=(T7s2x30,)
ds=dict(input_ids:{0:Dim(batch),1:Dim(length)})
It contains the following submodules.
for name, mod in model.named_modules():
print(f"- {name}: {type(mod)}")
- : <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.LLM'>
- embedding: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.Embedding'>
- embedding.embedding: <class 'torch.nn.modules.sparse.Embedding'>
- embedding.pe: <class 'torch.nn.modules.sparse.Embedding'>
- decoder: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.DecoderLayer'>
- decoder.attention: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.MultiAttentionBlock'>
- decoder.attention.attention: <class 'torch.nn.modules.container.ModuleList'>
- decoder.attention.attention.0: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.AttentionBlock'>
- decoder.attention.attention.0.query: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.0.key: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.0.value: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.AttentionBlock'>
- decoder.attention.attention.1.query: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1.key: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.attention.1.value: <class 'torch.nn.modules.linear.Linear'>
- decoder.attention.linear: <class 'torch.nn.modules.linear.Linear'>
- decoder.feed_forward: <class 'onnx_diagnostic.helpers.torch_helper.dummy_llm.<locals>.FeedForward'>
- decoder.feed_forward.linear_1: <class 'torch.nn.modules.linear.Linear'>
- decoder.feed_forward.relu: <class 'torch.nn.modules.activation.ReLU'>
- decoder.feed_forward.linear_2: <class 'torch.nn.modules.linear.Linear'>
- decoder.norm_1: <class 'torch.nn.modules.normalization.LayerNorm'>
- decoder.norm_2: <class 'torch.nn.modules.normalization.LayerNorm'>
Steal and dump the output of submodules¶
The following context spies on the intermediate results for the following module and submodules. It stores in one onnx file all the input/output for those.
with steal_forward(
[
("model", model),
("model.decoder", model.decoder),
("model.decoder.attention", model.decoder.attention),
("model.decoder.feed_forward", model.decoder.feed_forward),
("model.decoder.norm_1", model.decoder.norm_1),
("model.decoder.norm_2", model.decoder.norm_2),
],
dump_file="plot_dump_intermediate_results.inputs.onnx",
verbose=1,
storage_limit=2**28,
):
expected = model(*inputs)
+model -- stolen forward for class LLM -- iteration 0
<- args=(T7s2x30,) --- kwargs={}
+model.decoder -- stolen forward for class DecoderLayer -- iteration 0
<- args=(T10s2x30x16,) --- kwargs={}
+model.decoder.norm_1 -- stolen forward for class LayerNorm -- iteration 0
<- args=(T10s2x30x16,) --- kwargs={}
-> T10s2x30x16
-model.decoder.norm_1.
-- stores key=('model.decoder.norm_1', 0), size 1Kb -- T10s2x30x16
+model.decoder.attention -- stolen forward for class MultiAttentionBlock -- iteration 0
<- args=(T10s2x30x16,) --- kwargs={}
-> T10s2x30x16
-model.decoder.attention.
-- stores key=('model.decoder.attention', 0), size 1Kb -- T10s2x30x16
+model.decoder.norm_2 -- stolen forward for class LayerNorm -- iteration 0
<- args=(T10s2x30x16,) --- kwargs={}
-> T10s2x30x16
-model.decoder.norm_2.
-- stores key=('model.decoder.norm_2', 0), size 1Kb -- T10s2x30x16
+model.decoder.feed_forward -- stolen forward for class FeedForward -- iteration 0
<- args=(T10s2x30x16,) --- kwargs={}
-> T10s2x30x16
-model.decoder.feed_forward.
-- stores key=('model.decoder.feed_forward', 0), size 1Kb -- T10s2x30x16
-> T10s2x30x16
-model.decoder.
-- stores key=('model.decoder', 0), size 1Kb -- T10s2x30x16
-> T10s2x30x16
-model.
-- stores key=('model', 0), size 1Kb -- T10s2x30x16
-- gather stored 12 objects, size=0 Mb
-- dumps stored objects
-- done dump stored objects
Restores saved inputs/outputs¶
All the intermediate tensors were saved in one unique onnx model,
every tensor is stored in a constant node.
The model can be run with any runtime to restore the inputs
and function create_input_tensors_from_onnx_model
can restore their names.
saved_tensors = create_input_tensors_from_onnx_model(
"plot_dump_intermediate_results.inputs.onnx"
)
for k, v in saved_tensors.items():
print(f"{k} -- {string_type(v, with_shape=True)}")
('model', 0, 'I') -- ((T7s2x30,),{})
('model.decoder', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_1', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_1', 0, 'O') -- T10s2x30x16
('model.decoder.attention', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.attention', 0, 'O') -- T10s2x30x16
('model.decoder.norm_2', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.norm_2', 0, 'O') -- T10s2x30x16
('model.decoder.feed_forward', 0, 'I') -- ((T10s2x30x16,),{})
('model.decoder.feed_forward', 0, 'O') -- T10s2x30x16
('model.decoder', 0, 'O') -- T10s2x30x16
('model', 0, 'O') -- T10s2x30x16
Let’s explained the naming convention.
('model.decoder.norm_2', 0, 'I') -- ((T1s2x30x16,),{})
| | |
| | +--> input, the format is args, kwargs
| |
| +--> iteration, 0 means the first time the execution
| went through that module
| it is possible to call multiple times,
| the model to store more
|
+--> the name given to function steal_forward
The same goes for output except 'I'
is replaced by 'O'
.
('model.decoder.norm_2', 0, 'O') -- T1s2x30x16
This trick can be used to compare intermediate results coming from pytorch to any other implementation of the same model as long as it is possible to map the stored inputs/outputs.
Conversion to ONNX¶
The difficult point is to be able to map the saved intermediate results to intermediate results in ONNX. Let’s create the ONNX model.
epo = torch.onnx.export(model, inputs, dynamic_shapes=ds, dynamo=True)
epo.optimize()
epo.save("plot_dump_intermediate_results.onnx")
[torch.onnx] Obtain model graph for `LLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `LLM([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 4 of general pattern rewrite rules.
Discrepancies¶
We have a torch model, intermediate results and an ONNX graph equivalent to the torch model. Let’s see how we can check the discrepancies. First the discrepancies of the whole model.
sess = onnxruntime.InferenceSession(
"plot_dump_intermediate_results.onnx", providers=["CPUExecutionProvider"]
)
feeds = dict(
zip([i.name for i in sess.get_inputs()], [t.detach().cpu().numpy() for t in inputs])
)
got = sess.run(None, feeds)
diff = max_diff(expected, got)
print(f"discrepancies torch/ORT: {string_diff(diff)}")
discrepancies torch/ORT: abs=0.00390625, rel=0.034148137590269856, n=960.0,amax=0,13,5
What about intermediate results? Let’s use a runtime still based on onnxruntime running an eager evaluation.
sess_eager = OnnxruntimeEvaluator(
"plot_dump_intermediate_results.onnx",
providers=["CPUExecutionProvider"],
torch_or_numpy=True,
)
feeds_tensor = dict(zip([i.name for i in sess.get_inputs()], inputs))
got = sess_eager.run(None, feeds_tensor)
diff = max_diff(expected, got)
print(f"discrepancies torch/eager ORT: {string_diff(diff)}")
discrepancies torch/eager ORT: abs=0.001953125, rel=0.015601597603594607, n=960.0,amax=0,3,5
They are almost the same. That’s good. Let’s now dig into the intermediate results. They are compared to the outputs stored in saved_tensors during the execution of the model.
baseline = {}
for k, v in saved_tensors.items():
if k[-1] == "I": # inputs are excluded
continue
if isinstance(v, torch.Tensor):
baseline[f"{k[0]}.{k[1]}".replace("model.decoder", "decoder")] = v
report_cmp = ReportResultComparison(baseline)
sess_eager.run(None, feeds_tensor, report_cmp=report_cmp)
[tensor([[[-7.5928e-01, 6.1914e-01, 1.7402e+00, 2.6245e-01, 2.3438e-02,
1.3721e+00, -2.8467e-01, -1.9141e+00, -9.2871e-01, 1.3223e+00,
1.6768e+00, 2.2227e+00, -9.5166e-01, -4.0405e-01, -3.9697e-01,
-5.2393e-01],
[-1.7129e+00, 5.5176e-01, 1.1904e+00, 8.2080e-01, -1.1729e+00,
-2.5000e-01, 1.6777e+00, 3.7500e-01, -2.6230e+00, -2.0254e+00,
-1.0889e+00, 7.2168e-01, 2.1953e+00, 2.3926e-01, -3.4729e-02,
1.0215e+00],
[-5.8838e-01, -9.3457e-01, -6.7090e-01, 6.8945e-01, -4.1992e-02,
3.8184e-01, 3.2593e-01, 3.0195e+00, 1.5986e+00, -1.8066e-02,
-6.8604e-02, 3.0195e+00, -3.1113e+00, 1.7842e+00, -3.4277e-01,
-1.4961e+00],
[-1.9092e+00, -1.0273e+00, 1.0234e+00, 2.1641e+00, 6.2256e-01,
-2.6895e+00, 1.1963e-01, 6.3281e-01, 8.2764e-01, -1.4551e+00,
-5.7227e-01, 2.6099e-01, 2.7031e+00, 3.0547e+00, 4.1504e-01,
-7.0312e-01],
[ 9.5801e-01, -1.7578e+00, -6.0010e-01, -1.4580e+00, 2.1621e+00,
8.6523e-01, -1.2715e+00, 4.9829e-01, -1.0181e-01, -8.5010e-01,
-2.3496e+00, 1.4526e-01, -7.6465e-01, 3.1406e+00, 2.3379e+00,
-7.1289e-01],
[-1.3389e+00, 7.5391e-01, -1.0361e+00, 9.5068e-01, -2.9805e+00,
-9.8535e-01, -1.9180e+00, -3.2007e-01, 7.1680e-01, -8.5645e-01,
-5.5713e-01, -9.3945e-01, 2.3999e-01, -1.5732e+00, -1.5020e+00,
5.0354e-02],
[ 1.7891e+00, 4.5557e-01, -1.9805e+00, -1.4033e+00, -1.6553e+00,
7.9150e-01, -1.8281e+00, -2.8955e-01, 7.5684e-03, 1.2471e+00,
-3.8574e-02, 2.2109e+00, -1.5459e+00, -1.4883e+00, 2.5977e+00,
-1.1416e+00],
[-1.3223e+00, 2.4878e-01, 1.7246e+00, 1.0479e+00, 5.0391e-01,
-7.4219e-01, -3.2676e+00, 8.5693e-02, -1.2402e+00, 8.9160e-01,
-1.0186e+00, -1.7832e+00, 5.8398e-01, -1.0479e+00, 3.0762e-01,
-2.7324e+00],
[ 2.7954e-02, -7.7686e-01, -1.1553e+00, 4.8340e-02, -1.7754e+00,
4.6704e-01, -1.0586e+00, -1.3794e-01, 4.4873e-01, -8.4473e-01,
-7.0654e-01, -4.4336e-01, -1.9072e+00, 2.2583e-02, 1.9062e+00,
-1.6807e+00],
[ 2.3203e+00, -1.2334e+00, 1.2598e-01, -1.4375e+00, 1.1707e-01,
-6.2402e-01, -2.0898e-01, 1.6172e+00, 1.7031e+00, -4.2139e-01,
1.0254e+00, -1.3965e-01, 7.5098e-01, -1.7646e+00, 1.5967e+00,
-1.9150e+00],
[ 7.6318e-01, -9.8633e-01, 1.1060e-01, -5.0146e-01, 1.2754e+00,
7.2949e-01, -2.2422e+00, -6.7871e-01, 1.9004e+00, -9.9316e-01,
-2.0078e+00, 3.0137e+00, -3.9429e-01, -1.5215e+00, -1.1279e+00,
-6.5039e-01],
[ 8.8196e-03, 1.3926e+00, -6.1230e-01, -7.5781e-01, -1.4893e-02,
-9.8730e-01, -5.3418e-01, -1.0381e+00, -1.5869e-01, -1.8125e+00,
1.9316e+00, -5.2612e-02, 1.6094e+00, 8.1689e-01, 9.8975e-01,
4.3115e-01],
[ 1.5449e+00, -1.5801e+00, -1.1533e+00, 5.3662e-01, 9.5276e-02,
-1.1230e+00, -1.4124e-01, 9.7363e-01, -2.2168e-01, -4.7412e-01,
-3.2520e-01, 1.7273e-01, 3.3594e-01, -1.9561e+00, 2.4170e-02,
-1.3379e+00],
[-1.6670e+00, 1.5840e+00, -2.3066e+00, 5.9082e-02, 2.1045e-01,
4.6562e+00, -2.6094e+00, 1.1074e+00, 1.4189e+00, 1.2539e+00,
4.5020e-01, 3.2871e+00, -1.9531e+00, 6.5430e-01, 1.6174e-01,
1.1045e+00],
[ 2.8398e+00, -1.8467e+00, -1.9031e-01, 4.4214e-01, -6.4160e-01,
9.6973e-01, 2.1230e+00, -1.9287e+00, 5.5469e-01, 1.5127e+00,
-3.3027e+00, -1.1113e+00, 2.3816e-01, -8.4961e-01, -1.0693e-01,
-1.1318e+00],
[ 2.2900e-01, -4.0469e+00, 4.9756e-01, -1.8535e+00, -4.8730e-01,
1.0156e+00, 1.2588e+00, -7.6758e-01, -1.7764e+00, 1.3350e+00,
-2.3242e-01, -7.0264e-01, 1.4971e+00, 1.7627e+00, -3.9551e-01,
-2.1758e+00],
[ 1.7891e+00, 1.0879e+00, 8.1934e-01, 1.2617e+00, -2.2188e+00,
8.7598e-01, 6.4355e-01, 2.9419e-02, 3.9331e-01, 3.1445e+00,
-9.3945e-01, 1.5637e-01, -1.9463e+00, -2.3574e+00, -3.9805e+00,
1.6279e+00],
[ 1.3838e+00, -1.4365e+00, 5.7959e-01, 1.2236e+00, -1.4043e+00,
-2.7124e-01, 4.7363e-01, -4.2725e-03, 8.8281e-01, 1.1270e+00,
-6.3330e-01, 1.7568e+00, -9.9951e-01, -7.1484e-01, -2.0156e+00,
-3.3472e-01],
[ 2.1350e-01, 3.4424e-02, -1.5469e+00, 1.3838e+00, 8.4863e-01,
1.2129e+00, -1.3691e+00, 9.7949e-01, -3.6841e-01, -3.0566e-01,
-2.7422e+00, 1.0605e+00, -1.3447e+00, -5.8008e-01, 7.1826e-01,
-1.1025e+00],
[-2.4670e-01, 7.3389e-01, 1.1777e+00, -4.5020e-01, -3.9502e-01,
-7.1826e-01, 3.3765e-01, -3.6475e-01, -2.8066e+00, 1.2559e+00,
-3.0640e-01, -4.1577e-01, 5.3418e-01, -8.6426e-01, -8.5059e-01,
1.7910e+00],
[ 1.0547e+00, 2.2441e+00, -1.3848e+00, 2.7891e+00, -1.0156e+00,
-4.7607e-01, -8.4863e-01, -1.5605e+00, -2.3242e-01, -1.4805e+00,
2.2485e-01, -1.0732e+00, 1.3428e-01, -1.6479e-01, -1.1045e+00,
-1.1807e+00],
[ 5.6396e-01, 5.9766e-01, -9.2969e-01, -2.3652e+00, 7.0166e-01,
2.6113e+00, -1.3379e+00, -1.0098e+00, -2.0435e-01, 4.5508e-01,
3.9233e-01, -9.6777e-01, -2.6562e+00, 5.5518e-01, -3.1787e-01,
-3.9258e-01],
[-4.2041e-01, 1.0371e+00, -5.0781e-01, -4.8022e-01, -5.3076e-01,
-1.6006e+00, 1.0596e+00, -9.5557e-01, 1.9131e+00, -2.4976e-01,
-1.2061e+00, -7.9492e-01, 1.9189e+00, -1.2207e+00, 1.5479e-01,
-9.6289e-01],
[-5.4785e-01, -1.3447e+00, 2.9272e-01, 3.2422e+00, 2.4727e+00,
6.4600e-01, -7.8186e-02, 5.6982e-01, -3.1445e-01, 9.2188e-01,
-3.2559e+00, 8.2812e-01, 8.7988e-01, -1.4717e+00, -1.5195e+00,
1.2988e+00],
[-3.8086e-02, 4.9219e-01, -1.7354e+00, 4.8889e-02, -3.0762e-01,
2.4573e-01, -1.8740e+00, 8.9258e-01, -6.6699e-01, -1.4375e+00,
-2.2383e+00, -1.7070e+00, 5.7422e-01, -7.8027e-01, 2.8125e+00,
1.4629e+00],
[-5.1855e-01, -7.0312e-01, 1.7490e+00, 2.0391e+00, 2.7207e+00,
-1.9170e+00, -1.4277e+00, -1.4111e+00, -1.5518e+00, 2.0361e-01,
1.0781e+00, 4.9927e-01, 1.6816e+00, -8.0859e-01, -2.8359e+00,
2.3608e-01],
[ 5.0537e-01, 6.4087e-02, -1.7383e+00, -1.7852e+00, -4.9316e-02,
3.8477e-01, 2.3262e+00, -1.2773e+00, -2.1367e+00, -1.8135e+00,
-1.7737e-01, -1.4785e+00, -2.1250e+00, 1.8623e+00, -2.1895e+00,
4.7241e-01],
[-1.2471e+00, 1.2881e+00, 5.4590e-01, 1.0840e+00, 1.6553e-01,
8.1201e-01, -4.6692e-02, 9.8535e-01, 6.8262e-01, 2.1133e+00,
-8.5107e-01, -1.8105e+00, -3.7939e-01, 2.3320e+00, -5.2734e-01,
2.4451e-01],
[ 7.6465e-01, -2.1250e+00, 8.3740e-01, -1.1738e+00, 1.1172e+00,
-2.1602e+00, -7.2388e-02, 1.4590e+00, 1.2080e+00, -1.5352e+00,
7.8027e-01, -1.1738e+00, 9.9805e-01, 6.0107e-01, -1.0486e-01,
-9.4141e-01],
[-1.1318e+00, -5.7031e-01, -7.1826e-01, 6.6748e-01, -2.3755e-01,
-5.3027e-01, 8.7500e-01, 2.1133e+00, 5.3369e-01, -7.3145e-01,
-9.4824e-01, -1.8535e+00, 5.4443e-01, 1.5762e+00, 1.0166e+00,
-1.5479e-01]],
[[ 1.1484e+00, 2.1133e+00, -6.3281e-01, 2.4902e+00, -1.0205e+00,
2.2930e+00, -2.3267e-01, 3.0786e-01, 8.1543e-02, 5.7031e-01,
-2.0781e+00, 2.0195e+00, -5.4883e-01, -2.2891e+00, -2.3398e+00,
-1.8223e+00],
[-3.8242e+00, -2.5781e-01, 6.3477e-01, 1.9199e+00, 1.3701e+00,
3.6670e-01, -2.8271e-01, -2.0227e-01, -1.4951e+00, 1.8323e-01,
1.2529e+00, 6.8848e-01, -4.9072e-02, 1.6055e+00, -6.7480e-01,
2.2344e+00],
[ 1.3643e+00, -3.0117e+00, -2.9590e+00, 1.7314e+00, 4.6606e-01,
3.1758e+00, -1.4941e+00, -6.9629e-01, -1.9521e+00, 1.4087e-01,
4.4849e-01, 2.6035e+00, 9.9756e-01, -3.8135e-01, -1.6963e+00,
-7.7832e-01],
[-4.8218e-01, -2.0781e+00, -1.5059e+00, -2.1367e+00, 3.9258e-01,
-1.0820e+00, -8.5059e-01, -1.6309e-01, -1.9629e-01, -1.3594e+00,
1.1240e+00, 1.1113e+00, -5.4053e-01, 3.8867e-01, -7.4268e-01,
9.2773e-03],
[-1.6328e+00, 3.0664e+00, 1.6289e+00, 3.9697e-01, -2.9028e-01,
1.3877e+00, -2.0020e-02, -1.2920e+00, -1.4707e+00, 6.6895e-01,
-1.9736e+00, 1.0967e+00, 1.3477e-01, 1.4355e+00, 2.6807e-01,
3.9575e-01],
[ 5.7520e-01, 2.9883e-01, -8.0225e-01, -8.8623e-01, -2.2742e-01,
-2.0840e+00, -2.4004e+00, 4.7656e+00, -1.6855e+00, -1.9395e+00,
-1.2578e+00, 1.0898e+00, -5.8789e-01, 5.1025e-01, 2.0195e+00,
2.0801e+00],
[-1.2422e+00, 1.9961e+00, 1.9263e-01, -9.8145e-01, -8.9160e-01,
-2.5488e+00, 1.3125e+00, 4.2285e-01, 2.6016e+00, -1.0781e+00,
1.1172e+00, 1.3252e+00, 6.9189e-01, 1.0537e+00, 9.0234e-01,
-1.9580e-01],
[ 2.8730e+00, 7.9346e-02, -5.0391e-01, 7.4121e-01, 5.8936e-01,
2.1855e+00, -1.3604e+00, 4.7363e-01, 1.4268e+00, 6.1426e-01,
-1.8457e+00, 9.9170e-01, -8.4766e-01, -1.4121e+00, 1.0195e+00,
-2.0850e-01],
[ 6.7285e-01, -7.3926e-01, -5.2612e-02, -5.3711e-03, 1.7461e+00,
9.4580e-01, 6.2109e-01, 1.2769e-01, -2.3560e-01, -6.5723e-01,
-1.1230e+00, 1.8906e+00, 4.9316e-01, 1.4795e+00, -1.5186e+00,
6.6504e-01],
[-6.5430e-02, 4.5679e-01, 4.4385e-01, 4.3086e+00, 2.6709e-01,
-1.4863e+00, 2.2266e+00, 6.8506e-01, 1.2686e+00, 4.2114e-01,
-1.0547e+00, 1.2246e+00, 2.6758e+00, -5.9668e-01, 2.9824e+00,
1.7764e+00],
[ 2.7266e+00, 3.3838e-01, 1.7200e-01, -1.9844e+00, 1.1224e-01,
-1.1602e+00, -8.4326e-01, -1.6016e+00, -4.6021e-01, -2.9517e-01,
-2.0273e+00, -1.3477e+00, 2.1399e-01, 1.9666e-01, 2.9609e+00,
-1.2539e+00],
[ 2.5957e+00, 1.9666e-01, 9.6484e-01, 1.4004e+00, -3.4521e-01,
1.9473e+00, -2.7124e-01, 1.1582e+00, -2.6836e+00, 2.1074e+00,
-2.3203e+00, 3.0078e+00, 7.6709e-01, -1.4551e-01, -2.2969e+00,
-1.3438e+00],
[-4.8950e-01, -8.5107e-01, 3.0742e+00, -8.2129e-01, -7.1289e-02,
-1.4482e+00, 2.0547e+00, -1.5225e+00, -1.0566e+00, -1.8203e+00,
1.1875e+00, -5.8472e-02, 2.4004e+00, -3.1543e+00, 2.4922e+00,
-9.1992e-01],
[ 4.9414e+00, -3.8635e-02, 1.2031e+00, 8.7793e-01, 2.1912e-01,
7.0850e-01, 6.7236e-01, -3.7451e-01, -8.2031e-01, 7.5684e-01,
7.2559e-01, -1.3711e+00, -1.1787e+00, -2.1621e+00, 2.1406e+00,
6.6797e-01],
[-2.0762e+00, 2.5879e-01, 4.0391e+00, 3.4692e-01, 8.8623e-01,
1.7236e-01, 9.9951e-01, 1.8271e+00, 1.0762e+00, 8.4521e-01,
-3.0176e-01, -9.1650e-01, -1.1816e+00, -6.1621e-01, 5.3613e-01,
-4.3872e-01],
[ 2.0820e+00, -8.4717e-01, 1.8494e-01, 6.0742e-01, 1.1096e-01,
1.5039e+00, 6.2402e-01, -1.0156e+00, -1.1152e+00, -3.0000e+00,
4.0356e-01, -5.3125e-01, -1.6611e+00, -3.1567e-01, -1.4199e+00,
-3.4741e-01],
[ 3.6475e-01, -3.1006e-01, 4.0991e-01, 2.6611e-02, -4.6069e-01,
6.0840e-01, 9.3555e-01, 2.1914e+00, -2.4395e+00, -2.9648e+00,
1.9971e+00, -7.9102e-02, -4.3945e-01, 9.8938e-02, 8.7219e-02,
-8.1348e-01],
[ 2.7930e+00, 9.7754e-01, -1.1133e+00, -9.7559e-01, -1.8037e+00,
-7.0312e-01, 4.4336e-01, -1.9756e+00, -9.2383e-01, 1.1436e+00,
-2.8242e+00, -5.9570e-02, -6.0596e-01, -1.2012e+00, 8.5059e-01,
-1.6895e-01],
[-9.6069e-02, 2.9961e+00, 1.7295e+00, -1.1133e+00, -7.6074e-01,
1.8848e+00, -2.6484e+00, -1.4492e+00, -1.3525e+00, 1.3936e+00,
-4.1919e-01, 2.5000e+00, 7.5977e-01, -3.4473e+00, 2.0371e+00,
-4.3750e-01],
[ 2.0371e+00, 1.1934e+00, -1.2344e+00, 1.3037e-01, 2.6680e+00,
-4.6387e-01, -7.3926e-01, -4.5898e-01, -2.1934e+00, 8.2715e-01,
2.0068e-01, 2.6934e+00, 3.0859e-01, -3.8135e-01, -8.4863e-01,
1.5762e+00],
[ 1.4685e-01, -2.7480e+00, 3.1885e-01, -1.0293e+00, 7.1143e-01,
-7.6270e-01, -1.0332e+00, -1.9824e+00, 4.1699e-01, 1.6973e+00,
-1.1934e+00, 1.5459e+00, -4.1943e-01, 2.0059e+00, -6.9238e-01,
1.0107e+00],
[ 7.4219e-01, -1.1494e+00, -1.9688e+00, 6.5527e-01, -1.7959e+00,
2.2676e+00, -3.5430e+00, 1.1748e+00, 7.5439e-02, 1.8945e+00,
-1.1230e+00, 1.1768e+00, 2.1621e+00, -2.1113e+00, -1.1973e+00,
-3.4355e+00],
[ 2.8652e+00, -1.4697e+00, 2.3203e+00, -2.9272e-01, 7.5586e-01,
1.9512e+00, -3.4351e-01, -2.6514e-01, -1.1768e+00, 1.1182e+00,
-3.8989e-01, -1.1436e+00, -8.6304e-02, -1.0312e+00, -8.4668e-01,
-1.4375e+00],
[-3.6133e-01, -4.5349e-02, 2.5020e+00, 1.1406e+00, 2.1055e+00,
1.6943e-01, 2.0056e-01, 5.6836e-01, -9.3323e-02, 2.7441e-01,
-7.6660e-02, 1.0225e+00, 4.2656e+00, 2.4453e+00, 1.4277e+00,
-1.7178e+00],
[-9.3213e-01, -2.2031e+00, -6.4648e-01, -2.4473e+00, 2.4524e-01,
-1.2139e+00, 1.6191e+00, -3.3008e-01, 5.6445e-01, 5.8398e-01,
1.0869e+00, 1.2402e-01, -4.2676e-01, 3.2397e-01, -7.9248e-01,
-1.4424e+00],
[-1.7217e+00, -3.2178e-01, 7.8906e-01, -7.0020e-01, -1.1250e+00,
6.2402e-01, -1.1395e-01, 2.5195e+00, -5.3223e-02, 1.2891e+00,
-2.3418e+00, -8.7891e-01, -6.9092e-02, -7.7344e-01, -1.5986e+00,
-1.4912e+00],
[ 8.7109e-01, -8.4619e-01, -1.6797e-01, 1.5361e+00, 1.0977e+00,
1.0977e+00, -8.9990e-01, -1.5586e+00, 1.5771e+00, -2.6221e-01,
8.7061e-01, -2.0332e+00, -7.7539e-01, 9.2285e-01, 1.8877e+00,
3.7754e+00],
[ 1.3408e+00, -6.1426e-01, 9.7949e-01, -3.6865e-01, 1.7451e+00,
4.7913e-02, 9.4141e-01, 6.6772e-02, 6.2500e-01, -6.2988e-01,
1.1084e+00, -3.3008e+00, -7.9688e-01, 4.2163e-01, -1.3789e+00,
2.7734e+00],
[ 3.8818e-02, -2.5215e+00, 2.1406e+00, -4.9829e-01, -5.0146e-01,
-8.7207e-01, 3.3906e+00, 7.4512e-01, -3.4961e-01, -1.0918e+00,
-9.8145e-01, 2.4487e-01, 1.5049e+00, -6.6797e-01, 3.0059e+00,
-2.5547e+00],
[-8.7158e-02, -1.8125e+00, -3.5684e+00, 4.5679e-01, -2.0723e+00,
-6.5430e-01, 1.0000e+00, -8.9844e-02, -2.9688e+00, 9.8486e-01,
2.0137e+00, -7.7344e-01, 2.0449e+00, 8.9111e-01, -1.4248e+00,
-3.7646e-01]]], dtype=torch.float16)]
Let’s see the results.
ref_name decoder.0 decoder.attention.0 decoder.feed_forward.0 decoder.norm_1.0 decoder.norm_2.0 model.0
run_index run_name
1 embedding 3.610352 3.648682 3.727091 2.897949 2.694824 3.610352
2 embedding_1 3.890381 3.751709 3.790283 2.792480 2.880859 3.890381
3 add_8 1.029297 4.624146 4.740356 2.097656 2.099609 1.029297
4 layer_norm 2.447266 2.893860 3.274902 0.000000 0.565918 2.447266
5 linear 4.568848 1.706055 1.953369 3.262695 3.194336 4.568848
6 linear_1 5.594238 2.059937 2.334961 3.664062 3.698242 5.594238
7 linear_2 6.108398 1.869141 1.727783 3.930664 3.985352 6.108398
24 matmul_1 5.087769 1.447021 1.320801 3.236816 3.227051 5.087769
25 linear_3 5.139648 1.672363 1.940796 3.065430 3.053711 5.139648
26 linear_4 5.377441 1.771362 2.413574 3.622070 3.676758 5.377441
27 linear_5 5.612305 1.799927 1.863525 3.434570 3.489258 5.612305
42 matmul_3 5.140137 1.373413 1.651855 3.126221 3.116455 5.140137
44 val_48 4.887329 0.145508 1.208618 2.793091 2.796997 4.887329
45 linear_6 4.795898 0.000488 1.121338 2.893860 2.897766 4.795898
46 add_115 0.967285 4.768677 4.884888 2.175781 2.177734 0.967285
47 layer_norm_1 2.388672 2.897766 3.208496 0.566162 0.001953 2.388672
51 val_54 4.840149 1.115479 0.078125 3.197266 3.130859 4.840149
52 linear_8 4.912292 1.121338 0.000488 3.274902 3.208496 4.912292
53 add_136 0.001953 4.796021 4.912231 2.447266 2.388672 0.001953
Let’s clean a little bit.
piv[piv >= 1] = np.nan
print(piv.dropna(axis=0, how="all"))
ref_name decoder.0 decoder.attention.0 decoder.feed_forward.0 decoder.norm_1.0 decoder.norm_2.0 model.0
run_index run_name
4 layer_norm NaN NaN NaN 0.000000 0.565918 NaN
44 val_48 NaN 0.145508 NaN NaN NaN NaN
45 linear_6 NaN 0.000488 NaN NaN NaN NaN
46 add_115 0.967285 NaN NaN NaN NaN 0.967285
47 layer_norm_1 NaN NaN NaN 0.566162 0.001953 NaN
51 val_54 NaN NaN 0.078125 NaN NaN NaN
52 linear_8 NaN NaN 0.000488 NaN NaN NaN
53 add_136 0.001953 NaN NaN NaN NaN 0.001953
We can identity which results is mapped to which expected tensor.
Picture of the model¶

doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue")

Total running time of the script: (0 minutes 7.864 seconds)
Related examples

Find where a model is failing by running submodels