onnx_diagnostic.torch_export_patches.patch_module_helper¶
- class onnx_diagnostic.torch_export_patches.patch_module_helper.OrToBitOrTransformer[source][source]¶
- onnx_diagnostic.torch_export_patches.patch_module_helper.ast_or_into_bitor(node: ast.Node) ast.Node [source][source]¶
Replaces every operator
or
into|
.
- onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting(cls_name: str) List[Any] | None [source][source]¶
Returns a known list of classes mapped to a known rewritings because of control flow. See
known_transformers_rewritings_clamp_float16()
.- Parameters:
cls_name – name of the class
- Returns:
a list of rewriting
<<<
import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( code_needing_rewriting, ) pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
>>>
[{'filter_node': <function rewritings_transformers_clamp_float16.<locals>.<lambda> at 0x7e958c111a80>, 'function': <function BartEncoderLayer.forward at 0x7e958b595800>, 'pre_rewriter': <function ast_or_into_bitor at 0x7e958df44180>}, {'filter_node': <function rewritings_transformers_clamp_float16.<locals>.<lambda> at 0x7e958c111a80>, 'function': <function PLBartEncoderLayer.forward at 0x7e958b5acae0>, 'pre_rewriter': <function ast_or_into_bitor at 0x7e958df44180>}]
- onnx_diagnostic.torch_export_patches.patch_module_helper.known_transformers_rewritings_clamp_float16() Dict[str, str] [source][source]¶
This functions returns the list of known classes to be rewritten. in transformers. Each class is mapped to an alias, this alias is then given to
rewritings_transformers_clamp_float16()
to rewrite the encoder layers because of a specific control flow.<<<
import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( known_transformers_rewritings_clamp_float16, ) pprint.pprint(known_transformers_rewritings_clamp_float16())
>>>
{'AutoformerEncoder': 'AutoformerEncoderLayer', 'AutoformerEncoderLayer': 'AutoformerEncoderLayer', 'AutoformerForPrediction': 'AutoformerEncoderLayer', 'AutoformerModel': 'AutoformerEncoderLayer', 'BartEncoderLayer': 'BartEncoderLayer', 'BartForConditionalGeneration': 'BartEncoderLayer', 'BigBirdPegasusForCausalLM': 'BigBirdPegasusEncoderLayer', 'BigBirdPegasusForConditionalGeneration': 'BigBirdPegasusEncoderLayer', 'BigBirdPegasusForQuestionAnswering': 'BigBirdPegasusEncoderLayer', 'BlenderbotSmallEncoderLayer': 'BlenderbotSmallEncoderLayer', 'BlenderbotSmallForCausalLM': 'BlenderbotSmallEncoderLayer', 'BlenderbotSmallForConditionalGeneration': 'BlenderbotSmallEncoderLayer', 'InformerEncoderLayer': 'InformerEncoderLayer', 'InformerForPrediction': 'InformerEncoderLayer', 'LEDClassificationHead': 'LEDEncoderLayer', 'LEDEncoderLayer': 'LEDEncoderLayer', 'LEDForConditionalGeneration': 'LEDEncoderLayer', 'MarianEncoder': 'MarianEncoderLayer', 'MarianEncoderLayer': 'MarianEncoderLayer', 'MarianMTModel': 'MarianEncoderLayer', 'MarianModel': 'MarianEncoderLayer', 'MvpEncoderLayer': 'MvpEncoderLayer', 'MvpForCausalLM': 'MvpEncoderLayer', 'MvpForConditionalGeneration': 'MvpEncoderLayer', 'MvpForQuestionAnswering': 'MvpEncoderLayer', 'MvpForSequenceClassification': 'MvpEncoderLayer', 'MvpPrompt': 'MvpEncoderLayer', 'NllbMoeEncoderLayer': 'NllbMoeEncoderLayer', 'NllbMoeForConditionalGeneration': 'NllbMoeEncoderLayer', 'PLBartEncoderLayer': 'BartEncoderLayer', 'PLBartForConditionalGeneration': 'BartEncoderLayer', 'TimeSeriesTransformerEncoderLayer': 'TimeSeriesTransformerEncoderLayer', 'TimeSeriesTransformerForPrediction': 'TimeSeriesTransformerEncoderLayer'}
- onnx_diagnostic.torch_export_patches.patch_module_helper.rewritings_transformers_clamp_float16(cls_name) List[type] [source][source]¶
Rewrites known control flows equal to this:
if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
cls_name is the class name. It is mapped with a list of other class names to rename. Here is the known list:
<<<
import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( _rewrite_forward_clamp_float16, ) pprint.pprint(_rewrite_forward_clamp_float16())
>>>
{'AutoformerEncoderLayer': [<class 'transformers.models.autoformer.modeling_autoformer.AutoformerEncoderLayer'>], 'BartEncoderLayer': [<class 'transformers.models.bart.modeling_bart.BartEncoderLayer'>, <class 'transformers.models.plbart.modeling_plbart.PLBartEncoderLayer'>], 'BigBirdPegasusEncoderLayer': [<class 'transformers.models.bigbird_pegasus.modeling_bigbird_pegasus.BigBirdPegasusEncoderLayer'>], 'BlenderbotSmallEncoderLayer': [<class 'transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallEncoderLayer'>], 'InformerEncoderLayer': [<class 'transformers.models.informer.modeling_informer.InformerEncoderLayer'>], 'LEDEncoderLayer': [<class 'transformers.models.led.modeling_led.LEDEncoderLayer'>], 'MarianEncoderLayer': [<class 'transformers.models.marian.modeling_marian.MarianEncoderLayer'>], 'MvpEncoderLayer': [<class 'transformers.models.mvp.modeling_mvp.MvpEncoderLayer'>], 'NllbMoeEncoderLayer': [<class 'transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeEncoderLayer'>], 'TimeSeriesTransformerEncoderLayer': [<class 'transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoderLayer'>]}
Function _rewrite_forward_clamp_float16 collects all model classes using those layers.