yobx.torch.tracing#

class yobx.torch.tracing.CondCCOp[source]#

Cannot be imported from torch.ops.higher_order.cond (function cond overwrite submodule cond).

class yobx.torch.tracing.CustomAttribute(root: CustomProxy, attr: str)[source]#

To trace attributes.

class yobx.torch.tracing.CustomParameterProxy(tracer: TracerBase, node: Node, name, param)[source]#

A special proxy which lets “shape”, “size”, “dim”, and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing.

numel()[source]#

Records a numel() call in the FX graph and returns a CustomProxyInt.

When the fake-tensor metadata carries a concrete element count, concrete_val is set on the returned proxy so that comparisons such as if x.numel() == 0: resolve to a plain Python bool at trace time without raising TraceError. All other comparisons (e.g. x.numel() > 0) keep the standard FX-proxy behaviour.

size(dim=None)[source]#

Records a size() or size(dim) call in the FX graph.

  • size() — returns a CustomProxyShape (dynamic shapes) or a plain torch.Size (all-static shapes), matching the behaviour of accessing .shape.

  • size(dim) — returns a CustomProxyInt whose concrete_val is set from fake-tensor metadata so that comparisons such as if x.size(0) == 0: resolve to a plain Python bool at trace time without raising TraceError.

class yobx.torch.tracing.CustomProxy(node: Node, tracer: TracerBase | None = None)[source]#

Defines a custom proxy to trace the execution of a model and converts it into a fx graph. Works with CustomTracer.

classmethod cat(tensors: List[CustomProxy], dim: int = 0, *, out=None, axis: int | None = None) CustomProxy[source]#

Implements cat for tensors.

length()[source]#

Returns a proxy for the length.

numel(*args: Any, **kwargs: Any) CustomProxyInt[source]#

Records a numel() call in the FX graph and returns a CustomProxyInt.

When the fake-tensor metadata carries a concrete element count, concrete_val is set on the returned proxy so that comparisons such as if x.numel() == 0: resolve to a plain Python bool at trace time without raising TraceError. All other comparisons (e.g. x.numel() > 0) keep the standard FX-proxy behaviour.

size(dim: int | None = None)[source]#

Records a size() or size(dim) call in the FX graph.

  • size() — returns a CustomProxyShape (dynamic shapes) or a plain torch.Size (all-static shapes), matching the behaviour of accessing .shape.

  • size(dim) — returns a CustomProxyInt whose concrete_val is set from fake-tensor metadata so that comparisons such as if x.size(0) == 0: resolve to a plain Python bool at trace time without raising TraceError.

class yobx.torch.tracing.CustomProxyBool(node: Node, tracer: TracerBase | None = None)[source]#

A proxy for a boolean.

class yobx.torch.tracing.CustomProxyFloat(node: Node, tracer: TracerBase | None = None)[source]#

A proxy for a float.

class yobx.torch.tracing.CustomProxyInt(node: Node, tracer: TracerBase | None = None, concrete_val: Any = <object object>, only_positive: bool = False, can_be_null: bool = True)[source]#

A proxy for an integer.

When constructed with a concrete_val (e.g. a backed SymInt from fake-tensor metadata), == and != against a plain Python integer / float are evaluated concretely so that patterns like if x.shape[2] == 0: work during symbolic tracing without raising TraceError. All other comparisons (including proxy-vs-proxy) still create FX nodes as usual.

When only_positive is True (set e.g. after seeing torch._check(value > 0)), comparisons against non-positive constants are evaluated concretely:

  • value > c for c <= 0True

  • value >= c for c <= 0True (since value > 0 ≥ c)

  • value < c for c <= 0False

  • value <= c for c <= 0False

  • value != c for c <= 0True

When can_be_null is False (set e.g. after seeing torch._check(value != 0)), zero-comparisons are evaluated concretely:

  • value == 0False

  • value != 0True

class yobx.torch.tracing.CustomProxyShape(node: Node, tracer: TracerBase | None = None, concrete_val: Any = <object object>)[source]#

A tuple of CustomProxyInt instances representing a tensor shape with dynamic dimensions.

Each element is a valid FX proxy node (so dynamic-shape operators such as torch.full are recorded correctly in the graph) and supports equality comparison against plain integer constants so that if x.shape[2] == 0: evaluates to a Python bool without raising TraceError.

Use from_proxy() to construct an instance from a shape CustomAttribute proxy and its corresponding concrete torch.Size.

