onnx_diagnostic.torch_export_patches.patch_module_helper¶
- class onnx_diagnostic.torch_export_patches.patch_module_helper.OrToBitOrTransformer[source][source]¶
- onnx_diagnostic.torch_export_patches.patch_module_helper.ast_or_into_bitor(node: ast.Node) ast.Node [source][source]¶
Replaces every operator
or
into|
.
- onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting(cls_name: str) List[Any] | None [source][source]¶
Returns a known list of methods or functions to rewrite because of control flow for a specific model class.
- Parameters:
cls_name – name of the class
- Returns:
a list of rewriting
<<<
import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( code_needing_rewriting, ) pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
>>>
[{'filter_node': <function code_needing_rewriting.<locals>.<lambda> at 0x7fe13c0553a0>, 'function': <function BartEncoderLayer.forward at 0x7fe125f12520>, 'pre_rewriter': <function ast_or_into_bitor at 0x7fe13c1dd620>}, {'filter_node': <function code_needing_rewriting.<locals>.<lambda> at 0x7fe13c0553a0>, 'function': <function PLBartEncoderLayer.forward at 0x7fe125fa8d60>, 'pre_rewriter': <function ast_or_into_bitor at 0x7fe13c1dd620>}]