Note
Go to the end to download the full example code.
Export a model using a custom type as input¶
We will a class used in many model: transformers.cache_utils.DynamicCache
.
First try: it fails¶
import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
class ModelTakingDynamicCacheAsInput(torch.nn.Module):
def forward(self, x, dc):
kc = torch.cat(dc.key_cache, axis=1)
vc = torch.cat(dc.value_cache, axis=1)
y = (kc + vc).sum(axis=2, keepdim=True)
return x + y
Let’s check the model runs.
x = torch.randn(3, 8, 7, 1)
cache = make_dynamic_cache([(torch.ones((3, 8, 5, 6)), (torch.ones((3, 8, 5, 6)) * 2))])
model = ModelTakingDynamicCacheAsInput()
expected = model(x, cache)
print(expected.shape)
torch.Size([3, 8, 7, 6])
Let’s check it works with other shapes.
x = torch.randn(4, 8, 7, 1)
cache = make_dynamic_cache([(torch.ones((4, 8, 11, 6)), (torch.ones((4, 8, 11, 6)) * 2))])
model = ModelTakingDynamicCacheAsInput()
expected = model(x, cache)
print(expected.shape)
torch.Size([4, 8, 7, 6])
Let’s export.
ep = torch.export.export(model, (x, cache))
print(ep.graph)
graph():
%x : [num_users=1] = placeholder[target=x]
%dc_key_cache_0 : [num_users=1] = placeholder[target=dc_key_cache_0]
%dc_value_cache_0 : [num_users=1] = placeholder[target=dc_value_cache_0]
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%dc_key_cache_0], 1), kwargs = {})
%cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%dc_value_cache_0], 1), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %cat_1), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%add, [2], True), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %sum_1), kwargs = {})
return (add_1,)
With dynamic shapes now.
batch = torch.export.Dim("batch", min=1, max=1024)
clength = torch.export.Dim("clength", min=1, max=1024)
try:
ep = torch.export.export(
model,
(x, cache),
dynamic_shapes=({0: batch}, [[{0: batch, 2: clength}], [{0: batch, 2: clength}]]),
)
print(ep.graph)
failed = False
except Exception as e:
print("FAILS:", e)
failed = True
graph():
%x : [num_users=1] = placeholder[target=x]
%dc_key_cache_0 : [num_users=1] = placeholder[target=dc_key_cache_0]
%dc_value_cache_0 : [num_users=1] = placeholder[target=dc_value_cache_0]
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%dc_key_cache_0], 1), kwargs = {})
%cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%dc_value_cache_0], 1), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %cat_1), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%add, [2], True), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %sum_1), kwargs = {})
return (add_1,)
If it failed, let’s understand why.
if failed:
class Model(torch.nn.Module):
def forward(self, dc):
kc = dc.key_cache[0]
vc = dc.value_cache[0]
return kc + vc
ep = torch.export.export(
Model(),
(cache,),
dynamic_shapes={"dc": [[{0: batch, 2: clength}], [{0: batch, 2: clength}]]},
)
for node in ep.graph.nodes:
print(f"{node.name} -> {node.meta.get('val', '-')}")
# it prints out ``dc_key_cache_0 -> FakeTensor(..., size=(4, 8, 11, 6))``
# but it should be ``dc_key_cache_0 -> FakeTensor(..., size=(s0, 8, s1, 6))``
Total running time of the script: (0 minutes 0.173 seconds)
Related examples

to_onnx and a custom operator registered with a function
to_onnx and a custom operator registered with a function