classmethod from_proxy(shape_proxy: CustomProxy, concrete_shape: Size) CustomProxyShape[source]#

Build a CustomProxyShape from a shape/size proxy and the corresponding concrete (possibly symbolic) torch.Size.

Parameters:
  • shape_proxy – An FX proxy whose [i] subscript creates a getitem graph node (e.g. a CustomAttribute for tensor.shape or a CustomProxy wrapping a call_method("size", …) node).

  • concrete_shape – The concrete torch.Size from the fake tensor’s metadata (may contain backed SymInt values for dynamic dimensions).

length() int[source]#

Returns a proxy for the length.

class yobx.torch.tracing.CustomTracer(autowrap_modules: ~typing.Tuple[~types.ModuleType, ...] = (<module 'math' (built-in)>,), autowrap_functions: ~typing.Tuple[~typing.Callable, ...] = (<built-in method ones of type object>, <built-in method zeros of type object>, <built-in method full of type object>, <built-in method empty of type object>, <built-in method arange of type object>), param_shapes_constant: bool = False, module_leaves: ~typing.Dict[type, ~typing.Callable[[~torch.nn.modules.module.Module, str], bool]] | None = None)[source]#

Defines a custom tracer to trace the execution of a model and converts it into a fx graph. Works with CustomProxy.

::

from yobx.torch.tracing import CustomTracer

graph = CustomTracer().trace(model)

Parameters:
  • autowrap_modules – defaults to (math, ), Python modules whose functions should be wrapped automatically without needing to use fx.wrap().

  • autowrap_functions – defaults to _AUTOWRAP_FUNCTIONS, Python functions that should be wrapped automatically without needing to use fx.wrap(). Includes tensor-creating functions (e.g. torch.ones) so that calls with proxy size arguments are captured as traced nodes rather than executed immediately.

  • param_shapes_constant – When this flag is set, calls to shape, size and a few other shape like attributes of a module’s parameter will be evaluated directly, rather than returning a new Proxy value for an attribute access.

  • module_leaves – modules to be considered as leaves, mapped to a callable f(module, module_qualified_name) -> bool that decides whether a specific module instance is a leaf; the tracer does not trace into leaf modules and emits call_module nodes for them instead

call_module(m: Module, forward: Callable[[...], Any], args: tuple[Any, ...], kwargs: dict[str, Any]) Any[source]#

Method that specifies the behavior of this Tracer when it encounters a call to an nn.Module instance.

By default, the behavior is to check if the called module is a leaf module via is_leaf_module. If it is, emit a call_module node referring to m in the Graph. Otherwise, call the Module normally, tracing through the operations in its forward function.

This method can be overridden to–for example–create nested traced GraphModules, or any other behavior you would want while tracing across Module boundaries.

Parameters:
  • m (Module) – The module for which a call is being emitted

  • forward (Callable) – The forward() method of the Module to be invoked

  • args (Tuple) – args of the module callsite

  • kwargs (Dict) – kwargs of the module callsite

Returns:

The return value from the Module call. In the case that a call_module node was emitted, this is a Proxy value. Otherwise, it is whatever value was returned from the Module invocation.

Note

Backwards-compatibility for this API is guaranteed.

create_arg(a: Any) Argument[source]#

Overwrites this method to deal with more argument.

create_args_for_root(root_fn, is_module, concrete_args=None)[source]#

Create placeholder nodes corresponding to the signature of the root Module. This method introspects root’s signature and emits those nodes accordingly, also supporting *args and **kwargs.

Warning

This API is experimental and is NOT backward-compatible.

getattr(attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any])[source]#

See torch.fx.Tracer.getattr().

classmethod graph_erase_node(graph: Graph, node: Node)[source]#

Removes a node and all predecessors with are only consumed by this one.

is_leaf_module(m: Module, module_qualified_name: str) bool[source]#

A method to specify whether a given nn.Module is a “leaf” module.

Leaf modules are the atomic units that appear in the IR, referenced by call_module calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.

Parameters:
  • m (Module) – The module being queried about

  • module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule foo contains submodule bar, which contains submodule baz, that module will appear with the qualified name foo.bar.baz here.

Note

Backwards-compatibility for this API is guaranteed.

proxy(node: Node, cls: type[CustomProxy] = <class 'yobx.torch.tracing.CustomProxy'>, **kwargs) Proxy[source]#

