onnx_diagnostic.torch_export_patches.patch_module

class onnx_diagnostic.torch_export_patches.patch_module.RewriteControlFlow(wrapper_name)[source][source]
generic_visit(node)[source][source]

Called if no explicit visitor function exists for a node.

class onnx_diagnostic.torch_export_patches.patch_module.RewrittenMethod(tree, func)[source][source]

Stores a rewritten method using onnx_diagnostic.torch_export_patches.patch_module.transform_method().

Parameters:
  • tree – ast tree

  • func – callable compiled from the tree

property code: str

Returns the source.

onnx_diagnostic.torch_export_patches.patch_module.transform_method(func: Callable, if_name='torch_cond') RewrittenMethod[source][source]

Returns a new function based on func where every test (if) is replaced by a call to torch.cond().

Parameters:
  • func – method or function to rewrite

  • if_name – function calling the test

Returns:

rewritten method