Source code for onnx_diagnostic.export.cf_simple_loop_for

import contextlib
from typing import Callable, List, Optional, Sequence, Tuple, Union
import torch
from torch._C import DispatchKey
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
import torch.utils._pytree as pytree
from torch._higher_order_ops.utils import (
    check_input_alias_and_mutation_return_outputs,
    reenter_make_fx,
    unique_graph_id,
    validate_subgraph_args_types,
)
import torch._dynamo.variables.higher_order_ops as hop
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils._python_dispatch import _get_current_dispatch_mode


[docs] class SimpleLoopForOp(HigherOrderOperator): """Higher order op for :func:`simple_loop_for`.""" def __init__(self): super().__init__("simple_loop_for") def __call__(self, n_iter, body_fn, operands, concatenation_dims=None): validate_subgraph_args_types(operands) return super().__call__(n_iter, body_fn, operands, concatenation_dims) def gen_schema(self, n_iter, body_fn, operands, concatenation_dims): from torch._higher_order_ops.schema import HopSchemaGenerator from torch._higher_order_ops.utils import materialize_as_graph body_gm: torch.fx.GraphModule = materialize_as_graph( # type: ignore[annotation-unchecked] body_fn, (torch.tensor(0, dtype=torch.int64), *operands) ) ( _, _, _, body_mutated_inputs, body_outputs, ) = check_input_alias_and_mutation_return_outputs(body_gm) mutated_inputs = body_mutated_inputs schema_gen = HopSchemaGenerator(self) schema_gen.add_arg("n_iter", n_iter) schema_gen.add_arg("body_fn", body_gm) for idx, arg in enumerate(operands): schema_gen.add_arg(f"operand{idx}", arg, is_mutated=idx in mutated_inputs) for out in body_outputs: schema_gen.add_output(out) assert concatenation_dims is None or len(concatenation_dims) == len(body_outputs), ( f"concatenation_dims={concatenation_dims} but its length should be equal to " f"the number of outputs ({len(body_outputs)})" ) schema_gen.add_schema_tree_spec(n_iter, body_fn, operands, concatenation_dims) return schema_gen.gen_schema()
simple_loop_for_op = SimpleLoopForOp() def _simple_loop_for_fn( n_iter: torch.Tensor, body_fn: Callable, operands: Tuple[torch.Tensor, ...] = (), concatenation_dims: Optional[Sequence[int]] = None, ) -> Tuple[torch.Tensor, ...]: """ Python implementation of the loop. :param n_iter: number of iteration :param body_fn: function implementing the body :param concatenation_dims: dimension used to reduce the list produced by the loop :param operands: arguments to the loop body :return: results """ torch._check( isinstance(n_iter, (int, torch.Tensor)), lambda: f"Unexpected type {type(n_iter)} for n_iter", ) torch._check(callable(body_fn), lambda: f"Unexpected type {type(body_fn)} for body_fn") torch._check( concatenation_dims is None or isinstance(concatenation_dims, (list, tuple)), lambda: f"Unexpected type {type(concatenation_dims)} for concatenation_dims", ) torch._check( isinstance(operands, tuple), lambda: f"Unexpected type {type(operands)} for operands" ) res: List[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = [] for i in torch.arange( n_iter, dtype=torch.int64 if isinstance(n_iter, int) else n_iter.dtype ): r = body_fn(i, *operands) if isinstance(r, tuple): assert not res or len(r) == len(res[-1]), ( f"Unexpected number of results {len(r)} for function {body_fn}, " f"expected {len(res[-1])}" ) assert all(isinstance(t, torch.Tensor) for t in r), ( f"Unexpected type {[type(_) for _ in r]} for returned by function {body_fn}, " f"it must be a tuple of Tensor or a Tensor." ) res.append(r) else: assert isinstance(r, torch.Tensor), ( f"Unexpected type {type(r)} coming from function {body_fn}, " f"it must be a tuple of Tensor or a Tensor." ) assert not res or len(res[-1]) == 1, ( f"Unexpected number of results {len(r)} coming from function {body_fn}, " f"expected {len(res[-1])}" ) res.append((r,)) if not res: return torch.empty(tuple(), dtype=torch.float32, device=operands[0].device) n_res = len(res[0]) return tuple( torch.cat( [r[i] for r in res], dim=( 0 if concatenation_dims is None or i >= len(concatenation_dims) else concatenation_dims[i] ), ) for i in range(n_res) ) def _simple_loop_for( n_iter: Union[int, torch.Tensor], body_fn: Callable, operands: Tuple[torch.Tensor, ...] = (), concatenation_dims: Optional[Sequence[int]] = None, ) -> Tuple[torch.Tensor, ...]: def _validate_input(n_iter, body_fn, operands, concatenation_dims): assert isinstance( n_iter, (int, torch.Tensor, torch.SymInt) ), f"Expected pred to be bool or tensor, but got {n_iter}." assert ( not isinstance(n_iter, torch.Tensor) or n_iter.numel() == 1 ), f"Expected pred to be bool or single-element tensor, but got {n_iter}." assert callable(body_fn), "Expect both branches to be callable." assert isinstance(operands, (tuple, list)) and pytree.tree_all( lambda t: isinstance(t, torch.Tensor), operands ), ( "Expect operands to be a tuple of possibly nested dict/list/tuple that only " f"consists of tensor leaves, but got {operands}." ) assert concatenation_dims is None or ( isinstance(concatenation_dims, (list, tuple)) and all(isinstance(i, int) for i in concatenation_dims) ), ( f"concatenation_dims should be None or a list of integers but it is " f"{concatenation_dims}. Its length should be equal to the number of outputs." ) assert torch._dynamo.is_dynamo_supported(), "simple_loop_for requires dynamo support." if torch.compiler.is_dynamo_compiling(): return simple_loop_for_op( n_iter, body_fn, operands, concatenation_dims=concatenation_dims ) if isinstance(n_iter, (bool, int, float)): torch._check( isinstance(n_iter, int), lambda: f"n_iter must be an integer or a tensor not {type(n_iter)}", ) return _simple_loop_for_fn( n_iter, body_fn, operands, concatenation_dims=concatenation_dims ) def _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims): return simple_loop_for_op(n_iter, body_fn, operands, concatenation_dims) _validate_input(n_iter, body_fn, operands, concatenation_dims) # This requires torch>=2.10. from torch._higher_order_ops.utils import setup_compilation_env with setup_compilation_env() as _backend: return _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims) # This is needed to support function body using module weights or function body # defined as a class method. This is yet to be implemented. # cpl = torch.compile(_loop_for_op_wrapper, backend=_backend, fullgraph=True) # return cpl(n_iter, body_fn, operands, concatenation_dims)
[docs] def trace_simple_loop_for( proxy_mode, func_overload, n_iter, body_fn, operands, concatenation_dims ): """See function ``simple_loop_for``.""" assert isinstance(operands, (list, tuple)) and ( concatenation_dims is None or ( isinstance(concatenation_dims, (list, tuple)) and all(isinstance(i, int) for i in concatenation_dims) ) ), ( f"simple_loop_for operands must be a list or tuple of tensors and SymInts and " f"concatenation_dims must be None or a list of integer, " f"operands={[type(o) for o in operands]}, " f"concatenation_dims={concatenation_dims}" ) body_graph = reenter_make_fx(body_fn)(n_iter, *operands) body_outs = [] for node in body_graph.graph.nodes: if node.op == "output": body_outs.extend(node.args) # flat_body_outs = pytree.arg_tree_leaves(*body_outs) _i, body_name = unique_graph_id(proxy_mode, prefix="body_graph") proxy_mode.tracer.root.register_module(body_name, body_graph) args = (n_iter, body_graph, operands, concatenation_dims) proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) out_proxy = proxy_mode.tracer.create_proxy("call_function", func_overload, proxy_args, {}) out = func_overload(n_iter, body_graph, operands, concatenation_dims) return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
[docs] @simple_loop_for_op.py_impl(DispatchKey.CompositeExplicitAutograd) def loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None): """Registered eager mode implementation.""" assert all(isinstance(o, torch.Tensor) for o in operands) and ( concatenation_dims is None or ( isinstance(concatenation_dims, (list, tuple)) and all(isinstance(i, int) for i in concatenation_dims) ) ), ( f"simple_loop_for operands must be a list or tuple of tensors and SymInts and " f"concatenation_dims must be None or a list of integer, " f"operands={[type(o) for o in operands]}, " f"concatenation_dims={concatenation_dims}" ) mode = _get_current_dispatch_mode() assert mode is None, "Mode should never be enabled for CPU/CUDA key" is_fake = isinstance(n_iter, torch._subclasses.fake_tensor.FakeTensor) res = _simple_loop_for_fn(n_iter, body_fn, operands, concatenation_dims=concatenation_dims) assert is_fake or not any( isinstance(r, torch._subclasses.fake_tensor.FakeTensor) for r in res ), ( f"One result is a fake tensor but the inputs were not, type(n_iter)={type(n_iter)}, " f"operands: {[type(_) for _ in operands]}, res: {[type(_) for _ in res]}" ) return res
[docs] @simple_loop_for_op.py_impl(ProxyTorchDispatchMode) def inner(mode, n_iter, body_fn, operands, concatenation_dims=None): """Registered tracing implementation.""" return trace_simple_loop_for( mode, simple_loop_for_op, n_iter, body_fn, operands, concatenation_dims )
[docs] @simple_loop_for_op.py_impl(FakeTensorMode) def simple_loop_for_fake_tensor_mode(mode, n_iter, body_fn, operands, concatenation_dims=None): """Registered FakeMode implementation.""" ignore_fresh_unbacked = contextlib.nullcontext() if mode.shape_env: ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols() with mode, ignore_fresh_unbacked: flat_body_outs, true_body_spec = pytree.tree_flatten(body_fn(n_iter, *operands)) return pytree.tree_unflatten(flat_body_outs, true_body_spec)
# Registration for autograd. simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU) simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
[docs] class SimpleLoopForHigherOrderVariable(hop.TorchHigherOrderOperatorVariable): """ Replicates the same pattern found for other higher order operators. This enables recursive compilation and the use of modules inside a function. """ _HOP_NAME = "simple_loop_for" _ALLOW_FALLBACK_TO_EAGER = False supports_input_mutation = False supports_aliasing = False def _call_function( self, tx: torch._dynamo.symbolic_convert.InstructionTranslator, args: list[hop.VariableTracker], kwargs: dict[str, hop.VariableTracker], ) -> hop.VariableTracker: """Main function.""" args, kwargs = hop.LazyVariableTracker.realize_all((args, kwargs)) for i, k in enumerate(["n_iter", "body_fn", "operands", "concatenated_dims"]): if v := kwargs.pop(k, None): assert i == len(args), "did not provide the right number of non-keyword args" args.append(v) if len(args) != 4 or kwargs: hop.unimplemented( gb_type="simple_loop_for: improper args/kwargs", context=f"args: {args}, kwargs: {kwargs}", explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) " f"and no keyword arguments (got {len(kwargs)})", hints=[*hop.graph_break_hints.USER_ERROR], ) # Specialize into one of the branches since pred is constant n_iter, body_fn, operands, _concatenated_dims = args assert type(n_iter) is not hop.ConstantVariable, ( f"n_iter is a {type(n_iter)}. When used simple_loop_for, " f"it unrolls the loop. A SymInt should be used." ) # predicate if type(n_iter.realize()) not in ( hop.ConstantVariable, hop.TensorVariable, hop.SymNodeVariable, ): hop.unimplemented( gb_type="simple_loop_for: improper predicate", context=str(n_iter), explanation=( f"Expected `n_iter` to be an int or a integer " f"tensor with a single item " f"but got {str(type(n_iter))} with original python type " f"{str(n_iter.python_type())}." ), hints=[*hop.graph_break_hints.USER_ERROR], ) # operands if not isinstance(operands, (hop.ListVariable, hop.TupleVariable)): hop.unimplemented( gb_type="simple_loop_for: improper operands", context=str(operands), explanation="Expected `operands` to be a list/tuple " f"but got {operands.python_type()}.", hints=[*hop.graph_break_hints.USER_ERROR], ) operands_seq = operands.unpack_var_sequence(tx) if not hop.only_consist_of( operands, (hop.TensorVariable, hop.ConstantVariable, hop.SymNodeVariable) ): hop.unimplemented( gb_type="simple_loop_for: improper operands contents", context=str(operands), explanation=( "Expected `operands` to be a list/tuple of pytrees " "that only consists of tensor leaves." ), hints=[*hop.graph_break_hints.USER_ERROR], ) # branches hop._check_supported_callable_arg(tx, body_fn, "body_fn") def speculate_body(): ( (ret_val, ret_spec), ret_graph, ret_lifted_freevars, ) = hop.speculate_subgraph( tx, args[1], (args[0], *operands_seq), {}, self._HOP_NAME, source_target=self.value, should_flatten_outputs=True, # TODO - removing consts from control flow ops need more work remove_consts_from_outputs=False, supports_input_mutation=self.supports_input_mutation, supports_aliasing=self.supports_aliasing, ) # need to ensure we increase epoch so we don't memoize unbacked bindings # across different subgraphs which can interfere with runtime assertion # generation. tx.fake_mode.epoch += 1 if not hop.only_consist_of(ret_val, (hop.TensorVariable, hop.ConstantVariable)): hop.unimplemented( gb_type="simple_loop_for: unsupported branch return type", context=str(ret_val), explanation=( "Expected branches to return a possibly nested " "pytree of tensors or constant ints." ), hints=[*hop.graph_break_hints.USER_ERROR], ) for ret in ret_val.unpack_var_sequence(tx): if ret.is_python_constant() and not isinstance(ret.as_python_constant(), int): hop.unimplemented( gb_type=( "simple_loop_for: unsupported branch return type " "(constant non-int)" ), context=str(ret_val), explanation="Constants returned from branches must be ints.", hints=[*hop.graph_break_hints.USER_ERROR], ) return ret_val, ret_spec, ret_graph, ret_lifted_freevars body_r, body_spec, body_graph, body_lifted_freevars = speculate_body() body_nn_modules = dict(tx.output.nn_modules) same_spec = body_spec.treespec.as_python_constant() if same_spec is not NotImplemented and not same_spec: hop.unimplemented( gb_type="simple_loop_for: differing branch outputs", context=( f"body_spec: {body_spec.treespec}, false_spec: " f"{body_spec.treespec}, same_spec: {same_spec}" ), explanation="Expected branches to return the same pytree structure.", hints=[*hop.graph_break_hints.USER_ERROR], ) body_name = tx.output.install_subgraph( "loop_body", torch.fx.GraphModule(body_nn_modules, body_graph) ) body_node = hop.make_attr(tx, body_name) p_args = ( n_iter.as_proxy(), body_node, # We pick true_shared but it shouldn't matter operands.as_proxy() + tuple(body_lifted_freevars.keys()), ) return hop._call_function_and_unflatten_output( tx, simple_loop_for, p_args, {}, None, body_spec, body_r, )
hop._hop_name_to_variable_class["simple_loop_for"] = SimpleLoopForHigherOrderVariable # @torch._functorch.utils.exposed_in("torch")
[docs] def simple_loop_for( n_iter: Union[int, torch.Tensor], body_fn: Callable, operands: Tuple[torch.Tensor, ...] = (), concatenation_dims: Optional[Union[int, Sequence[int]]] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Implements a simple loop for, the body is defined by a function which takes the iteration number stored in a tensor, and other tensors. It results one or several tensors in a tuple. All of them are finally concatenated along the first dimension. :param n_iter: iteration number :param body: function :param operands: bidy arguments :param concatenation_dims: dimension or dimensions used to concatenate the output sequences :return: contenated outputs, the output is a Tensor An example with one output: .. runpython:: :showcode: import torch from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for class Model(torch.nn.Module): def forward(self, n_iter, x): def body(i, x): return (x[: i.item() + 1].unsqueeze(1),) return simple_loop_for(n_iter, body, (x,)) model = Model() n_iter = torch.tensor(4, dtype=torch.int64) x = torch.arange(10, dtype=torch.float32) ep = torch.export.export( model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) print(ep) Another example with two outputs and a final concatenation on different axes. .. runpython:: :showcode: import torch from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for class Model(torch.nn.Module): def forward(self, n_iter, x): def body(i, x): return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0)) return simple_loop_for(n_iter, body, (x,), (0, 1)) model = Model() n_iter = torch.tensor(4, dtype=torch.int64) x = torch.arange(10, dtype=torch.float32) ep = torch.export.export( model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) print(ep) """ res = _simple_loop_for( n_iter, body_fn, operands, concatenation_dims=( (concatenation_dims,) if isinstance(concatenation_dims, int) else concatenation_dims ), ) torch._check( isinstance(res, tuple), lambda: f"Output of the loop should be a tuple not {type(res)}.", ) return res[0] if len(res) == 1 else res