yobx.torch.tracing#

class yobx.torch.tracing.CondCCOp[source]#

Cannot be imported from torch.ops.higher_order.cond (function cond overwrite submodule cond).

class yobx.torch.tracing.CustomAttribute(root: CustomProxy, attr: str)[source]#

To trace attributes.

class yobx.torch.tracing.CustomParameterProxy(tracer: TracerBase, node: Node, name, param)[source]#

A special proxy which lets “shape”, “size”, “dim”, and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing.

class yobx.torch.tracing.CustomProxy(node: Node, tracer: TracerBase | None = None)[source]#

Defines a custom proxy to trace the execution of a model and converts it into a fx graph. Works with CustomTracer.

classmethod cat(tensors: List[CustomProxy], dim: int = 0, *, out=None, axis: int | None = None) CustomProxy[source]#

Implements cat for tensors.

instanceof(cls)[source]#

Tells if this proxy represents a specific class.

length()[source]#

Returns a proxy for the length.

class yobx.torch.tracing.CustomProxyFloat(node: Node, tracer: TracerBase | None = None)[source]#

A proxy for a float.

instanceof(cls)[source]#

isinstance

class yobx.torch.tracing.CustomProxyInt(node: Node, tracer: TracerBase | None = None)[source]#

A proxy for an integer.

instanceof(cls)[source]#

isinstance

class yobx.torch.tracing.CustomTracer(autowrap_modules: ~typing.Tuple[~types.ModuleType, ...] = (<module 'math' (built-in)>,), autowrap_functions: ~typing.Tuple[~typing.Callable, ...] = (<built-in method ones of type object>, <built-in method zeros of type object>, <built-in method full of type object>, <built-in method empty of type object>, <built-in method arange of type object>), param_shapes_constant: bool = False, module_leaves: ~typing.Dict[type, ~typing.Callable[[~torch.nn.modules.module.Module, str], bool]] | None = None)[source]#

Defines a custom tracer to trace the execution of a model and converts it into a fx graph. Works with CustomProxy.

::

from yobx.torch.tracing import CustomTracer

graph = CustomTracer().trace(model)

Parameters:
  • autowrap_modules – defaults to (math, ), Python modules whose functions should be wrapped automatically without needing to use fx.wrap().

  • autowrap_functions – defaults to _AUTOWRAP_FUNCTIONS, Python functions that should be wrapped automatically without needing to use fx.wrap(). Includes tensor-creating functions (e.g. torch.ones) so that calls with proxy size arguments are captured as traced nodes rather than executed immediately.

  • param_shapes_constant – When this flag is set, calls to shape, size and a few other shape like attributes of a module’s parameter will be evaluated directly, rather than returning a new Proxy value for an attribute access.

  • module_leaves – modules to be considered as leaves, mapped to a callable f(module, module_qualified_name) -> bool that decides whether a specific module instance is a leaf; the tracer does not trace into leaf modules and emits call_module nodes for them instead

call_module(m: Module, forward: Callable[[...], Any], args: tuple[Any, ...], kwargs: dict[str, Any]) Any[source]#

Method that specifies the behavior of this Tracer when it encounters a call to an nn.Module instance.

By default, the behavior is to check if the called module is a leaf module via is_leaf_module. If it is, emit a call_module node referring to m in the Graph. Otherwise, call the Module normally, tracing through the operations in its forward function.

This method can be overridden to–for example–create nested traced GraphModules, or any other behavior you would want while tracing across Module boundaries.

Parameters:
  • m (Module) – The module for which a call is being emitted

  • forward (Callable) – The forward() method of the Module to be invoked

  • args (Tuple) – args of the module callsite

  • kwargs (Dict) – kwargs of the module callsite

Returns:

The return value from the Module call. In the case that a call_module node was emitted, this is a Proxy value. Otherwise, it is whatever value was returned from the Module invocation.

Note

