yobx.helpers.patch_helper#

class yobx.helpers.patch_helper.PatchDetails[source]#

This class is used to store patching information. This helps understanding which rewriting was applied to which method of functions. See page Applying patches to a model and displaying the diff.

append(patch: PatchInfo)[source]#

Adds a patch to the list of patches.

data() List[Dict[str, Any]][source]#

Returns the data for a dataframe.

extend(patches: Iterable[PatchInfo])[source]#

Adds a patches to the list of patches.

find(name: str) PatchInfo | None[source]#

Finds a patch by name.

make_report(patches: List[Tuple[PatchInfo, List[torch.fx.Node]]], format: str = 'raw') str[source]#

Creates a report based on the involved patches.

Parameters:
Returns:

report

See example Applying patches to a model and displaying the diff.

-- original
+++ rewritten
@@ -1,6 +1,5 @@
def _print_Symbol(self, expr: sympy.Symbol) -> str:
-    if not isinstance(expr, sympy.Symbol):
-        raise AssertionError(f"Expected sympy.Symbol, got {type(expr)}")
-    if not self.symbol_to_source.get(expr):
-        raise AssertionError(f"Unknown symbol {expr} created by constraints solver")
-    return self.symbol_to_source[expr][0].name
+    assert isinstance(expr, sympy.Symbol), str(type(expr))
+    if self.symbol_to_source.get(expr):  # type: ignore
+        return self.symbol_to_source[expr][0].name  # type: ignore
+    return str(expr)
matching_pair(patch: PatchInfo, node: torch.fx.Node) bool[source]#

Last validation for a pair. RotaryEmbedding has many rewriting and they all end up in the same code line.

property n_patches: int#

Returns the number of stored patches.

patches_involved_in_graph(graph: Any) List[Tuple[PatchInfo, List[Any]]][source]#

Enumerates all patches impacting a graph. The function goes through the graph node (only the main graph) and looks into the metadata to determine if a listed patch was involved.

Parameters:

graph

a graph object whose nodes can be iterated. The method is designed for torch.fx.Graph but works with any object that satisfies the following minimal contract:

  • graph.nodes — iterable of node objects.

  • node.meta — a dict attached to each node.

  • node.meta["stack_trace"] — a string containing the call-stack captured when the node was created.

Any custom graph representation that provides these three attributes will work just as well as a native fx.Graph.

Returns:

list of nodes impacted by a patch

class yobx.helpers.patch_helper.PatchInfo(patch: Callable, do: Callable[[], None], undo: Callable[[Callable], None], family: str = '', _last_patched_function: Callable | None = None)[source]#

Stores information about patches.

Parameters:
  • function_to_patch – function to patch

  • patch – function patched

  • family – a category, anything to classify the patch

  • do – applies the patch, this function returns the patched function

  • undo – remove the patch

  • _last_patched_function – this is used for patches applied when another is used

do()[source]#

Applies the patch.

format_diff(format: str = 'raw') str[source]#

Formats a diff between two function as a string.

Parameters:

format'raw' or 'rst'

Returns:

diff

classmethod make(patch: Callable, module_or_class: Any, method_or_function_name: str, family: str, _last_patched_function: Callable | None = None) PatchInfo[source]#

Creates a patch with the given information.

Parameters:
  • patch – the patched method or function

  • module_or_class – the module or the class to patch

  • method_or_function_name – method or function name to patch

  • family – category of the patch

  • _last_patched_function – this is used for patches applied when another is used

Returns:

the patch

make_diff() str[source]#

Returns a diff as a string. See example Applying patches to a model and displaying the diff.

-- original
+++ rewritten
@@ -1,6 +1,5 @@
def _print_Symbol(self, expr: sympy.Symbol) -> str:
-    if not isinstance(expr, sympy.Symbol):
-        raise AssertionError(f"Expected sympy.Symbol, got {type(expr)}")
-    if not self.symbol_to_source.get(expr):
-        raise AssertionError(f"Unknown symbol {expr} created by constraints solver")
-    return self.symbol_to_source[expr][0].name
+    assert isinstance(expr, sympy.Symbol), str(type(expr))
+    if self.symbol_to_source.get(expr):  # type: ignore
+        return self.symbol_to_source[expr][0].name  # type: ignore
+    return str(expr)
property name: str#

Returns the name of the patch.

to_dict() Dict[str, Any][source]#

usual

undo()[source]#

Removes the patch.

yobx.helpers.patch_helper.clean_code_with_black(code: str) str[source]#

Changes the code style with black if available.

yobx.helpers.patch_helper.make_diff_code(code1: str, code2: str, output: str | None = None) str[source]#

Creates a diff between two codes.

Parameters:
  • code1 – first code

  • code2 – second code

  • output – if not empty, stores the output in this file

Returns:

diff