onnx_diagnostic.export.shape_helper

onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = 'd') Any[source][source]

Returns the dynamic shapes for the given inputs. All dimensions are considered as dynamic. dim_prefix can be a string (the function uses it as a prefix), or torch.export.Dim.AUTO or torch.export.Dim.DYNAMIC.

<<<

import pprint
import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs

bsize, nheads, slen, dim = 2, 1, 30, 96
inputs = dict(
    input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
    attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
    position_ids=torch.arange(3, dtype=torch.int64),
    past_key_values=make_dynamic_cache(
        [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
    ),
)
ds = all_dynamic_shape_from_inputs(inputs)
pprint.pprint(ds)

>>>

    {'attention_mask': {0: 'd_1_0', 1: 'd_1_1'},
     'input_ids': {0: 'd_0_0', 1: 'd_0_1'},
     'past_key_values': {'key_cache': [{0: 'd_3_0',
                                        1: 'd_3_1',
                                        2: 'd_3_2',
                                        3: 'd_3_3'}],
                         'value_cache': [{0: 'd_4_0',
                                          1: 'd_4_1',
                                          2: 'd_4_2',
                                          3: 'd_4_3'}]},
     'position_ids': {0: 'd_2_0'}}
onnx_diagnostic.export.shape_helper.guess_dynamic_shapes_from_inputs(inputs: List[Any], auto: bool | str = False) Tuple[Tuple[Any, ...], Dict[str, Any]][source][source]

Guesses which dimension is dimension from a set of inputs. Every dimension having different values over multiple sets of inputs. Every dimension not changing remains static.

Parameters:
  • inputs – a list of input sets

  • auto – True for torch.export.Dim.AUTO, False for torch.export.Dim.DYNAMIC, a string to get a unique string for every dynamic dimension

Returns:

args and kwargs

<<<

import pprint
import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs

bsize, nheads, slen, dim = 2, 1, 30, 96
inputs1 = dict(
    input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
    attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
    position_ids=torch.arange(3, dtype=torch.int64),
    past_key_values=make_dynamic_cache(
        [
            (
                torch.randn(bsize, nheads, slen, dim),
                torch.randn(bsize, nheads, slen, dim),
            ),
        ]
    ),
)
bsize, nheads, slen, dim = 3, 1, 33, 96
inputs2 = dict(
    input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64),
    attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64),
    position_ids=torch.arange(4, dtype=torch.int64),
    past_key_values=make_dynamic_cache(
        [
            (
                torch.randn(bsize, nheads, slen, dim),
                torch.randn(bsize, nheads, slen, dim),
            ),
        ]
    ),
)
ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d")
pprint.pprint(ds)

>>>

    ((),
     {'attention_mask': {0: 'd_0I0', 1: 'd_0I1'},
      'input_ids': {0: 'd_1I0', 1: 'd_1I1'},
      'past_key_values': [[{0: 'd_2I_0o_0l0', 2: 'd_2I_0o_0l2'}],
                          [{0: 'd_2I_1o_0l0', 2: 'd_2I_1o_0l2'}]],
      'position_ids': {0: 'd_3I0'}})

This function returns something equivalent to function torch.export.dynamic_shapes.AdditionalInputs but this one needs a model.

<<<

import pprint
import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs

data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
ds = torch.export.dynamic_shapes.AdditionalInputs()
ds.add((), data["inputs"])
ds.add((), data["inputs2"])
pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))

>>>

    {'attention_mask': (_DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                 min=None,
                                 max=None,
                                 _factory=True),
                        _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                 min=None,
                                 max=None,
                                 _factory=True)),
     'input_ids': (_DimHint(type=<_DimHintType.DYNAMIC: 3>,
                            min=None,
                            max=None,
                            _factory=True),
                   _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                            min=None,
                            max=None,
                            _factory=True)),
     'past_key_values': [[(_DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True),
                           None,
                           _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True),
                           None)],
                         [(_DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True),
                           None,
                           _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True),
                           None)]],
     'position_ids': (_DimHint(type=<_DimHintType.DYNAMIC: 3>,
                               min=None,
                               max=None,
                               _factory=True),
                      _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                               min=None,
                               max=None,
                               _factory=True))}