yobx.torch.new_tracing.tracer#
- class yobx.torch.new_tracing.tracer.GraphTracer(verbose: int = 0, module_leaves: Dict[type, Callable[[...], bool]] | None = None)[source]#
Traces a callable by intercepting all tensor operations via
__torch_dispatch__and records them into atorch.fx.Graph.<<<
import torch from yobx.torch.new_tracing.tracer import GraphTracer def add(x, y): return x + y x = torch.randn(3, 4) y = torch.randn(3, 4) tracer = GraphTracer() graph = tracer.trace(add, (x, y)) print(graph)
>>>
graph(): %x : [num_users=1] = placeholder[target=x] %y : [num_users=1] = placeholder[target=y] %add_tensor : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {}) return add_tensor
The class creates an empty
torch.fx.Graphpopulated by atrace()call. We do the same with dynamic shapes:<<<
import torch from yobx.torch.new_tracing.tracer import GraphTracer def add(x, y): return x + y x = torch.randn(3, 4) y = torch.randn(1, 4) tracer = GraphTracer() graph = tracer.trace(add, (x, y), {}, ({0: "batch"}, {})) print(graph)
>>>
graph(): %x : [num_users=1] = placeholder[target=x] %y : [num_users=1] = placeholder[target=y] %add_tensor : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {}) return add_tensor
Leaf modules — sub-modules that should not be traced into but instead appear as a single
call_functionnode in the graph — can be declared via the module_leaves constructor argument:<<<
import torch from yobx.torch.new_tracing.tracer import GraphTracer class MyLeaf(torch.nn.Module): def forward(self, x): return x * 2 class Outer(torch.nn.Module): def __init__(self): super().__init__() self.leaf = MyLeaf() def forward(self, x): return self.leaf(x) + 1 tracer = GraphTracer(module_leaves={MyLeaf: lambda m, module_qualified_name=None: True}) graph = tracer.trace(Outer(), (torch.randn(2, 4),)) # print(graph fails) graph.print_tabular()
>>>
opcode name target args kwargs ------------- ---------- --------------- -------------- -------- placeholder x x () {} call_function leaf_leaf MyLeaf() (x,) {} call_function add_tensor aten.add.Tensor (leaf_leaf, 1) {} output output output (add_tensor,) {}
- Parameters:
verbose – Verbosity level (0 = silent).
module_leaves – Optional mapping from module type to a predicate
f(module, module_qualified_name=name) -> bool. When the predicate returnsTruefor a given sub-module instance, that sub-module is treated as a leaf: a singlecall_functionnode is emitted for its call site and the tracer does not descend into itsforwardmethod. Internal parameters and buffers of leaf modules are also excluded from the graph’s placeholder nodes.
- dispatch(func: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) Any[source]#
Handle one dispatched operation:
Run the op on meta tensors for shape/dtype inference.
Build FX node args (replacing
TracingTensorwith their nodes, and non-traced tensors with auto-placeholder nodes).Emit a
call_functionnode in the graph.Return wrapped
TracingTensoroutput(s).
- is_not_tensor(value: Any) bool[source]#
Return
Trueif value contains notorch.Tensorleaves.Scalars (
int,float,str) and empty collections are treated as non-tensor. Lists and tuples are inspected recursively. Dicts are inspected by their values.- Parameters:
value – The value to inspect. May be a scalar, tensor, list, tuple, or dict.
- Returns:
Truewhen value has no tensor leaves;Falsewhen any leaf is atorch.Tensor(includingTracingTensor).- Raises:
TypeError – If value has a type that cannot be classified.
- make_fake(a: TracingTensor | Tensor) FakeTensor[source]#
Convert a into a
FakeTensorfor shape/dtype inference insidedispatch().For a
TracingTensor, each symbolic (string) dimension is mapped to aSymIntbacked by the tracer’sShapeEnv; previously seen dimension names are reused so that the same symbol appears wherever the same dynamic dimension is referenced.For a plain
torch.Tensor, a fake tensor with the same shape, dtype, and device is created directly.- Parameters:
a – Either a
TracingTensor(possibly with symbolic dimensions) or a plaintorch.Tensor.- Returns:
A
FakeTensorsuitable for passing to ATen operations inFakeTensorMode.
- make_names(n: int, name: str, arg, treespec)[source]#
Generate a list of n unique child names derived from name.
For a list or tuple arg, names are
"<name>_0","<name>_1", …,"<name>_{n-1}". For a dict arg whose length equals n, names are"<name>_<key>"for each key.- Parameters:
n – Number of names to generate (must equal
len(arg)).name – Base name (typically the parameter name from the function signature).
arg – The original argument (list, tuple, or dict).
treespec – The
torch.utils._pytree.TreeSpecof arg; included for error messages only.
- Returns:
A
listofstrnames of length n.- Raises:
NotImplementedError – If arg has an unsupported type.
- make_tracing_args(args: Tuple[Any, ...], kwargs: Dict[str, Any] | None = None, dynamic_shapes: Tuple[Any, ...] | Dict[str, Any] | None = None, sig_names: List[str] | None = None) Tuple[Tuple[Any, ...], Dict[str, Any]][source]#
Convert args / kwargs into tracing counterparts.
Every
torch.Tensor(or container of tensors) in args and kwargs is replaced by the correspondingTracingTensorplaceholder(s). Non-tensor values are forwarded unchanged.- Parameters:
args – Positional arguments as provided to
trace().kwargs – Keyword arguments as provided to
trace().dynamic_shapes – Optional per-argument dynamic shape mapping; passed through to
make_tracing_arg().sig_names – Parameter names extracted from the traced function’s signature; used to name placeholders and look up dynamic_shapes by name rather than by index.
- Returns:
A
(tracing_args, tracing_kwargs)tuple whose tensor leaves areTracingTensorinstances.
- place(tt: TracingTensor, name: str | None = None) TracingTensor[source]#
Ensure tt is registered in this tracer’s graph as a placeholder.
If tt already has a
_node(i.e. it was produced by a previous call toplaceholder()ordispatch()), it is returned unchanged. Otherwise a new placeholder node is created, tt is bound to it, and the updated tensor is returned.- Parameters:
tt – A
TracingTensorto register. It must not be owned by a differentGraphTracerinstance.name – Unused placeholder for a future name override. The actual node name is always generated as
tt_<counter>.
- Returns:
tt with
_tracerset toselfand_nodeset to the newly created (or pre-existing) placeholder node.- Raises:
AssertionError – If tt already belongs to a different tracer.
- placeholder(name: str, shape: Tuple[int, ...] | TracingShape, dtype: dtype, device: str | device) TracingTensor[source]#
Add a placeholder (input) node to the graph and return the corresponding
TracingTensor.- Parameters:
name – The name of the placeholder in the graph.
shape – Tensor shape (concrete sizes or
TracingShape).dtype – Tensor dtype.
device – Target device (string or
torch.device).
- Returns:
A
TracingTensorrepresenting the graph input.
- register_module_parameters(module: Module) None[source]#
Pre-register all named parameters and buffers of module as placeholder nodes in the graph.
This gives each parameter a meaningful name in the graph (e.g.
linear_weightinstead ofparam_1) and ensures that shared tensors (the sametorch.Tensorreferenced under multiple names) map to exactly one placeholder node.Parameters that belong to leaf sub-modules (see
module_leaves) are skipped: leaf modules are treated as black boxes and their internal parameters are not exposed as graph inputs.Each placeholder node receives two extra metadata entries:
node.meta["torch_name"]: the original dotted parameter name as returned bytorch.nn.Module.named_parameters()(e.g."linear.weight").node.meta["torch_value"]: the actualtorch.Tensorobject (useful for retrieving concrete weight values later).
- Parameters:
module – The
torch.nn.Modulewhose parameters and buffers should be pre-registered.
- trace(func: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any] | None = None, dynamic_shapes: Dict[str, Any] | None = None) Graph[source]#
Trace func with the provided args and return the resulting
torch.fx.Graph.Tensor arguments are replaced by
TracingTensorplaceholders and every dispatched ATen operation is recorded as a graph node. Non-tensor arguments are forwarded as-is.- Parameters:
func – The callable to trace (e.g. an
torch.nn.Moduleinstance or a plain Python function).args – Positional arguments to func. Real
torch.Tensorvalues should be supplied; their shapes and dtypes are used for placeholder metadata.kwargs – Optional keyword arguments to func.
dynamic_shapes – Optional mapping from argument name (parameter name from func’s signature for positional args; key name for keyword args) to a list/tuple of
TracingInt/ int describing the dimensions symbolically. When provided, the corresponding placeholder is given aTracingShapeinstead of a concretetorch.Size.
- Returns:
A
torch.fx.Graphrepresenting the full computation.
<<<
import torch from yobx.torch.new_tracing.tracer import GraphTracer def add(x, y): return x + y graph = GraphTracer().trace(add, (torch.randn(3, 4), torch.randn(3, 4))) print(graph)
>>>
graph(): %x : [num_users=1] = placeholder[target=x] %y : [num_users=1] = placeholder[target=y] %add_tensor : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {}) return add_tensor