yobx.torch.new_tracing.tracer#

class yobx.torch.new_tracing.tracer.GraphTracer(verbose: int = 0, module_leaves: Dict[type, Callable[[...], bool]] | None = None)[source]#

Traces a callable by intercepting all tensor operations via __torch_dispatch__ and records them into a torch.fx.Graph.

<<<

import torch
from yobx.torch.new_tracing.tracer import GraphTracer


def add(x, y):
    return x + y


x = torch.randn(3, 4)
y = torch.randn(3, 4)
tracer = GraphTracer()
graph = tracer.trace(add, (x, y))
print(graph)

>>>

    graph():
        %x : [num_users=1] = placeholder[target=x]
        %y : [num_users=1] = placeholder[target=y]
        %add_tensor : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
        return add_tensor

The class creates an empty torch.fx.Graph populated by a trace() call. We do the same with dynamic shapes:

<<<

import torch
from yobx.torch.new_tracing.tracer import GraphTracer


def add(x, y):
    return x + y


x = torch.randn(3, 4)
y = torch.randn(1, 4)
tracer = GraphTracer()
graph = tracer.trace(add, (x, y), {}, ({0: "batch"}, {}))
print(graph)

>>>

    graph():
        %x : [num_users=1] = placeholder[target=x]
        %y : [num_users=1] = placeholder[target=y]
        %add_tensor : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
        return add_tensor

Leaf modules — sub-modules that should not be traced into but instead appear as a single call_function node in the graph — can be declared via the module_leaves constructor argument:

<<<

import torch
from yobx.torch.new_tracing.tracer import GraphTracer


class MyLeaf(torch.nn.Module):
    def forward(self, x):
        return x * 2


class Outer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.leaf = MyLeaf()

    def forward(self, x):
        return self.leaf(x) + 1


tracer = GraphTracer(module_leaves={MyLeaf: lambda m, module_qualified_name=None: True})
graph = tracer.trace(Outer(), (torch.randn(2, 4),))
# print(graph fails)
graph.print_tabular()

>>>

    opcode         name        target           args            kwargs
    -------------  ----------  ---------------  --------------  --------
    placeholder    x           x                ()              {}
    call_function  leaf_leaf   MyLeaf()         (x,)            {}
    call_function  add_tensor  aten.add.Tensor  (leaf_leaf, 1)  {}
    output         output      output           (add_tensor,)   {}
Parameters:
  • verbose – Verbosity level (0 = silent).

  • module_leaves – Optional mapping from module type to a predicate f(module, module_qualified_name=name) -> bool. When the predicate returns True for a given sub-module instance, that sub-module is treated as a leaf: a single call_function node is emitted for its call site and the tracer does not descend into its forward method. Internal parameters and buffers of leaf modules are also excluded from the graph’s placeholder nodes.

dispatch(func: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) Any[source]#

Handle one dispatched operation:

  1. Run the op on meta tensors for shape/dtype inference.

  2. Build FX node args (replacing TracingTensor with their nodes, and non-traced tensors with auto-placeholder nodes).

  3. Emit a call_function node in the graph.

  4. Return wrapped TracingTensor output(s).

is_not_tensor(value: Any) bool[source]#

Return True if value contains no torch.Tensor leaves.

Scalars (int, float, str) and empty collections are treated as non-tensor. Lists and tuples are inspected recursively. Dicts are inspected by their values.

Parameters:

value – The value to inspect. May be a scalar, tensor, list, tuple, or dict.

Returns:

True when value has no tensor leaves; False when any leaf is a torch.Tensor (including TracingTensor).

Raises:

TypeError – If value has a type that cannot be classified.

make_fake(a: TracingTensor | Tensor) FakeTensor[source]#

Convert a into a FakeTensor for shape/dtype inference inside dispatch().

For a TracingTensor, each symbolic (string) dimension is mapped to a SymInt backed by the tracer’s ShapeEnv; previously seen dimension names are reused so that the same symbol appears wherever the same dynamic dimension is referenced.

For a plain torch.Tensor, a fake tensor with the same shape, dtype, and device is created directly.

Parameters:

a – Either a TracingTensor (possibly with symbolic dimensions) or a plain torch.Tensor.

Returns:

A FakeTensor suitable for passing to ATen operations in FakeTensorMode.

make_names(n: int, name: str, arg, treespec)[source]#

Generate a list of n unique child names derived from name.

For a list or tuple arg, names are "<name>_0", "<name>_1", …, "<name>_{n-1}". For a dict arg whose length equals n, names are "<name>_<key>" for each key.

