Note
Go to the end to download the full example code.
Export a model using a custom type as input¶
We will a class used in many model: transformers.cache_utils.DynamicCache
.
First try: it fails¶
Let’s check the model runs.
x = torch.randn(3, 8, 7, 1)
cache = transformers.cache_utils.DynamicCache(1)
cache.update(torch.ones((3, 8, 5, 6)), (torch.ones((3, 8, 5, 6)) * 2), 0)
model = ModelTakingDynamicCacheAsInput()
expected = model(x, cache)
print(expected.shape)
torch.Size([3, 8, 7, 6])
Let’s check it works with others shapes.
x = torch.randn(4, 8, 7, 1)
cache = transformers.cache_utils.DynamicCache(1)
cache.update(torch.ones((4, 8, 11, 6)), (torch.ones((4, 8, 11, 6)) * 2), 0)
model = ModelTakingDynamicCacheAsInput()
expected = model(x, cache)
print(expected.shape)
torch.Size([4, 8, 7, 6])
Let’s export.
try:
torch.export.export(model, (x, cache))
except Exception as e:
print("export failed with", e)
export failed with It looks like one of the inputs with type `<class 'transformers.cache_utils.DynamicCache'>` is not supported or pytree-flattenable.
Exported graphs inputs can only contain the following supported types: [<class 'torch.Tensor'>, <class 'torch.SymInt'>, <class 'torch.SymFloat'>, <class 'torch.SymBool'>, <class 'torch.ScriptObject'>, <class 'int'>, <class 'torch.layout'>, <class 'torch.device'>, <class 'torch.memory_format'>, <class 'NoneType'>, <class 'torch.dtype'>, <class 'float'>, <class 'bool'>, <class 'triton.language.core.dtype'>, <class 'code'>, <class 'torch.iinfo'>, <class 'complex'>, <class 'torch.nn.attention._SDPBackend'>, <class 'ellipsis'>, <class 'str'>, <class 'torch.finfo'>, <class 'bytes'>, <class 'torch._C._CudaDeviceProperties'>, <class 'NotImplementedType'>].
If you are using a custom class object, please register a pytree_flatten/unflatten function using `torch.utils._pytree.register_pytree_node` or `torch.export.register_dataclass`.
Register serialization of DynamicCache¶
That’s what needs to be done.
Feel free to adapt it to your own class.
The important informatin is we want to serialize
two attributes key_cache
and value_cache
.
Both are list of tensors of the same size.
def flatten_dynamic_cache(
dynamic_cache: transformers.cache_utils.DynamicCache,
) -> Tuple[List[Any], torch.utils._pytree.Context]:
flat = [
(k, getattr(dynamic_cache, k))
for k in ["key_cache", "value_cache"]
if hasattr(dynamic_cache, k)
]
return [f[1] for f in flat], [f[0] for f in flat]
def unflatten_dynamic_cache(
values: List[Any],
context: torch.utils._pytree.Context,
output_type=None,
) -> transformers.cache_utils.DynamicCache:
cache = transformers.cache_utils.DynamicCache()
values = dict(zip(context, values))
for k, v in values.items():
setattr(cache, k, v)
return cache
def flatten_with_keys_dynamic_cache(d: Dict[Any, Any]) -> Tuple[
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
torch.utils._pytree.Context,
]:
values, context = flatten_dynamic_cache(d)
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
torch.utils._pytree.register_pytree_node(
transformers.cache_utils.DynamicCache,
flatten_dynamic_cache,
unflatten_dynamic_cache,
serialized_type_name=f"{transformers.cache_utils.DynamicCache.__module__}.{transformers.cache_utils.DynamicCache.__name__}",
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
)
torch.fx._pytree.register_pytree_flatten_spec(
transformers.cache_utils.DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
)
Let’s try to export again.
ep = torch.export.export(model, (x, cache))
print(ep.graph)
graph():
%c_dc___key_cache_0 : [num_users=0] = placeholder[target=c_dc___key_cache_0]
%c_dc___value_cache_0 : [num_users=0] = placeholder[target=c_dc___value_cache_0]
%x : [num_users=1] = placeholder[target=x]
%dc_key_cache_0 : [num_users=1] = placeholder[target=dc_key_cache_0]
%dc_value_cache_0 : [num_users=1] = placeholder[target=dc_value_cache_0]
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%dc_key_cache_0], 1), kwargs = {})
%cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%dc_value_cache_0], 1), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %cat_1), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%add, [2], True), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %sum_1), kwargs = {})
return (add_1,)
With dynamic shapes now.
batch = torch.export.Dim("batch", min=1, max=1024)
clength = torch.export.Dim("clength", min=1, max=1024)
try:
ep = torch.export.export(
model,
(x, cache),
dynamic_shapes=({0: batch}, [[{0: batch, 2: clength}], [{0: batch, 2: clength}]]),
)
print(ep.graph)
failed = False
except Exception as e:
print("FAILS:", e)
failed = True
FAILS: Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of batch = L['x'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (4).
Suggested fixes:
batch = 4
If it failed, let’s understand why.
if failed:
class Model(torch.nn.Module):
def forward(self, dc):
kc = dc.key_cache[0]
vc = dc.value_cache[0]
return kc + vc
ep = torch.export.export(
Model(),
(cache,),
dynamic_shapes={"dc": [[{0: batch, 2: clength}], [{0: batch, 2: clength}]]},
)
for node in ep.graph.nodes:
print(f"{node.name} -> {node.meta.get('val', '-')}")
# it prints out ``dc_key_cache_0 -> FakeTensor(..., size=(4, 8, 11, 6))``
# but it should be ``dc_key_cache_0 -> FakeTensor(..., size=(s0, 8, s1, 6))``
c_dc___key_cache_0 -> FakeTensor(..., size=(4, 8, 11, 6))
c_dc___value_cache_0 -> FakeTensor(..., size=(4, 8, 11, 6))
dc_key_cache_0 -> FakeTensor(..., size=(4, 8, 11, 6))
dc_value_cache_0 -> FakeTensor(..., size=(4, 8, 11, 6))
add -> FakeTensor(..., size=(4, 8, 11, 6))
output -> -
Let’s undo the registration.
torch.utils._pytree.SUPPORTED_NODES.pop(transformers.cache_utils.DynamicCache)
torch.fx._pytree.SUPPORTED_NODES.pop(transformers.cache_utils.DynamicCache)
torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH.pop(transformers.cache_utils.DynamicCache)
Total running time of the script: (0 minutes 0.656 seconds)
Related examples
Do no use Module as inputs!
to_onnx and a custom operator registered with a function
to_onnx and a custom operator registered with a function
to_onnx and a custom operator inplace
to_onnx and a custom operator inplace