Source code for onnx_diagnostic.helpers.config_helper

import functools
import importlib
import inspect
import re
from typing import Any, Dict, Optional, Tuple, Union
import transformers


[docs] def check_hasattr(config: Any, *args: Union[str, Tuple[Any, ...]]): """ Checks the confiugation has all the attributes in ``args``. Raises an exception otherwise. """ for a in args: assert isinstance(a, (str, tuple)), f"unexpected type {type(a)} in {args!r}" if isinstance(a, str): assert (isinstance(config, dict) and a in config) or hasattr( config, a ), f"Missing attribute {a!r} in\n{config}" elif isinstance(a, tuple): assert any( (isinstance(name, str) and hasattr(config, name)) or all(hasattr(config, _) for _ in name) for name in a ), f"All attributes in {a!r} are missing from\n{config}"
[docs] def update_config(config: Any, mkwargs: Dict[str, Any]): """Updates a configuration with different values.""" for k, v in mkwargs.items(): if isinstance(v, dict): assert hasattr( config, k ), f"missing attribute {k!r} in config={config}, cannot update it with {v}" update_config(getattr(config, k), v) else: setattr(config, k, v)
def _pick(config, *atts): """Returns the first value found in the configuration.""" for a in atts: if isinstance(a, str): if hasattr(config, a): return getattr(config, a) elif isinstance(a, tuple): if all(hasattr(config, _) for _ in a[1:]): return a[0]([getattr(config, _) for _ in a[1:]]) raise AssertionError(f"Unable to find any of these {atts!r} in {config}")
[docs] @functools.cache def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[type]: """ Retrieves the configuration class for a given architecture. :param arch: architecture (clas name) :param exc: raise an exception if not found :return: type """ cls = getattr(transformers, arch) mod_name = cls.__module__ mod = importlib.import_module(mod_name) source = inspect.getsource(mod) reg = re.compile("config: ([A-Za-z0-9]+)") fall = reg.findall(source) if len(fall) == 0: assert not exc, ( f"Unable to guess Configuration class name for arch={arch!r}, " f"module={mod_name!r}, no candidate, source is\n{source}" ) return None unique = set(fall) assert len(unique) == 1, ( f"Unable to guess Configuration class name for arch={arch!r}, " f"module={mod_name!r}, found={unique} (#{len(unique)}), " f"source is\n{source}" ) cls_name = unique.pop() return getattr(transformers, cls_name)