Source code for experimental_experiment.torch_models
importpprintfromtypingimportAny,Dict,Iterator
[docs]defflatten_outputs(output:Any)->Iterator[Any]:"""Flattens output results."""ifisinstance(output,(list,tuple)):foriteminoutput:yield fromflatten_outputs(item)elifisinstance(output,dict):yield fromflatten_outputs(list(output.values()))elifhasattr(output,"to_tuple"):yield fromflatten_outputs(output.to_tuple())elifhasattr(output,"shape"):yieldoutputelifoutput.__class__.__name__=="MambaCache":yieldoutput.conv_statesyieldoutput.ssm_stateselifoutput.__class__.__name__=="DynamicCache":yieldoutput.key_cacheyieldoutput.value_cacheelse:raiseTypeError(f"Unable to flatten type {type(output)}")
[docs]defassert_found(kwargs:Dict[str,Any],config:Dict[str,Any]):"""Checks a parameter is available."""forkinkwargs:assert(kinconfigork=="_attn_implementation"),f"Parameter {k!r} is not mentioned in the configuration {pprint.pformat(config)}"