.torch_interpreter.tracing¶
- class experimental_experiment.torch_interpreter.tracing.CondCCOp[source]¶
Cannot be imported from torch.ops.higher_order.cond (function cond overwrite submodule cond).
- class experimental_experiment.torch_interpreter.tracing.CustomAttribute(root: CustomProxy, attr: str)[source]¶
To trace attributes.
- class experimental_experiment.torch_interpreter.tracing.CustomParameterProxy(tracer: TracerBase, node: Node, name, param)[source]¶
A special proxy which lets “shape”, “size”, “dim”, and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing.
- class experimental_experiment.torch_interpreter.tracing.CustomProxy(node: Node, tracer: TracerBase | None = None)[source]¶
Defines a custom proxy to trace the execution of a model and converts it into a fx graph. Works with
CustomTracer
.- classmethod cat(tensors: List[CustomProxy], dim: int = 0, *, out=None, axis: int | None = None) CustomProxy [source]¶
Implements cat for tensors.
- class experimental_experiment.torch_interpreter.tracing.CustomProxyFloat(node: Node, tracer: TracerBase | None = None)[source]¶
A proxy for a float.
- class experimental_experiment.torch_interpreter.tracing.CustomProxyInt(node: Node, tracer: TracerBase | None = None)[source]¶
A proxy for an integer.
- class experimental_experiment.torch_interpreter.tracing.CustomTracer(autowrap_modules: ~typing.Tuple[ModuleType] = (<module 'math' (built-in)>, ), autowrap_functions: ~typing.Tuple[~typing.Callable, ...] = (), param_shapes_constant: bool = False)[source]¶
Defines a custom tracer to trace the execution of a model and converts it into a fx graph. Works with
CustomProxy
.- ::
from experimental_experiment.torch_interpreter.tracing import CustomTracer
graph = CustomTracer().trace(model)
- classmethod graph_erase_node(graph: Graph, node: Node)[source]¶
Removes a node all predecessors with are only consumed by this one.
- proxy(node: ~torch.fx.node.Node, cls: type[~experimental_experiment.torch_interpreter.tracing.CustomProxy] = <class 'experimental_experiment.torch_interpreter.tracing.CustomProxy'>) Proxy [source]¶
Overwrites this method to replace the default Proxy by CustomProxy.
- register_callable(name: str, fn: Callable) Node [source]¶
Registers a function and return a unique name.
- Parameters:
name – prefix to prepend to the function name
fn – function
- Returns:
new_name
- classmethod remove_inplace(graph: Graph, exported_program: ExportedProgram | None = None) int [source]¶
Removes inplace operations.
- Parameters:
graph – graph to modify
exported_program – if available, it is used in the error message to make it easier to trace the code source
- Returns:
number of inplace nodes removed
The most difficult pattern is the following:
%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor] (args = (%clone, 0, 0, 9223372036854775807), kwargs = {}) %slice_12 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor] (args = (%slice_11, 1, 0, 9223372036854775807), kwargs = {}) %slice_13 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor] (args = (%slice_12, 2, 0, 9223372036854775807), kwargs = {}) %copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default] (args = (%slice_13, %masked_fill), kwargs = {})
- classmethod remove_unnecessary_slices(graph: Graph) int [source]¶
Removes unnecessary slices:
- Parameters:
graph – graph to modify
- Returns:
number of inplace nodes removed
%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor] (args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
- trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None, remove_inplace: bool = True, update_model_with_callable: bool = True) Graph [source]¶
Trace
root
and return the corresponding FXGraph
representation.root
can either be annn.Module
instance or a Python callable.Note that after this call,
self.root
may be different from theroot
passed in here. For example, when a free function is passed totrace()
, we will create annn.Module
instance to use as the root and add embedded constants to.- Parameters:
root – Either a
Module
or a function to be traced through. Backwards-compatibility for this parameter is guaranteed.concrete_args – Concrete arguments that should not be treated as Proxies. This parameter is experimental and its backwards-compatibility is NOT guaranteed.
remove_inplace – Removes inplace nodes
update_model_with_attribute – in some cases (control flow), the model needs to be
- Returns:
A
Graph
representing the semantics of the passed-inroot
.
- experimental_experiment.torch_interpreter.tracing.replace_problematic_function_before_tracing()[source]¶
Replaces function that cannot be traced with the default tracer such as
torch.cat()
.