Overwrites this method to replace the default Proxy by CustomProxy.

register_callable(name: str, fn: Callable) Node[source]#

Registers a function and return a unique name.

Parameters:
  • name – prefix to prepend to the function name

  • fn – function

Returns:

new_name

classmethod remove_batch_dim_nodes(graph: Graph, verbose: int = 0) int[source]#

Rewrites _add_batch_dim and _remove_batch_dim nodes introduced by torch.vmap lowering using a two-phase batch-dimension tracking strategy.

Phase 1 – tag: _add_batch_dim(x, batch_dim, level) is replaced by aten.clone.default(x) and the actual batch-dimension position is stored in the replacement node’s metadata under the key "vmap_batch_dim" as a dict mapping levelbatch_dim.

Phase 2 – propagate: For every call_function node whose input nodes carry "vmap_batch_dim" metadata, that metadata is copied onto the output node so that downstream consumers can find it.

Phase 3 – remove: _remove_batch_dim(x, level, batch_size, out_dim) looks up the actual batch-dimension position of x for the given level from the metadata (falling back to level - 1 if absent), then:

  • If the actual size of the batch dimension in x is 1 but batch_size > 1 (the tensor was broadcast / “not batched” inside vmap), an aten.expand.default node is emitted first to materialise the full batch.

  • erases the node (replacing all uses with x) when no expand is needed and actual_batch_dim == out_dim – a no-op case;

  • replaces the node with aten.movedim.int(x, actual_batch_dim, out_dim) otherwise.

Parameters:
  • graph – FX graph to modify in-place

  • verbose – verbosity level

Returns:

number of nodes replaced or removed

classmethod remove_inplace(graph: Graph, exported_program: ExportedProgram | None = None, verbose: int = 0, exc: bool = True, recursive: bool = False) int[source]#

Removes inplace operations.

Parameters:
  • graph – graph to modify

  • exported_program – if available, it is used in the error message to make it easier to trace the code source

  • verbose – verbosity

  • exc – raise an exception if not possible, other return -1

  • recursive – remove node inside submodules

Returns:

number of inplace nodes removed, a negative number means there are still inplace nodes to be removed but this function is unable to do that, only decompositions may help in that case

classmethod remove_unnecessary_slices(graph: Graph, verbose: int = 0) int[source]#

Removes unnecessary slices and other nodes doing nothing.

Parameters:
  • graph – graph to modify

  • verbose – verbosity level

Returns:

number of inplace nodes removed

%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
    (args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None, remove_inplace: bool = True, update_model_with_callable: bool = True, dynamic_shapes: Any | None = None, verbose: int = 0) Graph[source]#

Trace root and return the corresponding FX Graph representation. root can either be an nn.Module instance or a Python callable.

Note that after this call, self.root may be different from the root passed in here. For example, when a free function is passed to trace(), we will create an nn.Module instance to use as the root and add embedded constants to.

Parameters:
  • root – Either a Module or a function to be traced through. Backwards-compatibility for this parameter is guaranteed.

  • concrete_args – Concrete arguments that should not be treated as Proxies. This parameter is experimental and its backwards-compatibility is NOT guaranteed.

  • remove_inplace – Removes inplace nodes

  • update_model_with_callable – in some cases (control flow), the model needs to be updated

  • dynamic_shapes – dynamic shapes

  • verbose – verbosity

Returns:

A Graph representing the semantics of the passed-in root

If the model had to be wrapped before being traced, attribute traced_model is added to the tracer.

class yobx.torch.tracing.LEAVE_INPLACE[source]#

Constant indicating inplace removal failed.

class yobx.torch.tracing.ScanCCOp[source]#

Replacement for torch.ops.higher_order.scan() during FX symbolic tracing. The real scan operator cannot be called with CustomProxy arguments because it tries to execute the body function eagerly. This proxy operator defers to CustomProxy.__torch_function__() which records the scan call as a FX call_function node instead.

yobx.torch.tracing.replace_problematic_function_before_tracing() Generator[source]#

Replaces function that cannot be traced with the default tracer such as torch.cat().

yobx.torch.tracing.setitem_with_transformation(a, b, transformations)[source]#

Extended version of setitem to deal with inplace modification.

yobx.torch.tracing.tree_unflatten_with_proxy(tree_spec: Any, leaves: Iterable[Any]) Any[source]#

More robust implementation of pytree.tree_unflatten supporting DynamicCache.