.torch_interpreter.tracing

class experimental_experiment.torch_interpreter.tracing.CondCCOp[source]

Cannot be imported from torch.ops.higher_order.cond (function cond overwrite submodule cond).

class experimental_experiment.torch_interpreter.tracing.CustomAttribute(root: CustomProxy, attr: str)[source]

To trace attributes.

class experimental_experiment.torch_interpreter.tracing.CustomParameterProxy(tracer: TracerBase, node: Node, name, param)[source]

A special proxy which lets “shape”, “size”, “dim”, and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing.

class experimental_experiment.torch_interpreter.tracing.CustomProxy(node: Node, tracer: TracerBase | None = None)[source]

Defines a custom proxy to trace the execution of a model and converts it into a fx graph. Works with CustomTracer.

classmethod cat(tensors: List[CustomProxy], dim: int = 0, *, out=None, axis: int | None = None) CustomProxy[source]

Implements cat for tensors.

instanceof(cls)[source]

Tells if this proxy represents a specific class.

length()[source]

Returns a proxy for the length.

class experimental_experiment.torch_interpreter.tracing.CustomProxyFloat(node: Node, tracer: TracerBase | None = None)[source]

A proxy for a float.

instanceof(cls)[source]

isinstance

class experimental_experiment.torch_interpreter.tracing.CustomProxyInt(node: Node, tracer: TracerBase | None = None)[source]

A proxy for an integer.

instanceof(cls)[source]

isinstance

class experimental_experiment.torch_interpreter.tracing.CustomTracer(autowrap_modules: ~typing.Tuple[ModuleType] = (<module 'math' (built-in)>, ), autowrap_functions: ~typing.Tuple[~typing.Callable, ...] = (), param_shapes_constant: bool = False)[source]

Defines a custom tracer to trace the execution of a model and converts it into a fx graph. Works with CustomProxy.

::

from experimental_experiment.torch_interpreter.tracing import CustomTracer

graph = CustomTracer().trace(model)

create_arg(a: Any) Argument[source]

Overwrites this method to deal with more argument.

getattr(attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any])[source]

See torch.fx.Tracer.getattr().

classmethod graph_erase_node(graph: Graph, node: Node)[source]

Removes a node all predecessors with are only consumed by this one.

proxy(node: ~torch.fx.node.Node, cls: type[~experimental_experiment.torch_interpreter.tracing.CustomProxy] = <class 'experimental_experiment.torch_interpreter.tracing.CustomProxy'>) Proxy[source]

Overwrites this method to replace the default Proxy by CustomProxy.

register_callable(name: str, fn: Callable) Node[source]

Registers a function and return a unique name.

Parameters:
  • name – prefix to prepend to the function name

  • fn – function

Returns:

new_name

classmethod remove_inplace(graph: Graph, exported_program: ExportedProgram | None = None) int[source]

Removes inplace operations.

Parameters:
  • graph – graph to modify

  • exported_program – if available, it is used in the error message to make it easier to trace the code source

Returns:

number of inplace nodes removed

The most difficult pattern is the following:

%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
    (args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
%slice_12 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
    (args = (%slice_11, 1, 0, 9223372036854775807), kwargs = {})
%slice_13 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
    (args = (%slice_12, 2, 0, 9223372036854775807), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default]
    (args = (%slice_13, %masked_fill), kwargs = {})
classmethod remove_unnecessary_slices(graph: Graph) int[source]

Removes unnecessary slices:

Parameters:

graph – graph to modify

Returns:

number of inplace nodes removed

%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
    (args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None, remove_inplace: bool = True, update_model_with_callable: bool = True) Graph[source]

Trace root and return the corresponding FX Graph representation. root can either be an nn.Module instance or a Python callable.

Note that after this call, self.root may be different from the root passed in here. For example, when a free function is passed to trace(), we will create an nn.Module instance to use as the root and add embedded constants to.

Parameters:
  • root – Either a Module or a function to be traced through. Backwards-compatibility for this parameter is guaranteed.

  • concrete_args – Concrete arguments that should not be treated as Proxies. This parameter is experimental and its backwards-compatibility is NOT guaranteed.

  • remove_inplace – Removes inplace nodes

  • update_model_with_attribute – in some cases (control flow), the model needs to be

Returns:

A Graph representing the semantics of the passed-in root.

experimental_experiment.torch_interpreter.tracing.replace_problematic_function_before_tracing()[source]

Replaces function that cannot be traced with the default tracer such as torch.cat().