yobx.torch.new_tracing#
trace_model#
- yobx.torch.new_tracing.trace_model(func: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any] | None = None, dynamic_shapes: Dict[str, Any] | None = None, verbose: int = 0, module_leaves: Dict[type, Callable[[...], bool]] | None = None) Graph[source]#
Convenience wrapper: create a
DispatchTracerand trace func.- Parameters:
func – Callable to trace.
args – Positional tensor arguments (real tensors; shapes/dtypes are used for placeholder metadata).
kwargs – Optional keyword tensor arguments.
dynamic_shapes – Optional dynamic shape specifications; see
DispatchTracer.trace()for the format.verbose – verbosity level
module_leaves – Optional mapping from module type to a predicate
f(module, module_qualified_name=name) -> bool. Modules whose type appears in this mapping and whose predicate returnsTrueare treated as leaves: the tracer emits a singlecall_functionnode for the whole module call instead of tracing through its internals. SeeGraphTracerfor details.
- Returns:
A
torch.fx.Graphrepresenting the computation.
<<<
import torch from yobx.torch.new_tracing import trace_model graph = trace_model( torch.nn.Linear(4, 4), (torch.randn(2, 4),), ) print(graph)
>>>
graph(): %weight : [num_users=0] = placeholder[target=weight] %bias : [num_users=1] = placeholder[target=bias] %input_1 : [num_users=1] = placeholder[target=input] %param_1 : [num_users=1] = placeholder[target=param_1] %addmm_default : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %input_1, %param_1), kwargs = {}) return addmm_default