yobx.helpers.patch_helper#
- class yobx.helpers.patch_helper.PatchDetails[source]#
This class is used to store patching information. This helps understanding which rewriting was applied to which method of functions. See page Applying patches to a model and displaying the diff.
- make_report(patches: List[Tuple[PatchInfo, List[torch.fx.Node]]], format: str = 'raw') str[source]#
Creates a report based on the involved patches.
- Parameters:
patches – from method
patches_involved_in_graph()format – format of the report
- Returns:
report
See example Applying patches to a model and displaying the diff.
-- original +++ rewritten @@ -1,6 +1,5 @@ def _print_Symbol(self, expr: sympy.Symbol) -> str: - if not isinstance(expr, sympy.Symbol): - raise AssertionError(f"Expected sympy.Symbol, got {type(expr)}") - if not self.symbol_to_source.get(expr): - raise AssertionError(f"Unknown symbol {expr} created by constraints solver") - return self.symbol_to_source[expr][0].name + assert isinstance(expr, sympy.Symbol), str(type(expr)) + if self.symbol_to_source.get(expr): # type: ignore + return self.symbol_to_source[expr][0].name # type: ignore + return str(expr)
- matching_pair(patch: PatchInfo, node: torch.fx.Node) bool[source]#
Last validation for a pair. RotaryEmbedding has many rewriting and they all end up in the same code line.
- patches_involved_in_graph(graph: Any) List[Tuple[PatchInfo, List[Any]]][source]#
Enumerates all patches impacting a graph. The function goes through the graph node (only the main graph) and looks into the metadata to determine if a listed patch was involved.
- Parameters:
graph –
a graph object whose nodes can be iterated. The method is designed for
torch.fx.Graphbut works with any object that satisfies the following minimal contract:graph.nodes— iterable of node objects.node.meta— adictattached to each node.node.meta["stack_trace"]— a string containing the call-stack captured when the node was created.
Any custom graph representation that provides these three attributes will work just as well as a native
fx.Graph.- Returns:
list of nodes impacted by a patch
- class yobx.helpers.patch_helper.PatchInfo(patch: Callable, do: Callable[[], None], undo: Callable[[Callable], None], family: str = '', _last_patched_function: Callable | None = None)[source]#
Stores information about patches.
- Parameters:
function_to_patch – function to patch
patch – function patched
family – a category, anything to classify the patch
do – applies the patch, this function returns the patched function
undo – remove the patch
_last_patched_function – this is used for patches applied when another is used
- format_diff(format: str = 'raw') str[source]#
Formats a diff between two function as a string.
- Parameters:
format –
'raw'or'rst'- Returns:
diff
- classmethod make(patch: Callable, module_or_class: Any, method_or_function_name: str, family: str, _last_patched_function: Callable | None = None) PatchInfo[source]#
Creates a patch with the given information.
- Parameters:
patch – the patched method or function
module_or_class – the module or the class to patch
method_or_function_name – method or function name to patch
family – category of the patch
_last_patched_function – this is used for patches applied when another is used
- Returns:
the patch
- make_diff() str[source]#
Returns a diff as a string. See example Applying patches to a model and displaying the diff.
-- original +++ rewritten @@ -1,6 +1,5 @@ def _print_Symbol(self, expr: sympy.Symbol) -> str: - if not isinstance(expr, sympy.Symbol): - raise AssertionError(f"Expected sympy.Symbol, got {type(expr)}") - if not self.symbol_to_source.get(expr): - raise AssertionError(f"Unknown symbol {expr} created by constraints solver") - return self.symbol_to_source[expr][0].name + assert isinstance(expr, sympy.Symbol), str(type(expr)) + if self.symbol_to_source.get(expr): # type: ignore + return self.symbol_to_source[expr][0].name # type: ignore + return str(expr)