Source code for experimental_experiment.torch_interpreter._aten_methods

from typing import Any, Dict, List, Optional, Sequence
import numpy as np
from onnx import TensorProto
from ..helpers import tensor_dtype_to_np_dtype
from ..xbuilder._shape_helper import all_int
from ..xbuilder.graph_builder import GraphBuilder
from ..xbuilder.shape_type_compute import (
    torch_dtype_to_onnx_dtype,
    set_type_shape_binary_op,
    set_type_shape_unary_op,
    set_type_shape_reduce_op,
    set_type_shape_reshape,
)
from ._aten_functions import (
    aten_cos,
    aten_expand,
    aten_eq,
    aten_repeat,
    aten_sin,
    aten_t,
)

T = str


[docs] def aten_meth_bool( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T ) -> T: "cast" import torch return aten_meth_to(g, sts, outputs, x, dtype=torch.bool)
[docs] def aten_meth_clone( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T ) -> T: "identity" assert ( x != outputs[0] ), f"Input and output are the same x={x!r}, outputs={outputs!r}{g.get_debug_msg()}" return g.make_node("Identity", [x], outputs, name=".clone")
[docs] def aten_meth_contiguous( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T ) -> T: "identity" return g.make_node("Identity", [x], outputs, name=".contiguous")
[docs] def aten_meth_cos( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T ) -> T: "cos" return aten_cos(g, sts, outputs, x)
[docs] def aten_meth_cpu( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T ) -> T: "identity" return g.make_node("Identity", [x], outputs, name="cpu")
[docs] def aten_meth_eq( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, y: T ) -> T: "equal" return aten_eq(g, sts, outputs, x, y)
[docs] def aten_meth_expand( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, *dims: List[int], ) -> T: "expand" return aten_expand(g, sts, outputs, x, dims, name=".expand")
[docs] def aten_meth_float( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T ) -> T: "cast" import torch return aten_meth_to(g, sts, outputs, x, dtype=torch.float32)
[docs] def aten_meth_item( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, name: str = "aten_meth_item", ) -> T: "float(x)" if not g.has_shape(x): # Shape is unknown but using this operator means it is a number. # Let's unsqueeze res = g.op.Squeeze(x, outputs=outputs, name=name) else: assert g.get_shape(x) in (tuple(), (1,)), ( f"Missing shape or unexpected shape for {x!r}: has_shape={g.has_shape(x)}, " f"has_rank={g.has_rank(x)}{g.get_debug_msg()}" ) if g.has_shape() == (1,): res = g.op.SqueezeAnyOpset( x, np.array([0], dtype=np.int64), outputs=outputs, name=name ) else: res = g.op.Identity(x, outputs=outputs, name=name) if not sts: if g.has_type(x): g.set_type(res, g.get_type(x)) g.set_shape(res, tuple()) return res
[docs] def aten_meth_expand_as( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, y: T, name: str = "aten_meth_expand_as", ) -> T: "expand_as" shape = g.op.Shape(y, name=name) res = g.op.Expand(x, shape, name=name, outputs=outputs) if not sts: if g.has_shape(y): g.set_shape(res, g.get_shape(y)) elif g.has_rank(y): g.set_rank(res, g.get_rank(y)) if g.has_type(x): g.set_type(res, g.get_type(x)) return res
[docs] def aten_meth_masked_fill( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, mask: T, value: Any, ) -> T: "masked_fill" if isinstance(value, float): itype = g.get_type(x) value_cast = g.make_initializer( "", np.array([value], dtype=tensor_dtype_to_np_dtype(itype)), source="aten_meth_masked_fill_", ) set_shape_cast = False else: value_cast = g.op.CastLike(value, x, name=".masked_fill") set_shape_cast = True res = g.op.Where(mask, value_cast, x, name=".masked_fill") if not sts: g.set_type(res, g.get_type(x)) if set_shape_cast: g.set_type(value_cast, g.get_type(x)) if isinstance(value, str): if g.has_shape(value): g.set_shape(value_cast, g.get_shape(value)) elif g.has_rank(value): g.set_rank(value_cast, g.get_rank(value)) elif isinstance(value, (int, float, bool)): g.set_shape(value_cast, tuple()) elif hasattr(value, "shape"): g.set_shape(value_cast, value.shape) else: raise RuntimeError(f"Unable to guess shape from type {type(value)}") set_type_shape_binary_op(g, res, mask, value_cast, x, begin=1) return res
[docs] def aten_meth_masked_fill_( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, mask: T, value: Any, ) -> T: "masked" raise RuntimeError( "These calls should be removed from the fx graph as it is inplace modification " "(aten_meth_masked_fill_)." )
[docs] def aten_meth_mean( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, dim: T, keepdim: bool = False, ) -> T: "reducemean" if isinstance(dim, int): cst = g.make_initializer( "", np.array([dim], dtype=np.int64), source="aten_meth_mean.cst.1" ) elif isinstance(dim, tuple): cst = g.make_initializer( "", np.array(dim, dtype=np.int64), source="aten_meth_mean.cst.2" ) else: raise RuntimeError(f"Unexpected type {type(dim)} for dim.") res = g.op.ReduceMeanAnyOpset( x, cst, outputs=outputs, keepdims=1 if keepdim else 0, name=".mean" ) if not sts: set_type_shape_reduce_op( g, outputs[0], x, keepdim=keepdim, axes=(dim,) if isinstance(dim, int) else dim, ) return res
[docs] def aten_meth_pow( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, exponent: T, ) -> T: "pow" assert isinstance(x, str), f"Unexpected type {type(x)} (x={x!r}, exponent={exponent!r})" if isinstance(exponent, (int, float)): cst = g.make_initializer( "", np.array( exponent, dtype=tensor_dtype_to_np_dtype(g.get_type(x)), ), source="aten_meth_pow.exponent.scalar", ) elif isinstance(exponent, np.array): cst = g.make_initializer( "", exponent.as_type(tensor_dtype_to_np_dtype(g.get_type(x))), source="aten_meth_pow.exponent.tensor", ) elif isinstance(exponent, str): cst = exponent else: raise RuntimeError(f"Unexpected type {type(exponent)} for exponent.") res = g.make_node("Pow", [x, cst], outputs, name="meth_pow") if not sts: set_type_shape_unary_op(g, outputs[0], x) return res
[docs] def aten_meth_repeat( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, *repeats: List[int], ) -> T: "repeat" return aten_repeat(g, sts, outputs, x, repeats, name=".repeat")
[docs] def aten_meth_reshape( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], input_name: T, *shape: List[int], name: str = "reshape", ) -> T: "reshape" if all_int(shape): # static version cst = g.make_initializer( "", np.array(shape, dtype=np.int64), source="aten_meth_reshape.shape" ) res = g.make_node("Reshape", [input_name, cst], outputs, name=name) if not sts: set_type_shape_reshape(g, res, input_name, shape) return res # dynamic version dyn_shape = g.make_shape_from_results(shape, name=name) res = g.make_node("Reshape", [input_name, dyn_shape], outputs, name=name) if not sts: set_type_shape_reshape(g, res, input_name, shape) return res
[docs] def aten_meth_sin( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T ) -> T: "sin" return aten_sin(g, sts, outputs, x)
[docs] def aten_meth_size( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, dim: Optional[int] = None, name: str = ".size", ) -> T: "size" if dim is None: res = g.op.Shape(x, name=f"{name}A", outputs=outputs) if not sts: g.set_type(res, TensorProto.INT64) g.set_shape(res, (g.get_rank(x),)) return res s = g.op.Shape(x, name=name) d = g.op.Gather(s, np.array([dim], dtype=np.int64), name=f"{name}B") res = g.op.SqueezeAnyOpset( d, np.array([0], dtype=np.int64), name=f"{name}B", outputs=outputs ) if not sts: g.set_type(res, TensorProto.INT64) g.set_shape(res, tuple()) return res
[docs] def aten_meth_sum( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T, axis: T, keepdim: bool = False, dim: Optional[int] = None, ) -> T: "reducesum" if axis is not None and isinstance(axis, int): axes = np.array([axis], dtype=np.int64) elif dim is not None and isinstance(dim, int): axes = np.array([dim], dtype=np.int64) else: raise AssertionError( f"Unexpected value for dim={dim!r} or axis={axis!r}{g.get_debug_msg()}" ) res = g.op.ReduceSumAnyOpset( x, axes, outputs=outputs, keepdims=1 if keepdim else 0, name=".sum" ) if not sts: set_type_shape_reduce_op( g, outputs[0], x, keepdim=keepdim, axes=tuple(map(int, axes)), ) return res
[docs] def aten_meth_t(g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T) -> T: "transpose" return aten_t(g, sts, outputs, x, name=".t")
[docs] def aten_meth_to( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], input_name: T, *args: List[Any], name: str = ".to", **kwargs: Dict[str, Any], ) -> T: "cast" import torch dtype = kwargs.get("dtype", None) device = kwargs.get("device", None) for a in args: if isinstance(a, torch.dtype): assert dtype is None, f"dtype is specified in args and kwargs {args}, {kwargs}" dtype = a continue if isinstance(a, torch.device): assert device is None, f"device is specified in args and kwargs {args}, {kwargs}" device = a continue raise NotImplementedError(f"Unexpected type for argument {type(a)}") assert ( dtype is not None or device is not None ), "dtype or device cannot be None for method to" if dtype is None: return g.op.Identity(input_name, outputs=outputs, name=name) onnx_to = torch_dtype_to_onnx_dtype(dtype) res = g.make_node("Cast", [input_name], outputs, to=onnx_to, name=name) if not sts: g.set_type(outputs[0], onnx_to) if g.has_shape(input_name): g.set_shape(outputs[0], g.get_shape(input_name)) elif g.has_rank(input_name): g.set_rank(outputs[0], g.get_rank(input_name)) return res
[docs] def aten_meth_transpose( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], input_name: T, dim0: int, dim1: int, ) -> T: "transpose" assert g.has_rank(input_name), f"{input_name!r} must have a rank{g.get_debug_msg}" perm = list(range(g.rank(input_name))) assert max(dim0, dim1) < len(perm), ( f"aten_meth_transpose: unexpected perm={perm}, dim0={dim0}, dim1={dim1}, " f"input_name={input_name!r}, rank={g.rank(input_name)}" f"{g.get_debug_msg()}" ) perm[dim0], perm[dim1] = perm[dim1], perm[dim0] res = g.make_node("Transpose", [input_name], outputs, perm=perm, name="meth_transpose") if not sts: g.set_type(outputs[0], g.get_type(input_name)) if g.has_shape(input_name): shape = list(g.get_shape(input_name)) shape[dim0], shape[dim1] = shape[dim1], shape[dim0] g.set_shape(outputs[0], tuple(shape)) elif g.has_rank(input_name): g.set_rank(outputs[0], g.get_rank(input_name)) return res
[docs] def aten_meth_unsqueeze( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], input_name: T, dim: int, ) -> T: "unsqueeze" new_name = g.unique_name(f"{input_name}_axes") g.make_initializer( new_name, np.array([dim], dtype=np.int64), source="aten_meth_unsqueeze.axis" ) res = g.make_node("Unsqueeze", [input_name, new_name], outputs, name="meth_unsqueeze") if not sts: dtype = g.get_type(input_name) g.set_type(outputs[0], dtype) if g.has_shape(input_name): shape = list(g.get_shape(input_name)) shape.insert(dim, 1) g.set_shape(outputs[0], tuple(shape)) elif g.has_rank(input_name): g.set_rank(outputs[0], g.get_rank(input_name) + 1) return res
[docs] def aten_meth_view( g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], input_name: T, *args: Sequence[int], ) -> T: "view" if all_int(args): # static shape new_shape_name = g.unique_name(f"{input_name}_view_shape") g.make_initializer( new_shape_name, np.array(args, dtype=np.int64), source="aten_meth_view.shape" ) res = g.make_node("Reshape", [input_name, new_shape_name], outputs, name=".view") if not sts: set_type_shape_reshape(g, res, input_name, args) return res new_shape_name = g.make_shape_from_results(args, name=".view") res = g.make_node("Reshape", [input_name, new_shape_name], outputs, name=".view") if not sts: g.set_type(new_shape_name, TensorProto.INT64) g.set_shape(new_shape_name, (len(args),)) set_type_shape_reshape(g, res, input_name, new_shape_name) assert g.get_rank(res) == len(args), f"error in set_type_shape_reshape args={args!r}" return res