onnx_diagnostic.torch_export_patches.patch_module_helper

class onnx_diagnostic.torch_export_patches.patch_module_helper.OrToBitOrTransformer[source][source]
onnx_diagnostic.torch_export_patches.patch_module_helper.ast_or_into_bitor(node: ast.Node) ast.Node[source][source]

Replaces every operator or into |.

onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting(cls_name: str) List[Any] | None[source][source]

Returns a known list of classes mapped to a known rewritings because of control flow. See known_transformers_rewritings_clamp_float16().

Parameters:

cls_name – name of the class

Returns:

a list of rewriting

<<<

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
    code_needing_rewriting,
)

pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))

>>>

    [{'filter_node': <function rewritings_transformers_clamp_float16.<locals>.<lambda> at 0x7e958c111a80>,
      'function': <function BartEncoderLayer.forward at 0x7e958b595800>,
      'pre_rewriter': <function ast_or_into_bitor at 0x7e958df44180>},
     {'filter_node': <function rewritings_transformers_clamp_float16.<locals>.<lambda> at 0x7e958c111a80>,
      'function': <function PLBartEncoderLayer.forward at 0x7e958b5acae0>,
      'pre_rewriter': <function ast_or_into_bitor at 0x7e958df44180>}]
onnx_diagnostic.torch_export_patches.patch_module_helper.known_transformers_rewritings_clamp_float16() Dict[str, str][source][source]

This functions returns the list of known classes to be rewritten. in transformers. Each class is mapped to an alias, this alias is then given to rewritings_transformers_clamp_float16() to rewrite the encoder layers because of a specific control flow.

<<<

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
    known_transformers_rewritings_clamp_float16,
)

pprint.pprint(known_transformers_rewritings_clamp_float16())

>>>

    {'AutoformerEncoder': 'AutoformerEncoderLayer',
     'AutoformerEncoderLayer': 'AutoformerEncoderLayer',
     'AutoformerForPrediction': 'AutoformerEncoderLayer',
     'AutoformerModel': 'AutoformerEncoderLayer',
     'BartEncoderLayer': 'BartEncoderLayer',
     'BartForConditionalGeneration': 'BartEncoderLayer',
     'BigBirdPegasusForCausalLM': 'BigBirdPegasusEncoderLayer',
     'BigBirdPegasusForConditionalGeneration': 'BigBirdPegasusEncoderLayer',
     'BigBirdPegasusForQuestionAnswering': 'BigBirdPegasusEncoderLayer',
     'BlenderbotSmallEncoderLayer': 'BlenderbotSmallEncoderLayer',
     'BlenderbotSmallForCausalLM': 'BlenderbotSmallEncoderLayer',
     'BlenderbotSmallForConditionalGeneration': 'BlenderbotSmallEncoderLayer',
     'InformerEncoderLayer': 'InformerEncoderLayer',
     'InformerForPrediction': 'InformerEncoderLayer',
     'LEDClassificationHead': 'LEDEncoderLayer',
     'LEDEncoderLayer': 'LEDEncoderLayer',
     'LEDForConditionalGeneration': 'LEDEncoderLayer',
     'MarianEncoder': 'MarianEncoderLayer',
     'MarianEncoderLayer': 'MarianEncoderLayer',
     'MarianMTModel': 'MarianEncoderLayer',
     'MarianModel': 'MarianEncoderLayer',
     'MvpEncoderLayer': 'MvpEncoderLayer',
     'MvpForCausalLM': 'MvpEncoderLayer',
     'MvpForConditionalGeneration': 'MvpEncoderLayer',
     'MvpForQuestionAnswering': 'MvpEncoderLayer',
     'MvpForSequenceClassification': 'MvpEncoderLayer',
     'MvpPrompt': 'MvpEncoderLayer',
     'NllbMoeEncoderLayer': 'NllbMoeEncoderLayer',
     'NllbMoeForConditionalGeneration': 'NllbMoeEncoderLayer',
     'PLBartEncoderLayer': 'BartEncoderLayer',
     'PLBartForConditionalGeneration': 'BartEncoderLayer',
     'TimeSeriesTransformerEncoderLayer': 'TimeSeriesTransformerEncoderLayer',
     'TimeSeriesTransformerForPrediction': 'TimeSeriesTransformerEncoderLayer'}
onnx_diagnostic.torch_export_patches.patch_module_helper.rewritings_transformers_clamp_float16(cls_name) List[type][source][source]

Rewrites known control flows equal to this:

if hidden_states.dtype == torch.float16 and (
    torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
    clamp_value = torch.finfo(hidden_states.dtype).max - 1000
    hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

cls_name is the class name. It is mapped with a list of other class names to rename. Here is the known list:

<<<

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
    _rewrite_forward_clamp_float16,
)

pprint.pprint(_rewrite_forward_clamp_float16())

>>>

    {'AutoformerEncoderLayer': [<class 'transformers.models.autoformer.modeling_autoformer.AutoformerEncoderLayer'>],
     'BartEncoderLayer': [<class 'transformers.models.bart.modeling_bart.BartEncoderLayer'>,
                          <class 'transformers.models.plbart.modeling_plbart.PLBartEncoderLayer'>],
     'BigBirdPegasusEncoderLayer': [<class 'transformers.models.bigbird_pegasus.modeling_bigbird_pegasus.BigBirdPegasusEncoderLayer'>],
     'BlenderbotSmallEncoderLayer': [<class 'transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallEncoderLayer'>],
     'InformerEncoderLayer': [<class 'transformers.models.informer.modeling_informer.InformerEncoderLayer'>],
     'LEDEncoderLayer': [<class 'transformers.models.led.modeling_led.LEDEncoderLayer'>],
     'MarianEncoderLayer': [<class 'transformers.models.marian.modeling_marian.MarianEncoderLayer'>],
     'MvpEncoderLayer': [<class 'transformers.models.mvp.modeling_mvp.MvpEncoderLayer'>],
     'NllbMoeEncoderLayer': [<class 'transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeEncoderLayer'>],
     'TimeSeriesTransformerEncoderLayer': [<class 'transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoderLayer'>]}

Function _rewrite_forward_clamp_float16 collects all model classes using those layers.