Export a model using a custom type as input

We will a class used in many model: transformers.cache_utils.DynamicCache.

First try: it fails

import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache


class ModelTakingDynamicCacheAsInput(torch.nn.Module):
    def forward(self, x, dc):
        kc = torch.cat(dc.key_cache, axis=1)
        vc = torch.cat(dc.value_cache, axis=1)
        y = (kc + vc).sum(axis=2, keepdim=True)
        return x + y

Let’s check the model runs.

x = torch.randn(3, 8, 7, 1)
cache = make_dynamic_cache([(torch.ones((3, 8, 5, 6)), (torch.ones((3, 8, 5, 6)) * 2))])

model = ModelTakingDynamicCacheAsInput()
expected = model(x, cache)

print(expected.shape)
torch.Size([3, 8, 7, 6])

Let’s check it works with other shapes.

x = torch.randn(4, 8, 7, 1)
cache = make_dynamic_cache([(torch.ones((4, 8, 11, 6)), (torch.ones((4, 8, 11, 6)) * 2))])

model = ModelTakingDynamicCacheAsInput()
expected = model(x, cache)

print(expected.shape)
torch.Size([4, 8, 7, 6])

Let’s export.

ep = torch.export.export(model, (x, cache))
print(ep.graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %dc_key_cache_0 : [num_users=1] = placeholder[target=dc_key_cache_0]
    %dc_value_cache_0 : [num_users=1] = placeholder[target=dc_value_cache_0]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%dc_key_cache_0], 1), kwargs = {})
    %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%dc_value_cache_0], 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %cat_1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%add, [2], True), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %sum_1), kwargs = {})
    return (add_1,)

With dynamic shapes now.

batch = torch.export.Dim("batch", min=1, max=1024)
clength = torch.export.Dim("clength", min=1, max=1024)

try:
    ep = torch.export.export(
        model,
        (x, cache),
        dynamic_shapes=({0: batch}, [[{0: batch, 2: clength}], [{0: batch, 2: clength}]]),
    )
    print(ep.graph)
    failed = False
except Exception as e:
    print("FAILS:", e)
    failed = True
graph():
    %x : [num_users=1] = placeholder[target=x]
    %dc_key_cache_0 : [num_users=1] = placeholder[target=dc_key_cache_0]
    %dc_value_cache_0 : [num_users=1] = placeholder[target=dc_value_cache_0]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%dc_key_cache_0], 1), kwargs = {})
    %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%dc_value_cache_0], 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %cat_1), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%add, [2], True), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %sum_1), kwargs = {})
    return (add_1,)

If it failed, let’s understand why.

if failed:

    class Model(torch.nn.Module):
        def forward(self, dc):
            kc = dc.key_cache[0]
            vc = dc.value_cache[0]
            return kc + vc

    ep = torch.export.export(
        Model(),
        (cache,),
        dynamic_shapes={"dc": [[{0: batch, 2: clength}], [{0: batch, 2: clength}]]},
    )
    for node in ep.graph.nodes:
        print(f"{node.name} -> {node.meta.get('val', '-')}")
        # it prints out ``dc_key_cache_0 -> FakeTensor(..., size=(4, 8, 11, 6))``
        # but it should be ``dc_key_cache_0 -> FakeTensor(..., size=(s0, 8, s1, 6))``

Total running time of the script: (0 minutes 0.173 seconds)

Related examples

Do no use Module as inputs!

Do no use Module as inputs!

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and a custom operator registered with a function

to_onnx and a custom operator registered with a function

torch.onnx.export and Phi-2

torch.onnx.export and Phi-2

to_onnx and Phi-2

to_onnx and Phi-2

Gallery generated by Sphinx-Gallery