JSON returns list when the original dynamic shapes are list or tuple

Dynamic shapes given to torch.export.export() must follow the same semantic. What if we confuse tuple and list when defining the dynamic shapes, how to restore the expected type assuming we know the inputs? Not often useful but maybe we will learn more about optree.

Dynamic Shapes After JSON

JSON format does not make the difference between a list and a tuple. So after serializing to json and restoring, both of them become lists.

import json
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs

bsize, nheads, slen, dim = 2, 1, 30, 96

inputs = dict(
    input_mask_position=(
        torch.randint(15, size=(2, 3), dtype=torch.int64),
        torch.randint(1, size=(2, 33), dtype=torch.int64),
        torch.arange(3, dtype=torch.int64),
    ),
    past_key_values=make_dynamic_cache(
        [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
    ),
)

print(string_type(inputs, with_shape=True))
dict(input_mask_position:(T7s2x3,T7s2x33,T7s3),past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))

Function onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs() produces the corresponding dynamic shapes assuming they are all dynamic.

{'input_mask_position': ({0: 'd_0_0', 1: 'd_0_1'},
                         {0: 'd_1_0', 1: 'd_1_1'},
                         {0: 'd_2_0'}),
 'past_key_values': {'key_cache': [{0: 'd_3_0',
                                    1: 'd_3_1',
                                    2: 'd_3_2',
                                    3: 'd_3_3'}],
                     'value_cache': [{0: 'd_4_0',
                                      1: 'd_4_1',
                                      2: 'd_4_2',
                                      3: 'd_4_3'}]}}

Converted into JSON.

json_str = json.dumps(ds, indent=2, ensure_ascii=False)
print(json_str)
{
  "input_mask_position": [
    {
      "0": "d_0_0",
      "1": "d_0_1"
    },
    {
      "0": "d_1_0",
      "1": "d_1_1"
    },
    {
      "0": "d_2_0"
    }
  ],
  "past_key_values": {
    "key_cache": [
      {
        "0": "d_3_0",
        "1": "d_3_1",
        "2": "d_3_2",
        "3": "d_3_3"
      }
    ],
    "value_cache": [
      {
        "0": "d_4_0",
        "1": "d_4_1",
        "2": "d_4_2",
        "3": "d_4_3"
      }
    ]
  }
}

Restoration.

{'input_mask_position': [{'0': 'd_0_0', '1': 'd_0_1'},
                         {'0': 'd_1_0', '1': 'd_1_1'},
                         {'0': 'd_2_0'}],
 'past_key_values': {'key_cache': [{'0': 'd_3_0',
                                    '1': 'd_3_1',
                                    '2': 'd_3_2',
                                    '3': 'd_3_3'}],
                     'value_cache': [{'0': 'd_4_0',
                                      '1': 'd_4_1',
                                      '2': 'd_4_2',
                                      '3': 'd_4_3'}]}}

tuple are replaced by list.

# The trick to restore tuple when expected
# ++++++++++++++++++++++++++++++++++++++++


def flatten_unflatten_like_dynamic_shapes(obj):
    if isinstance(obj, torch.Tensor):
        return obj
    flat, spec = torch.utils._pytree.tree_flatten(obj)
    start = 0
    end = 0
    subtrees = []
    for subspec in spec.children_specs:
        end += subspec.num_leaves
        value = subspec.unflatten(flat[start:end])
        value = flatten_unflatten_like_dynamic_shapes(value)
        subtrees.append(value)
        start = end
    if spec.type is dict or spec.context:
        return dict(zip(spec.context, subtrees))
    if spec.type is tuple:
        return tuple(subtrees)
    return subtrees


def _align(inputs, ds):
    if isinstance(inputs, torch.Tensor):
        return ds
    if isinstance(inputs, tuple):
        return tuple(_align(o, d) for o, d in zip(inputs, ds))
    if isinstance(inputs, list):
        return [_align(o, d) for o, d in zip(inputs, ds)]
    if isinstance(inputs, dict):
        return {k: _align(inputs[k], d) for k, d in ds.items()}
    raise TypeError(f"Unexpected types inputs is {type(inputs)}, ds is {type(ds)}")


def fix_dynamic_shapes(inputs, dynamic_shapes):
    flat_unflat_inputs = flatten_unflatten_like_dynamic_shapes(inputs)
    return _align(flat_unflat_inputs, dynamic_shapes)


fixed_ds = fix_dynamic_shapes(inputs, ds2)
pprint.pprint(fixed_ds)
{'input_mask_position': ({'0': 'd_0_0', '1': 'd_0_1'},
                         {'0': 'd_1_0', '1': 'd_1_1'},
                         {'0': 'd_2_0'}),
 'past_key_values': {'key_cache': [{'0': 'd_3_0',
                                    '1': 'd_3_1',
                                    '2': 'd_3_2',
                                    '3': 'd_3_3'}],
                     'value_cache': [{'0': 'd_4_0',
                                      '1': 'd_4_1',
                                      '2': 'd_4_2',
                                      '3': 'd_4_3'}]}}

The code changed tuple into list as expected.

assert isinstance(ds2["input_mask_position"], list)
assert isinstance(fixed_ds["input_mask_position"], tuple)
doc.plot_legend("dynamic shapes\nto json\nfrom json", "torch.export.export", "green")
plot dynamic shapes json

Total running time of the script: (0 minutes 1.055 seconds)

Related examples

Builds dynamic shapes from any input

Builds dynamic shapes from any input

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

Export a model with a control flow (If)

Export a model with a control flow (If)

Gallery generated by Sphinx-Gallery