onnx_diagnostic.torch_export_patches.patch_module_helper

class onnx_diagnostic.torch_export_patches.patch_module_helper.OrToBitOrTransformer[source][source]
onnx_diagnostic.torch_export_patches.patch_module_helper.ast_or_into_bitor(node: ast.Node) ast.Node[source][source]

Replaces every operator or into |.

onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting(cls_name: str) List[Any] | None[source][source]

Returns a known list of methods or functions to rewrite because of control flow for a specific model class.

Parameters:

cls_name – name of the class

Returns:

a list of rewriting

<<<

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
    code_needing_rewriting,
)

pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))

>>>

    [{'filter_node': <function _rewrite_bart_encoder_layer.<locals>.<lambda> at 0x7fc605754220>,
      'function': <function BartEncoderLayer.forward at 0x7fc5fcf1d8a0>,
      'pre_rewriter': <function ast_or_into_bitor at 0x7fc6062e1300>},
     {'filter_node': <function _rewrite_bart_encoder_layer.<locals>.<lambda> at 0x7fc605754220>,
      'function': <function PLBartEncoderLayer.forward at 0x7fc5fcf9c5e0>,
      'pre_rewriter': <function ast_or_into_bitor at 0x7fc6062e1300>}]