Backwards-compatibility for this API is guaranteed.

create_arg(a: Any) Argument[source]#

Overwrites this method to deal with more argument.

create_args_for_root(root_fn, is_module, concrete_args=None)[source]#

Create placeholder nodes corresponding to the signature of the root Module. This method introspects root’s signature and emits those nodes accordingly, also supporting *args and **kwargs.

Warning

This API is experimental and is NOT backward-compatible.

getattr(attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any])[source]#

See torch.fx.Tracer.getattr().

classmethod graph_erase_node(graph: Graph, node: Node)[source]#

Removes a node and all predecessors with are only consumed by this one.

is_leaf_module(m: Module, module_qualified_name: str) bool[source]#

A method to specify whether a given nn.Module is a “leaf” module.

Leaf modules are the atomic units that appear in the IR, referenced by call_module calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.

Parameters:
  • m (Module) – The module being queried about

  • module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule foo contains submodule bar, which contains submodule baz, that module will appear with the qualified name foo.bar.baz here.

Note

Backwards-compatibility for this API is guaranteed.

proxy(node: Node, cls: type[CustomProxy] = <class 'yobx.torch.tracing.CustomProxy'>) Proxy[source]#

Overwrites this method to replace the default Proxy by CustomProxy.

register_callable(name: str, fn: Callable) Node[source]#

Registers a function and return a unique name.

Parameters:
  • name – prefix to prepend to the function name

  • fn – function

Returns:

new_name

classmethod remove_inplace(graph: Graph, exported_program: ExportedProgram | None = None, verbose: int = 0, exc: bool = True, recursive: bool = False) int[source]#

Removes inplace operations.

Parameters:
  • graph – graph to modify

  • exported_program – if available, it is used in the error message to make it easier to trace the code source

  • verbose – verbosity

  • exc – raise an exception if not possible, other return -1

  • recursive – remove node inside submodules

Returns:

number of inplace nodes removed, a negative number means there are still inplace nodes to be removed but this function is unable to do that, only decompositions may help in that case

classmethod remove_unnecessary_slices(graph: Graph, verbose: int = 0) int[source]#

Removes unnecessary slices and other nodes doing nothing.

Parameters:
  • graph – graph to modify

  • verbose – verbosity level

Returns:

number of inplace nodes removed

%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
    (args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None, remove_inplace: bool = True, update_model_with_callable: bool = True, dynamic_shapes: Any | None = None, verbose: int = 0) Graph[source]#

Trace root and return the corresponding FX Graph representation. root can either be an nn.Module instance or a Python callable.

Note that after this call, self.root may be different from the root passed in here. For example, when a free function is passed to trace(), we will create an nn.Module instance to use as the root and add embedded constants to.

Parameters:
  • root – Either a Module or a function to be traced through. Backwards-compatibility for this parameter is guaranteed.

  • concrete_args – Concrete arguments that should not be treated as Proxies. This parameter is experimental and its backwards-compatibility is NOT guaranteed.

  • remove_inplace – Removes inplace nodes

  • update_model_with_callable – in some cases (control flow), the model needs to be updated

  • dynamic_shapes – dynamic shapes

  • verbose – verbosity

Returns:

A Graph representing the semantics of the passed-in root

If the model had to be wrapped before being traced, attribute traced_model is added to the tracer.

class yobx.torch.tracing.LEAVE_INPLACE[source]#

Constant indicating inplace removal failed.

yobx.torch.tracing.replace_problematic_function_before_tracing() Generator[source]#

Replaces function that cannot be traced with the default tracer such as torch.cat().

yobx.torch.tracing.setitem_with_transformation(a, b, transformations)[source]#

Extended version of setitem to deal with inplace modification.

yobx.torch.tracing.tree_unflatten_with_proxy(tree_spec: Any, leaves: Iterable[Any]) Any[source]#

More robust implementation of pytree.tree_unflatten supporting DynamicCache.