.torch_interpreter.tracing¶
- class experimental_experiment.torch_interpreter.tracing.CondCCOp[source]¶
Cannot be imported from torch.ops.higher_order.cond (function cond overwrite submodule cond).
- class experimental_experiment.torch_interpreter.tracing.CustomAttribute(root: CustomProxy, attr: str)[source]¶
To trace attributes.
- class experimental_experiment.torch_interpreter.tracing.CustomParameterProxy(tracer: TracerBase, node: Node, name, param)[source]¶
A special proxy which lets “shape”, “size”, “dim”, and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing.
- class experimental_experiment.torch_interpreter.tracing.CustomProxy(node: Node, tracer: TracerBase | None = None)[source]¶
Defines a custom proxy to trace the execution of a model and converts it into a fx graph. Works with
CustomTracer
.- classmethod cat(tensors: List[CustomProxy], dim: int = 0, *, out=None, axis: int | None = None) CustomProxy [source]¶
Implements cat for tensors.
- class experimental_experiment.torch_interpreter.tracing.CustomProxyFloat(node: Node, tracer: TracerBase | None = None)[source]¶
A proxy for a float.
- class experimental_experiment.torch_interpreter.tracing.CustomProxyInt(node: Node, tracer: TracerBase | None = None)[source]¶
A proxy for an integer.
- class experimental_experiment.torch_interpreter.tracing.CustomTracer(autowrap_modules: ~typing.Tuple[ModuleType] = (<module 'math' (built-in)>, ), autowrap_functions: ~typing.Tuple[~typing.Callable, ...] = (), param_shapes_constant: bool = False)[source]¶
Defines a custom tracer to trace the execution of a model and converts it into a fx graph. Works with
CustomProxy
.- ::
from experimental_experiment.torch_interpreter.tracing import CustomTracer
graph = CustomTracer().trace(model)
- classmethod graph_erase_node(graph: Graph, node: Node)[source]¶
Removes a node and all predecessors with are only consumed by this one.
- proxy(node: ~torch.fx.node.Node, cls: type[~experimental_experiment.torch_interpreter.tracing.CustomProxy] = <class 'experimental_experiment.torch_interpreter.tracing.CustomProxy'>) Proxy [source]¶
Overwrites this method to replace the default Proxy by CustomProxy.
- register_callable(name: str, fn: Callable) Node [source]¶
Registers a function and return a unique name.
- Parameters:
name – prefix to prepend to the function name
fn – function
- Returns:
new_name
- classmethod remove_inplace(graph: Graph, exported_program: ExportedProgram | None = None, verbose: int = 0, exc: bool = True) int [source]¶
Removes inplace operations.
- Parameters:
graph – graph to modify
exported_program – if available, it is used in the error message to make it easier to trace the code source
verbose – verbosity
exc – raise an exception if not possible, other return -1
- Returns:
number of inplace nodes removed, a negative number means there are still inplace nodes to be removed but this function is unable to do that, only decompositions may help in that case
The most difficult pattern is the following:
%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor] (args = (%clone, 0, 0, 9223372036854775807), kwargs = {}) %slice_12 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor] (args = (%slice_11, 1, 0, 9223372036854775807), kwargs = {}) %slice_13 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor] (args = (%slice_12, 2, 0, 9223372036854775807), kwargs = {}) %copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default] (args = (%slice_13, %masked_fill), kwargs = {})
- classmethod remove_unnecessary_slices(graph: Graph) int [source]¶
Removes unnecessary slices:
- Parameters:
graph – graph to modify
- Returns:
number of inplace nodes removed
%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor] (args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
- trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None, remove_inplace: bool = True, update_model_with_callable: bool = True) Graph [source]¶
Trace
root
and return the corresponding FXGraph
representation.root
can either be annn.Module
instance or a Python callable.Note that after this call,
self.root
may be different from theroot
passed in here. For example, when a free function is passed totrace()
, we will create annn.Module
instance to use as the root and add embedded constants to.- Parameters:
root – Either a
Module
or a function to be traced through. Backwards-compatibility for this parameter is guaranteed.concrete_args – Concrete arguments that should not be treated as Proxies. This parameter is experimental and its backwards-compatibility is NOT guaranteed.
remove_inplace – Removes inplace nodes
update_model_with_attribute – in some cases (control flow), the model needs to be
- Returns:
A
Graph
representing the semantics of the passed-inroot
.
- class experimental_experiment.torch_interpreter.tracing.LEAVE_INPLACE[source]¶
Constant indicating inplace removal failed.
- experimental_experiment.torch_interpreter.tracing.replace_problematic_function_before_tracing()[source]¶
Replaces function that cannot be traced with the default tracer such as
torch.cat()
.