yobx.torch.new_tracing#

trace_model#

yobx.torch.new_tracing.trace_model(func: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any] | None = None, dynamic_shapes: Dict[str, Any] | None = None, verbose: int = 0, module_leaves: Dict[type, Callable[[...], bool]] | None = None) Graph[source]#

Convenience wrapper: create a DispatchTracer and trace func.

Parameters:
  • func – Callable to trace.

  • args – Positional tensor arguments (real tensors; shapes/dtypes are used for placeholder metadata).

  • kwargs – Optional keyword tensor arguments.

  • dynamic_shapes – Optional dynamic shape specifications; see DispatchTracer.trace() for the format.

  • verbose – verbosity level

  • module_leaves – Optional mapping from module type to a predicate f(module, module_qualified_name=name) -> bool. Modules whose type appears in this mapping and whose predicate returns True are treated as leaves: the tracer emits a single call_function node for the whole module call instead of tracing through its internals. See GraphTracer for details.

Returns:

A torch.fx.Graph representing the computation.

<<<

import torch
from yobx.torch.new_tracing import trace_model

graph = trace_model(
    torch.nn.Linear(4, 4),
    (torch.randn(2, 4),),
)
print(graph)

>>>

    graph():
        %weight : [num_users=0] = placeholder[target=weight]
        %bias : [num_users=1] = placeholder[target=bias]
        %input_1 : [num_users=1] = placeholder[target=input]
        %param_1 : [num_users=1] = placeholder[target=param_1]
        %addmm_default : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %input_1, %param_1), kwargs = {})
        return addmm_default