yobx.torch.tracing#
- class yobx.torch.tracing.CondCCOp[source]#
Cannot be imported from torch.ops.higher_order.cond (function cond overwrite submodule cond).
- class yobx.torch.tracing.CustomAttribute(root: CustomProxy, attr: str)[source]#
To trace attributes.
- class yobx.torch.tracing.CustomParameterProxy(tracer: TracerBase, node: Node, name, param)[source]#
A special proxy which lets “shape”, “size”, “dim”, and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing.
- numel()[source]#
Records a
numel()call in the FX graph and returns aCustomProxyInt.When the fake-tensor metadata carries a concrete element count,
concrete_valis set on the returned proxy so that comparisons such asif x.numel() == 0:resolve to a plain Pythonboolat trace time without raisingTraceError. All other comparisons (e.g.x.numel() > 0) keep the standard FX-proxy behaviour.
- size(dim=None)[source]#
Records a
size()orsize(dim)call in the FX graph.size()— returns aCustomProxyShape(dynamic shapes) or a plaintorch.Size(all-static shapes), matching the behaviour of accessing.shape.size(dim)— returns aCustomProxyIntwhoseconcrete_valis set from fake-tensor metadata so that comparisons such asif x.size(0) == 0:resolve to a plain Pythonboolat trace time without raisingTraceError.
- class yobx.torch.tracing.CustomProxy(node: Node, tracer: TracerBase | None = None)[source]#
Defines a custom proxy to trace the execution of a model and converts it into a fx graph. Works with
CustomTracer.- classmethod cat(tensors: List[CustomProxy], dim: int = 0, *, out=None, axis: int | None = None) CustomProxy[source]#
Implements cat for tensors.
- numel(*args: Any, **kwargs: Any) CustomProxyInt[source]#
Records a
numel()call in the FX graph and returns aCustomProxyInt.When the fake-tensor metadata carries a concrete element count,
concrete_valis set on the returned proxy so that comparisons such asif x.numel() == 0:resolve to a plain Pythonboolat trace time without raisingTraceError. All other comparisons (e.g.x.numel() > 0) keep the standard FX-proxy behaviour.
- size(dim: int | None = None)[source]#
Records a
size()orsize(dim)call in the FX graph.size()— returns aCustomProxyShape(dynamic shapes) or a plaintorch.Size(all-static shapes), matching the behaviour of accessing.shape.size(dim)— returns aCustomProxyIntwhoseconcrete_valis set from fake-tensor metadata so that comparisons such asif x.size(0) == 0:resolve to a plain Pythonboolat trace time without raisingTraceError.
- class yobx.torch.tracing.CustomProxyBool(node: Node, tracer: TracerBase | None = None)[source]#
A proxy for a boolean.
- class yobx.torch.tracing.CustomProxyFloat(node: Node, tracer: TracerBase | None = None)[source]#
A proxy for a float.
- class yobx.torch.tracing.CustomProxyInt(node: Node, tracer: TracerBase | None = None, concrete_val: Any = <object object>, only_positive: bool = False, can_be_null: bool = True)[source]#
A proxy for an integer.
When constructed with a concrete_val (e.g. a backed
SymIntfrom fake-tensor metadata),==and!=against a plain Python integer / float are evaluated concretely so that patterns likeif x.shape[2] == 0:work during symbolic tracing without raisingTraceError. All other comparisons (including proxy-vs-proxy) still create FX nodes as usual.When only_positive is
True(set e.g. after seeingtorch._check(value > 0)), comparisons against non-positive constants are evaluated concretely:value > cforc <= 0→Truevalue >= cforc <= 0→True(since value > 0 ≥ c)value < cforc <= 0→Falsevalue <= cforc <= 0→Falsevalue != cforc <= 0→True
When can_be_null is
False(set e.g. after seeingtorch._check(value != 0)), zero-comparisons are evaluated concretely:value == 0→Falsevalue != 0→True
- class yobx.torch.tracing.CustomProxyShape(node: Node, tracer: TracerBase | None = None, concrete_val: Any = <object object>)[source]#
A
tupleofCustomProxyIntinstances representing a tensor shape with dynamic dimensions.Each element is a valid FX proxy node (so dynamic-shape operators such as
torch.fullare recorded correctly in the graph) and supports equality comparison against plain integer constants so thatif x.shape[2] == 0:evaluates to a Pythonboolwithout raisingTraceError.Use
from_proxy()to construct an instance from a shapeCustomAttributeproxy and its corresponding concretetorch.Size.- classmethod from_proxy(shape_proxy: CustomProxy, concrete_shape: Size) CustomProxyShape[source]#
Build a
CustomProxyShapefrom a shape/size proxy and the corresponding concrete (possibly symbolic)torch.Size.- Parameters:
shape_proxy – An FX proxy whose
[i]subscript creates agetitemgraph node (e.g. aCustomAttributefortensor.shapeor aCustomProxywrapping acall_method("size", …)node).concrete_shape – The concrete
torch.Sizefrom the fake tensor’s metadata (may contain backedSymIntvalues for dynamic dimensions).
- class yobx.torch.tracing.CustomTracer(autowrap_modules: ~typing.Tuple[~types.ModuleType, ...] = (<module 'math' (built-in)>,), autowrap_functions: ~typing.Tuple[~typing.Callable, ...] = (<built-in method ones of type object>, <built-in method zeros of type object>, <built-in method full of type object>, <built-in method empty of type object>, <built-in method arange of type object>), param_shapes_constant: bool = False, module_leaves: ~typing.Dict[type, ~typing.Callable[[~torch.nn.modules.module.Module, str], bool]] | None = None)[source]#
Defines a custom tracer to trace the execution of a model and converts it into a fx graph. Works with
CustomProxy.- ::
from yobx.torch.tracing import CustomTracer
graph = CustomTracer().trace(model)
- Parameters:
autowrap_modules – defaults to (math, ), Python modules whose functions should be wrapped automatically without needing to use fx.wrap().
autowrap_functions – defaults to
_AUTOWRAP_FUNCTIONS, Python functions that should be wrapped automatically without needing to use fx.wrap(). Includes tensor-creating functions (e.g.torch.ones) so that calls with proxy size arguments are captured as traced nodes rather than executed immediately.param_shapes_constant – When this flag is set, calls to shape, size and a few other shape like attributes of a module’s parameter will be evaluated directly, rather than returning a new Proxy value for an attribute access.
module_leaves – modules to be considered as leaves, mapped to a callable
f(module, module_qualified_name) -> boolthat decides whether a specific module instance is a leaf; the tracer does not trace into leaf modules and emitscall_modulenodes for them instead
- call_module(m: Module, forward: Callable[[...], Any], args: tuple[Any, ...], kwargs: dict[str, Any]) Any[source]#
Method that specifies the behavior of this
Tracerwhen it encounters a call to annn.Moduleinstance.By default, the behavior is to check if the called module is a leaf module via
is_leaf_module. If it is, emit acall_modulenode referring tomin theGraph. Otherwise, call theModulenormally, tracing through the operations in itsforwardfunction.This method can be overridden to–for example–create nested traced GraphModules, or any other behavior you would want while tracing across
Moduleboundaries.- Parameters:
m (Module) – The module for which a call is being emitted
forward (Callable) – The forward() method of the
Moduleto be invokedargs (Tuple) – args of the module callsite
kwargs (Dict) – kwargs of the module callsite
- Returns:
The return value from the Module call. In the case that a
call_modulenode was emitted, this is aProxyvalue. Otherwise, it is whatever value was returned from theModuleinvocation.
Note
Backwards-compatibility for this API is guaranteed.
- create_args_for_root(root_fn, is_module, concrete_args=None)[source]#
Create
placeholdernodes corresponding to the signature of therootModule. This method introspects root’s signature and emits those nodes accordingly, also supporting*argsand**kwargs.Warning
This API is experimental and is NOT backward-compatible.
- classmethod graph_erase_node(graph: Graph, node: Node)[source]#
Removes a node and all predecessors with are only consumed by this one.
- is_leaf_module(m: Module, module_qualified_name: str) bool[source]#
A method to specify whether a given
nn.Moduleis a “leaf” module.Leaf modules are the atomic units that appear in the IR, referenced by
call_modulecalls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.- Parameters:
m (Module) – The module being queried about
module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule
foocontains submodulebar, which contains submodulebaz, that module will appear with the qualified namefoo.bar.bazhere.
Note
Backwards-compatibility for this API is guaranteed.
- proxy(node: Node, cls: type[CustomProxy] = <class 'yobx.torch.tracing.CustomProxy'>, **kwargs) Proxy[source]#
Overwrites this method to replace the default Proxy by CustomProxy.
- register_callable(name: str, fn: Callable) Node[source]#
Registers a function and return a unique name.
- Parameters:
name – prefix to prepend to the function name
fn – function
- Returns:
new_name
- classmethod remove_batch_dim_nodes(graph: Graph, verbose: int = 0) int[source]#
Rewrites
_add_batch_dimand_remove_batch_dimnodes introduced bytorch.vmaplowering using a two-phase batch-dimension tracking strategy.Phase 1 – tag:
_add_batch_dim(x, batch_dim, level)is replaced byaten.clone.default(x)and the actual batch-dimension position is stored in the replacement node’s metadata under the key"vmap_batch_dim"as adictmapping level → batch_dim.Phase 2 – propagate: For every
call_functionnode whose input nodes carry"vmap_batch_dim"metadata, that metadata is copied onto the output node so that downstream consumers can find it.Phase 3 – remove:
_remove_batch_dim(x, level, batch_size, out_dim)looks up the actual batch-dimension position of x for the given level from the metadata (falling back tolevel - 1if absent), then:If the actual size of the batch dimension in x is 1 but
batch_size > 1(the tensor was broadcast / “not batched” inside vmap), anaten.expand.defaultnode is emitted first to materialise the full batch.erases the node (replacing all uses with x) when no expand is needed and
actual_batch_dim == out_dim– a no-op case;replaces the node with
aten.movedim.int(x, actual_batch_dim, out_dim)otherwise.
- Parameters:
graph – FX graph to modify in-place
verbose – verbosity level
- Returns:
number of nodes replaced or removed
- classmethod remove_inplace(graph: Graph, exported_program: ExportedProgram | None = None, verbose: int = 0, exc: bool = True, recursive: bool = False) int[source]#
Removes inplace operations.
- Parameters:
graph – graph to modify
exported_program – if available, it is used in the error message to make it easier to trace the code source
verbose – verbosity
exc – raise an exception if not possible, other return -1
recursive – remove node inside submodules
- Returns:
number of inplace nodes removed, a negative number means there are still inplace nodes to be removed but this function is unable to do that, only decompositions may help in that case
- classmethod remove_unnecessary_slices(graph: Graph, verbose: int = 0) int[source]#
Removes unnecessary slices and other nodes doing nothing.
- Parameters:
graph – graph to modify
verbose – verbosity level
- Returns:
number of inplace nodes removed
%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor] (args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
- trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None, remove_inplace: bool = True, update_model_with_callable: bool = True, dynamic_shapes: Any | None = None, verbose: int = 0) Graph[source]#
Trace
rootand return the corresponding FXGraphrepresentation.rootcan either be annn.Moduleinstance or a Python callable.Note that after this call,
self.rootmay be different from therootpassed in here. For example, when a free function is passed totrace(), we will create annn.Moduleinstance to use as the root and add embedded constants to.- Parameters:
root – Either a
Moduleor a function to be traced through. Backwards-compatibility for this parameter is guaranteed.concrete_args – Concrete arguments that should not be treated as Proxies. This parameter is experimental and its backwards-compatibility is NOT guaranteed.
remove_inplace – Removes inplace nodes
update_model_with_callable – in some cases (control flow), the model needs to be updated
dynamic_shapes – dynamic shapes
verbose – verbosity
- Returns:
A
Graphrepresenting the semantics of the passed-inroot
If the model had to be wrapped before being traced, attribute
traced_modelis added to the tracer.
- class yobx.torch.tracing.ScanCCOp[source]#
Replacement for
torch.ops.higher_order.scan()during FX symbolic tracing. The real scan operator cannot be called withCustomProxyarguments because it tries to execute the body function eagerly. This proxy operator defers toCustomProxy.__torch_function__()which records the scan call as a FXcall_functionnode instead.
- yobx.torch.tracing.replace_problematic_function_before_tracing() Generator[source]#
Replaces function that cannot be traced with the default tracer such as
torch.cat().