.torch_interpreter.tracing

class experimental_experiment.torch_interpreter.tracing.CondCCOp[source]

Cannot be imported from torch.ops.higher_order.cond (function cond overwrite submodule cond).

class experimental_experiment.torch_interpreter.tracing.CustomAttribute(root: CustomProxy, attr: str)[source]

To trace attributes.

class experimental_experiment.torch_interpreter.tracing.CustomParameterProxy(tracer: TracerBase, node: Node, name, param)[source]

A special proxy which lets “shape”, “size”, “dim”, and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing.

class experimental_experiment.torch_interpreter.tracing.CustomProxy(node: Node, tracer: TracerBase | None = None)[source]

Defines a custom proxy to trace the execution of a model and converts it into a fx graph. Works with CustomTracer.

classmethod cat(tensors: List[CustomProxy], dim: int = 0, *, out=None, axis: int | None = None) CustomProxy[source]

Implements cat for tensors.

instanceof(cls)[source]

Tells if this proxy represents a specific class.

length()[source]

Returns a proxy for the length.

class experimental_experiment.torch_interpreter.tracing.CustomProxyFloat(node: Node, tracer: TracerBase | None = None)[source]

A proxy for a float.

instanceof(cls)[source]

isinstance

class experimental_experiment.torch_interpreter.tracing.CustomProxyInt(node: Node, tracer: TracerBase | None = None)[source]

A proxy for an integer.

instanceof(cls)[source]

isinstance

class experimental_experiment.torch_interpreter.tracing.CustomTracer(autowrap_modules: ~typing.Tuple[ModuleType] = (<module 'math' (built-in)>, ), autowrap_functions: ~typing.Tuple[~typing.Callable, ...] = (), param_shapes_constant: bool = False)[source]

Defines a custom tracer to trace the execution of a model and converts it into a fx graph. Works with CustomProxy.

::

from experimental_experiment.torch_interpreter.tracing import CustomTracer

graph = CustomTracer().trace(model)

create_arg(a: Any) Argument[source]

Overwrites this method to deal with more argument.

getattr(attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any])[source]

See torch.fx.Tracer.getattr().

classmethod graph_erase_node(graph: Graph, node: Node)[source]

Removes a node and all predecessors with are only consumed by this one.

proxy(node: ~torch.fx.node.Node, cls: type[~experimental_experiment.torch_interpreter.tracing.CustomProxy] = <class 'experimental_experiment.torch_interpreter.tracing.CustomProxy'>) Proxy[source]

Overwrites this method to replace the default Proxy by CustomProxy.

register_callable(name: str, fn: Callable) Node[source]

Registers a function and return a unique name.

Parameters:
  • name – prefix to prepend to the function name

  • fn – function

Returns:

new_name

classmethod remove_inplace(graph: Graph, exported_program: ExportedProgram | None = None, verbose: int = 0, exc: bool = True) int[source]

Removes inplace operations.

Parameters:
  • graph – graph to modify

  • exported_program – if available, it is used in the error message to make it easier to trace the code source

  • verbose – verbosity

  • exc – raise an exception if not possible, other return -1

Returns:

number of inplace nodes removed, a negative number means there are still inplace nodes to be removed but this function is unable to do that, only decompositions may help in that case

The most difficult pattern is the following:

%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
    (args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
%slice_12 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
    (args = (%slice_11, 1, 0, 9223372036854775807), kwargs = {})
%slice_13 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
    (args = (%slice_12, 2, 0, 9223372036854775807), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default]
    (args = (%slice_13, %masked_fill), kwargs = {})
classmethod remove_unnecessary_slices(graph: Graph) int[source]

Removes unnecessary slices:

Parameters:

graph – graph to modify

Returns:

number of inplace nodes removed

%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
    (args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None, remove_inplace: bool = True, update_model_with_callable: bool = True) Graph[source]

Trace root and return the corresponding FX Graph representation. root can either be an nn.Module instance or a Python callable.

Note that after this call, self.root may be different from the root passed in here. For example, when a free function is passed to trace(), we will create an nn.Module instance to use as the root and add embedded constants to.

Parameters:
  • root – Either a Module or a function to be traced through. Backwards-compatibility for this parameter is guaranteed.

  • concrete_args – Concrete arguments that should not be treated as Proxies. This parameter is experimental and its backwards-compatibility is NOT guaranteed.

  • remove_inplace – Removes inplace nodes

  • update_model_with_attribute – in some cases (control flow), the model needs to be

Returns:

A Graph representing the semantics of the passed-in root.

class experimental_experiment.torch_interpreter.tracing.LEAVE_INPLACE[source]

Constant indicating inplace removal failed.

experimental_experiment.torch_interpreter.tracing.replace_problematic_function_before_tracing()[source]

Replaces function that cannot be traced with the default tracer such as torch.cat().

experimental_experiment.torch_interpreter.tracing.setitem_with_transformation(a, b, transformations)[source]

Extended version of setitem to deal with inplace modification.