onnx_diagnostic.torch_export_patches.patch_module_helper¶
- class onnx_diagnostic.torch_export_patches.patch_module_helper.OrToBitOrTransformer[source][source]¶
- onnx_diagnostic.torch_export_patches.patch_module_helper.ast_or_into_bitor(node: ast.Node) ast.Node [source][source]¶
Replaces every operator
or
into|
.
- onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting(cls_name: str) List[Any] | None [source][source]¶
Returns a known list of methods or functions to rewrite because of control flow for a specific model class.
- Parameters:
cls_name – name of the class
- Returns:
a list of rewriting
<<<
import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( code_needing_rewriting, ) pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
>>>
[{'filter_node': <function _rewrite_bart_encoder_layer.<locals>.<lambda> at 0x7fc605754220>, 'function': <function BartEncoderLayer.forward at 0x7fc5fcf1d8a0>, 'pre_rewriter': <function ast_or_into_bitor at 0x7fc6062e1300>}, {'filter_node': <function _rewrite_bart_encoder_layer.<locals>.<lambda> at 0x7fc605754220>, 'function': <function PLBartEncoderLayer.forward at 0x7fc5fcf9c5e0>, 'pre_rewriter': <function ast_or_into_bitor at 0x7fc6062e1300>}]