Source code for experimental_experiment.torch_models import pprint from typing import Any, Dict, Iterator [docs] def flatten_outputs(output: Any) -> Iterator[Any]: """ Flattens output results. """ if isinstance(output, (list, tuple)): for item in output: yield from flatten_outputs(item) elif isinstance(output, dict): yield from flatten_outputs(list(output.values())) elif hasattr(output, "to_tuple"): yield from flatten_outputs(output.to_tuple()) else: yield output [docs] def assert_found(kwargs: Dict[str, Any], config: Dict[str, Any]): """ Checks a parameter is available. """ for k in kwargs: assert ( k in config or k == "_attn_implementation" ), f"Parameter {k!r} is not mentioned in the configuration {pprint.pformat(config)}"