Source code for onnx_diagnostic.torch_export_patches.patch_inputs

import inspect
from typing import Any, Dict, Optional, Tuple
import torch
import transformers
from ..helpers import string_type
from ..helpers.cache_helper import make_dynamic_cache


def _process_cache(k: str, v):
    assert k != "position_ids" or isinstance(
        k, torch.Tensor
    ), f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}"
    if (
        isinstance(v, list)
        and all(isinstance(i, tuple) for i in v)
        and set(len(t) for t in v) == {2}
    ):
        # A dynamicCache
        cache = make_dynamic_cache(v)
        return cache
    if isinstance(v, torch.Tensor):
        return v
    raise NotImplementedError(
        f"Unable to process parameter {k!r} with v={string_type(v,with_shape=True)}"
    )


def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
    if cls is transformers.cache_utils.DynamicCache:
        assert subset, "DynamicCache cannot be empty"
        values = set(map(str, subset.values()))
        assert len(values) == 1, (
            f"Inconsistencies in subset={subset}, found={values}, "
            f"it cannot be a {cls}, value={string_type(value)}"
        )
        cache_length = len(value.key_cache)
        for v in subset.values():
            axes = v
            break
        new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]]
        return new_shape
    raise NotImplementedError(
        f"_make_shape not implemented for cls={cls}, "
        f"subset={subset}, value={string_type(value)}"
    )


[docs] def convert_dynamic_axes_into_dynamic_shapes( model: torch.nn.Module, args: Optional[Tuple[Any, ...]] = None, kwargs: Optional[Dict[str, Any]] = None, dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, prefix_mapping: Optional[Dict[str, str]] = None, verbose: int = 0, ) -> Tuple[Tuple[Any, ...], Dict[str, Any], Dict[str, Any]]: """ Converts the input from an export to something :func:`torch.export.export` can handle. :param model: model to convert (used to extract the signature) :param args: positional arguments :param kwargs: named arguments :param dynamic_axes: dynamic axes :param prefix_mapping: prefix mapping :param verbose: verbosity :return: (args, kwargs, dynamic shapes) """ new_kwargs = {} if args: assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}" plus = 0 if isinstance(model, torch.nn.Module) else 1 print( f"[convert_dynamic_axes_into_dynamic_shapes] " f"mapping args to kwargs for model=" f"{model if plus else model.__class__.__name__}" ) pars = inspect.signature(model.forward).parameters assert len(pars) >= len( args ), f"Length mismatch, len(args)={len(args)}, pars={list(pars)}" for i, p in enumerate(pars): if i < plus: continue if i - plus >= len(args): break if verbose: print( f"[convert_dynamic_axes_into_dynamic_shapes] mapping args[{i-plus}] " f"to {p!r} ({string_type(args[i-plus])})" ) new_kwargs[p] = args[i - plus] if kwargs: for k, v in kwargs.items(): assert k not in new_kwargs, f"Argument {k!r} from kwargs already present in args." new_kwargs[k] = v # process updated_kwargs = {} changes = {} for k, v in new_kwargs.items(): if isinstance(v, torch.Tensor): updated_kwargs[k] = v continue if isinstance(v, list): # cache? updated_kwargs[k] = _process_cache(k, v) if type(updated_kwargs[k]) is not type(v): # A cache was introduced. if verbose: print( f"[convert_dynamic_axes_into_dynamic_shapes] parameter " f"{k!r} was changed into {type(updated_kwargs[k])}" ) changes[k] = type(updated_kwargs[k]) continue raise NotImplementedError( f"Unexpected type {type(v)} for parameter {k!r} " f"({string_type(v, with_shape=True)})" ) # process dynamic axes if changes: dynamic_shapes = {} done = set() for k, v in dynamic_axes.items(): if k not in changes and k in updated_kwargs and isinstance(v, dict): dynamic_shapes[k] = v continue if "." in k: # something like present.0.key prefix = k.split(".")[0] if prefix in done: continue args_prefix = ( prefix_mapping[prefix] if prefix_mapping and prefix in prefix_mapping else prefix ) if args_prefix in updated_kwargs and args_prefix in changes: # A cache. cls = changes[args_prefix] dynamic_shapes[args_prefix] = _make_shape( { _: __ for _, __ in dynamic_axes.items() if _.startswith(f"{prefix}.") }, cls, updated_kwargs[args_prefix], ) done.add(prefix) continue if k not in updated_kwargs: # dynamic axes not in the given inputs, should be raise an exception? if verbose: print( f"[convert_dynamic_axes_into_dynamic_shapes] dropping axes " f"{k!r}-{v!r}, not found in {set(updated_kwargs)}" ) continue raise NotImplementedError( f"Unable to process dynamic axes {k!r}, axes={v}, " f"value={string_type(updated_kwargs[k], with_shape=True)}, " f"dynamic axes={dynamic_axes}, " f"updated_kwargs={string_type(updated_kwargs, with_shape=True)}" ) return (), updated_kwargs, dynamic_shapes