experimental_experiment.torch_interpreter.tracing

class experimental_experiment.torch_interpreter.tracing.CondCCOp[source]

Cannot be imported from torch.ops.higher_order.cond (function cond overwrite submodule cond).

class experimental_experiment.torch_interpreter.tracing.CustomAttribute(root: CustomProxy, attr: str)[source]

To trace attributes.

class experimental_experiment.torch_interpreter.tracing.CustomParameterProxy(tracer: TracerBase, node: Node, name, param)[source]

A special proxy which lets “shape”, “size”, “dim”, and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing.

class experimental_experiment.torch_interpreter.tracing.CustomProxy(node: Node, tracer: TracerBase | None = None)[source]

Defines a custom proxy to trace the execution of a model and converts it into a fx graph. Works with CustomTracer.

classmethod cat(tensors: List[CustomProxy], dim: int = 0, *, out=None, axis: int | None = None) CustomProxy[source]

Implements cat for tensors.

instanceof(cls)[source]

Tells if this proxy represents a specific class.

length()[source]

Returns a proxy for the length.

class experimental_experiment.torch_interpreter.tracing.CustomProxyFloat(node: Node, tracer: TracerBase | None = None)[source]

A proxy for a float.

instanceof(cls)[source]

isinstance

class experimental_experiment.torch_interpreter.tracing.CustomProxyInt(node: Node, tracer: TracerBase | None = None)[source]

A proxy for an integer.

instanceof(cls)[source]

isinstance

class experimental_experiment.torch_interpreter.tracing.CustomTracer(autowrap_modules: ~typing.Tuple[ModuleType] = (<module 'math' (built-in)>, ), autowrap_functions: ~typing.Tuple[~typing.Callable, ...] = (), param_shapes_constant: bool = False)[source]

Defines a custom tracer to trace the execution of a model and converts it into a fx graph. Works with CustomProxy.

::

from experimental_experiment.torch_interpreter.tracing import CustomTracer

graph = CustomTracer().trace(model)

create_arg(a: Any) Argument[source]

Overwrites this method to deal with more argument.

getattr(attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any])[source]

See torch.fx.Tracer.getattr().

proxy(node: ~torch.fx.node.Node, cls: type[~experimental_experiment.torch_interpreter.tracing.CustomProxy] = <class 'experimental_experiment.torch_interpreter.tracing.CustomProxy'>) Proxy[source]

Overwrites this method to replace the default Proxy by CustomProxy.

register_callable(name: str, fn: Callable) Node[source]

Registers a function and return a unique name.

Parameters:
  • name – prefix to prepend to the function name

  • fn – function

Returns:

new_name

classmethod remove_inplace(graph: Graph) int[source]

Removes inplace operations.

Returns:

number of inplace nodes removed

trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None, remove_inplace: bool = True, update_model_with_callable: bool = True) Graph[source]

Trace root and return the corresponding FX Graph representation. root can either be an nn.Module instance or a Python callable.

Note that after this call, self.root may be different from the root passed in here. For example, when a free function is passed to trace(), we will create an nn.Module instance to use as the root and add embedded constants to.

Parameters:
  • root – Either a Module or a function to be traced through. Backwards-compatibility for this parameter is guaranteed.

  • concrete_args – Concrete arguments that should not be treated as Proxies. This parameter is experimental and its backwards-compatibility is NOT guaranteed.

  • remove_inplace – Removes inplace nodes

  • update_model_with_attribute – in some cases (control flow), the model needs to be

Returns:

A Graph representing the semantics of the passed-in root.

experimental_experiment.torch_interpreter.tracing.replace_problematic_function_before_tracing()[source]

Replaces function that cannot be traced with the default tracer such as torch.cat().