yobx.torch.tracing#
- class yobx.torch.tracing.CondCCOp[source]#
Cannot be imported from torch.ops.higher_order.cond (function cond overwrite submodule cond).
- class yobx.torch.tracing.CustomAttribute(root: CustomProxy, attr: str)[source]#
To trace attributes.
- class yobx.torch.tracing.CustomParameterProxy(tracer: TracerBase, node: Node, name, param)[source]#
A special proxy which lets “shape”, “size”, “dim”, and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing.
- class yobx.torch.tracing.CustomProxy(node: Node, tracer: TracerBase | None = None)[source]#
Defines a custom proxy to trace the execution of a model and converts it into a fx graph. Works with
CustomTracer.- classmethod cat(tensors: List[CustomProxy], dim: int = 0, *, out=None, axis: int | None = None) CustomProxy[source]#
Implements cat for tensors.
- class yobx.torch.tracing.CustomProxyFloat(node: Node, tracer: TracerBase | None = None)[source]#
A proxy for a float.
- class yobx.torch.tracing.CustomProxyInt(node: Node, tracer: TracerBase | None = None)[source]#
A proxy for an integer.
- class yobx.torch.tracing.CustomTracer(autowrap_modules: ~typing.Tuple[~types.ModuleType, ...] = (<module 'math' (built-in)>,), autowrap_functions: ~typing.Tuple[~typing.Callable, ...] = (<built-in method ones of type object>, <built-in method zeros of type object>, <built-in method full of type object>, <built-in method empty of type object>, <built-in method arange of type object>), param_shapes_constant: bool = False, module_leaves: ~typing.Dict[type, ~typing.Callable[[~torch.nn.modules.module.Module, str], bool]] | None = None)[source]#
Defines a custom tracer to trace the execution of a model and converts it into a fx graph. Works with
CustomProxy.- ::
from yobx.torch.tracing import CustomTracer
graph = CustomTracer().trace(model)
- Parameters:
autowrap_modules – defaults to (math, ), Python modules whose functions should be wrapped automatically without needing to use fx.wrap().
autowrap_functions – defaults to
_AUTOWRAP_FUNCTIONS, Python functions that should be wrapped automatically without needing to use fx.wrap(). Includes tensor-creating functions (e.g.torch.ones) so that calls with proxy size arguments are captured as traced nodes rather than executed immediately.param_shapes_constant – When this flag is set, calls to shape, size and a few other shape like attributes of a module’s parameter will be evaluated directly, rather than returning a new Proxy value for an attribute access.
module_leaves – modules to be considered as leaves, mapped to a callable
f(module, module_qualified_name) -> boolthat decides whether a specific module instance is a leaf; the tracer does not trace into leaf modules and emitscall_modulenodes for them instead
- call_module(m: Module, forward: Callable[[...], Any], args: tuple[Any, ...], kwargs: dict[str, Any]) Any[source]#
Method that specifies the behavior of this
Tracerwhen it encounters a call to annn.Moduleinstance.By default, the behavior is to check if the called module is a leaf module via
is_leaf_module. If it is, emit acall_modulenode referring tomin theGraph. Otherwise, call theModulenormally, tracing through the operations in itsforwardfunction.This method can be overridden to–for example–create nested traced GraphModules, or any other behavior you would want while tracing across
Moduleboundaries.- Parameters:
m (Module) – The module for which a call is being emitted
forward (Callable) – The forward() method of the
Moduleto be invokedargs (Tuple) – args of the module callsite
kwargs (Dict) – kwargs of the module callsite
- Returns:
The return value from the Module call. In the case that a
call_modulenode was emitted, this is aProxyvalue. Otherwise, it is whatever value was returned from theModuleinvocation.
Note
Backwards-compatibility for this API is guaranteed.
- create_args_for_root(root_fn, is_module, concrete_args=None)[source]#
Create
placeholdernodes corresponding to the signature of therootModule. This method introspects root’s signature and emits those nodes accordingly, also supporting*argsand**kwargs.Warning
This API is experimental and is NOT backward-compatible.
- classmethod graph_erase_node(graph: Graph, node: Node)[source]#
Removes a node and all predecessors with are only consumed by this one.
- is_leaf_module(m: Module, module_qualified_name: str) bool[source]#
A method to specify whether a given
nn.Moduleis a “leaf” module.Leaf modules are the atomic units that appear in the IR, referenced by
call_modulecalls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.- Parameters:
m (Module) – The module being queried about
module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule
foocontains submodulebar, which contains submodulebaz, that module will appear with the qualified namefoo.bar.bazhere.
Note
Backwards-compatibility for this API is guaranteed.
- proxy(node: Node, cls: type[CustomProxy] = <class 'yobx.torch.tracing.CustomProxy'>) Proxy[source]#
Overwrites this method to replace the default Proxy by CustomProxy.
- register_callable(name: str, fn: Callable) Node[source]#
Registers a function and return a unique name.
- Parameters:
name – prefix to prepend to the function name
fn – function
- Returns:
new_name
- classmethod remove_inplace(graph: Graph, exported_program: ExportedProgram | None = None, verbose: int = 0, exc: bool = True, recursive: bool = False) int[source]#
Removes inplace operations.
- Parameters:
graph – graph to modify
exported_program – if available, it is used in the error message to make it easier to trace the code source
verbose – verbosity
exc – raise an exception if not possible, other return -1
recursive – remove node inside submodules
- Returns:
number of inplace nodes removed, a negative number means there are still inplace nodes to be removed but this function is unable to do that, only decompositions may help in that case
- classmethod remove_unnecessary_slices(graph: Graph, verbose: int = 0) int[source]#
Removes unnecessary slices and other nodes doing nothing.
- Parameters:
graph – graph to modify
verbose – verbosity level
- Returns:
number of inplace nodes removed
%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor] (args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
- trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None, remove_inplace: bool = True, update_model_with_callable: bool = True, dynamic_shapes: Any | None = None, verbose: int = 0) Graph[source]#
Trace
rootand return the corresponding FXGraphrepresentation.rootcan either be annn.Moduleinstance or a Python callable.Note that after this call,
self.rootmay be different from therootpassed in here. For example, when a free function is passed totrace(), we will create annn.Moduleinstance to use as the root and add embedded constants to.- Parameters:
root – Either a
Moduleor a function to be traced through. Backwards-compatibility for this parameter is guaranteed.concrete_args – Concrete arguments that should not be treated as Proxies. This parameter is experimental and its backwards-compatibility is NOT guaranteed.
remove_inplace – Removes inplace nodes
update_model_with_callable – in some cases (control flow), the model needs to be updated
dynamic_shapes – dynamic shapes
verbose – verbosity
- Returns:
A
Graphrepresenting the semantics of the passed-inroot
If the model had to be wrapped before being traced, attribute
traced_modelis added to the tracer.
- yobx.torch.tracing.replace_problematic_function_before_tracing() Generator[source]#
Replaces function that cannot be traced with the default tracer such as
torch.cat().