Parameters:
  • n – Number of names to generate (must equal len(arg)).

  • name – Base name (typically the parameter name from the function signature).

  • arg – The original argument (list, tuple, or dict).

  • treespec – The torch.utils._pytree.TreeSpec of arg; included for error messages only.

Returns:

A list of str names of length n.

Raises:

NotImplementedError – If arg has an unsupported type.

make_tracing_args(args: Tuple[Any, ...], kwargs: Dict[str, Any] | None = None, dynamic_shapes: Tuple[Any, ...] | Dict[str, Any] | None = None, sig_names: List[str] | None = None) Tuple[Tuple[Any, ...], Dict[str, Any]][source]#

Convert args / kwargs into tracing counterparts.

Every torch.Tensor (or container of tensors) in args and kwargs is replaced by the corresponding TracingTensor placeholder(s). Non-tensor values are forwarded unchanged.

Parameters:
  • args – Positional arguments as provided to trace().

  • kwargs – Keyword arguments as provided to trace().

  • dynamic_shapes – Optional per-argument dynamic shape mapping; passed through to make_tracing_arg().

  • sig_names – Parameter names extracted from the traced function’s signature; used to name placeholders and look up dynamic_shapes by name rather than by index.

Returns:

A (tracing_args, tracing_kwargs) tuple whose tensor leaves are TracingTensor instances.

place(tt: TracingTensor, name: str | None = None) TracingTensor[source]#

Ensure tt is registered in this tracer’s graph as a placeholder.

If tt already has a _node (i.e. it was produced by a previous call to placeholder() or dispatch()), it is returned unchanged. Otherwise a new placeholder node is created, tt is bound to it, and the updated tensor is returned.

Parameters:
  • tt – A TracingTensor to register. It must not be owned by a different GraphTracer instance.

  • name – Unused placeholder for a future name override. The actual node name is always generated as tt_<counter>.

Returns:

tt with _tracer set to self and _node set to the newly created (or pre-existing) placeholder node.

Raises:

AssertionError – If tt already belongs to a different tracer.

placeholder(name: str, shape: Tuple[int, ...] | TracingShape, dtype: dtype, device: str | device) TracingTensor[source]#

Add a placeholder (input) node to the graph and return the corresponding TracingTensor.

Parameters:
  • name – The name of the placeholder in the graph.

  • shape – Tensor shape (concrete sizes or TracingShape).

  • dtype – Tensor dtype.

  • device – Target device (string or torch.device).

Returns:

A TracingTensor representing the graph input.

register_module_parameters(module: Module) None[source]#

Pre-register all named parameters and buffers of module as placeholder nodes in the graph.

This gives each parameter a meaningful name in the graph (e.g. linear_weight instead of param_1) and ensures that shared tensors (the same torch.Tensor referenced under multiple names) map to exactly one placeholder node.

Parameters that belong to leaf sub-modules (see module_leaves) are skipped: leaf modules are treated as black boxes and their internal parameters are not exposed as graph inputs.

Each placeholder node receives two extra metadata entries:

  • node.meta["torch_name"]: the original dotted parameter name as returned by torch.nn.Module.named_parameters() (e.g. "linear.weight").

  • node.meta["torch_value"]: the actual torch.Tensor object (useful for retrieving concrete weight values later).

Parameters:

module – The torch.nn.Module whose parameters and buffers should be pre-registered.

trace(func: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any] | None = None, dynamic_shapes: Dict[str, Any] | None = None) Graph[source]#

Trace func with the provided args and return the resulting torch.fx.Graph.

Tensor arguments are replaced by TracingTensor placeholders and every dispatched ATen operation is recorded as a graph node. Non-tensor arguments are forwarded as-is.

Parameters:
  • func – The callable to trace (e.g. an torch.nn.Module instance or a plain Python function).

  • args – Positional arguments to func. Real torch.Tensor values should be supplied; their shapes and dtypes are used for placeholder metadata.

  • kwargs – Optional keyword arguments to func.

  • dynamic_shapes – Optional mapping from argument name (parameter name from func’s signature for positional args; key name for keyword args) to a list/tuple of TracingInt / int describing the dimensions symbolically. When provided, the corresponding placeholder is given a TracingShape instead of a concrete torch.Size.

Returns:

A torch.fx.Graph representing the full computation.

<<<

import torch
from yobx.torch.new_tracing.tracer import GraphTracer


def add(x, y):
    return x + y


graph = GraphTracer().trace(add, (torch.randn(3, 4), torch.randn(3, 4)))
print(graph)

>>>

    graph():
        %x : [num_users=1] = placeholder[target=x]
        %y : [num_users=1] = placeholder[target=y]
        %add_tensor : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
        return add_tensor