Source code for experimental_experiment.xshape._builder_runtime

import contextlib
from itertools import zip_longest
from typing import Dict, Generator, List, Tuple
import numpy as np
from onnx import NodeProto
from ..helpers import (
    string_type,
    tensor_dtype_to_np_dtype,
    dtype_to_tensor_dtype,
    onnx_dtype_to_torch_dtype,
)
from ..xshape._shape_helper import DYNAMIC_SHAPE, STATIC_SHAPE, all_int, all_int_or_str
from ..xshape.simplify_expressions import simplify_expression
from ..xshape._onnx_helper import str_tensor_proto_type


@contextlib.contextmanager
def _unset_fake_temporarily() -> Generator:
    import torch

    old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
    try:
        yield old
    finally:
        if old is not None:
            torch._C._set_dispatch_mode(old)


[docs] class _BuilderRuntime: """ Computes the output of a couple of nodes knowing their inputs. It supports numpy and torch tensors. Most of the function are used while exporting a model, by :meth:`_InferenceRuntime.compute_constant <experimental_experiment.xshape._inference_runtime._InferenceRuntime.compute_constant>`. """ def _apply_slice_to_shape( self, shape: STATIC_SHAPE, indices: List[slice], axes: List[int], expand_axes: List[int], ) -> STATIC_SHAPE: assert isinstance(shape, tuple), f"Unexpected type {type(shape)} for shape: {shape}" assert isinstance(indices, list), f"Unexpected type {type(indices)} for index: {indices}" assert isinstance(axes, list), f"Unexpected type {type(axes)} for index: {axes}" assert len(axes) in ( 1, len(indices), ), f"Mismatch lengths {len(indices)} != {len(axes)}" if all(isinstance(i, slice) for i in indices): new_shape = [] for index, axis_ in zip(indices, axes): axis = axis_ if axis_ >= 0 else (axis_ + len(shape)) % len(shape) while len(new_shape) < axis: assert shape[len(new_shape)] >= 0, ( f"Negative value in shape {shape}, indices={indices}, " f"axes={axes}, expand_axes={expand_axes}" ) new_shape.append(shape[len(new_shape)]) assert axis < len(shape), ( f"axis={axis} is out of order (shape={shape}, " f"indices={indices}, axes={axes}){self.get_debug_msg()}" ) n = shape[axis] start = index.start or 0 end = index.stop or n diff = end - start dim = diff // index.step if index.step else diff dim = max(dim, 0) assert dim >= 0, ( f"Negative dim={dim}, axis={axis}, shape={shape}, indices={indices}, " f"axes={axes}, expand_axes={expand_axes}" ) new_shape.append(dim) elif all_int(indices): assert len(axes) == 1, ( f"Unable to guess new shape from shape={shape}, " f"indices={indices}, axes={axes}, expand_axes={expand_axes}" ) new_shape = [len(indices), *shape[1:]] else: raise RuntimeError( f"Unable to guess new shape from shape={shape}, " f"indices={indices}, axes={axes}, expand_axes={expand_axes}" ) for a in shape[len(new_shape) :]: assert a >= 0, ( f"Negative value in shape {shape}, indices={indices}, " f"axes={axes}, expand_axes={expand_axes}" ) new_shape.append(a) for e in expand_axes: new_shape.insert(e, 1) return tuple(new_shape)
[docs] def _apply_reshape_to_shape( self, input_shape: DYNAMIC_SHAPE, new_shape: STATIC_SHAPE ) -> DYNAMIC_SHAPE: """Returns the shape of the output of a node Reshape.""" assert isinstance( input_shape, tuple ), f"unexpected type {type(input_shape)} for input_shape." assert isinstance(new_shape, tuple), f"unexpected type {type(new_shape)} for input_shape." assert all_int(new_shape), f"unexpected type for a dimension in {new_shape}" # handling zeros --> keeps the original dimension new_new_shape = [] for i, sh in enumerate(new_shape): if sh == 0: assert i < len( input_shape ), f"Unable to apply reshape {new_shape} to input shape {input_shape}" new_new_shape.append(input_shape[i]) continue new_new_shape.append(sh) new_shape = tuple(new_new_shape) if -1 not in new_shape: return new_shape if all_int(input_shape): size = int(np.prod(input_shape)) div = np.prod([i for i in new_shape if i != -1]) if div == 0: return tuple((int(i) if i >= 0 else 0) for i in new_shape) return tuple((int(i) if i >= 0 else int(size // div)) for i in new_shape) if all_int_or_str(input_shape) and new_shape == (1, -1): # common case return (1, "*".join(map(str, input_shape))) mul, div = [], [] muli, divi = 1, 1 for s, n in zip_longest(input_shape, new_shape): if s is None: s = 1 if n is None: n = 1 if isinstance(s, str) and isinstance(n, str): if s != n: mul.append(s) div.append(n) elif isinstance(s, str): mul.append(s) if n != -1: divi *= n else: muli *= s if n != -1: divi *= n if not mul and not div: assert muli % divi == 0, ( f"Inconsistency between input_shape={input_shape} " f"and new_shape={new_shape}{self.get_debug_msg()}" ) rest = muli // divi else: def _(s): return f"({s})" if set(s) & set("+/-*%()") else s if muli != 1: mul.append(str(muli)) if divi != 1: div.append(str(divi)) if not mul: mul = ["1"] if not div: rest = mul[0] if len(mul) == 1 else f"{'*'.join(f'{_(s)}' for s in mul)}" elif not mul: rest = ( f"1//{_(div[0])}" if len(div) == 1 else f"1//({'*'.join(f'{_(s)}' for s in div)})" ) else: rest = ( f"(({'*'.join(f'{_(s)}' for s in mul)})" f"//({'*'.join(f'{_(s)}' for s in div)}))" ) rest = simplify_expression(rest) return tuple(s if s != -1 else rest for s in new_shape)
[docs] def _apply_expand_to_shape( self, input_shape: DYNAMIC_SHAPE, new_shape: STATIC_SHAPE ) -> DYNAMIC_SHAPE: """Returns the shape of the output of a node Reshape.""" assert isinstance( input_shape, tuple ), f"unexpected type {type(input_shape)} for input_shape." assert isinstance(new_shape, tuple), f"unexpected type {type(new_shape)} for input_shape." if -1 not in new_shape and 1 not in new_shape: return new_shape assert len(new_shape) >= len(input_shape), ( f"inconsistent behaviour, new_shape={new_shape}, " f"input_shape={input_shape}{self.get_debug_msg()}" ) if len(input_shape) < len(new_shape): input_shape = (1,) * (len(new_shape) - len(input_shape)) + input_shape nsh = [] for i, s in enumerate(new_shape): if s == 1: assert i < len(input_shape), ( f"Unexpected scenario new_shape={new_shape}, " f"input_shape={input_shape}{self.get_debug_msg()}" ) nsh.append(input_shape[i]) continue if s == 0: nsh.append(0) continue if i < len(input_shape): if isinstance(s, str) and isinstance(input_shape[i], str): if s != input_shape[i]: return None nsh.append(s) continue if isinstance(s, str) and isinstance(input_shape[i], int): if input_shape[i] == 1: nsh.append(s) continue # (1, 1, 1024) with (1, 1, 'input_dim_13') # The output is 1024 if input_dim_13 is not zero, which we don't know. return None assert isinstance(s, int) or (i < len(input_shape) and input_shape[i] == 1), ( f"Unable to compute expanded shape at position {i} when trying " f"to expand shape {input_shape} with {new_shape}{self.get_debug_msg()}" ) nsh.append(s) return tuple(nsh)
def _apply_transpose( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 perm = None for att in node.attribute: if att.name == "perm": perm = tuple(att.ints) break assert perm, f"perm not here in node {node}" x = feeds[node.input[0]] assert len(x.shape) == len(perm), ( f"Shape mismatch between x.shape={x.shape} and perm={perm!r}, " f"node is {self.pretty_node(node)}{self.get_debug_msg()}" ) if isinstance(x, np.ndarray): # Type conversion between numpy and torch is not robust. itype = dtype_to_tensor_dtype(x.dtype) ttype = onnx_dtype_to_torch_dtype(itype) x = self.torch.from_numpy(x.copy()).to(ttype) return [self.torch.permute(x, perm).to(x.dtype)] def _apply_expand( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 x = feeds[node.input[0]] new_shape = feeds[node.input[1]] if isinstance(x, self.torch.Tensor): if len(x.shape) == 0: if len(new_shape) == 0: return x import torch return [torch.full(tuple(new_shape), x)] shape_x = ( x.shape if len(x.shape) == len(new_shape) else ((1,) * (len(new_shape) - len(x.shape)) + x.shape) ) try: return [x.expand(tuple(max(s, int(i)) for s, i in zip(shape_x, new_shape)))] except RuntimeError as e: raise RuntimeError( f"Unable to compute the constant, new_shape={new_shape}, " f"x.shape={x.shape}, node={node}\n{self.pretty_text()}" ) from e ones = np.ones(tuple(int(i) for i in new_shape), dtype=x.dtype) return [(x * ones).astype(x.dtype)] def _apply_squeeze( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 x = feeds[node.input[0]] if len(node.input) == 1: # No axis. return [x.squeeze()] axis = feeds[node.input[1]] if len(axis.shape) == 0: return [np.squeeze(x, (int(axis),))] return [x.squeeze(tuple(int(i) for i in axis))] def _apply_unsqueeze( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 x = feeds[node.input[0]] axis = feeds[node.input[1]] if isinstance(x, np.ndarray): if len(axis.shape) == 0: return [np.expand_dims(x, (int(axis),))] return [np.expand_dims(x, tuple(int(i) for i in axis))] if isinstance(axis, np.ndarray): axis = [int(axis)] if axis.shape == tuple() else axis.tolist() if len(axis) == 1: if isinstance(x, (np.int64, np.int32)): return np.array([x]) return ( [x.expand_dims(int(axis[0]))] if isinstance(x, np.ndarray) else [x.unsqueeze(int(axis[0]))] ) assert len(axis) > 0, f"axis={axis} is null" for a in axis: x = x.expand_dims(int(a)) if isinstance(x, np.ndarray) else x.unsqueeze(int(a)) return [x] def _apply_cast( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 x = feeds[node.input[0]] if not isinstance(x, np.ndarray) and ( not hasattr(self, "torch") or not isinstance(x, self.torch.Tensor) ): # Maybe a float, then we process it as a float, tensor.to only works # on tensors. assert isinstance( x, (float, int, np.float32, np.float64, np.float16, np.int32, np.int64) ), f"Unexpected type {type(x)} for {node.input[0]!r} (node.name={node.name!r})" res = self._apply_cast(node, {node.input[0]: np.array(x)}) return [res[0]] to, saturate = None, 1 for att in node.attribute: if att.name == "to": to = att.i break if att.name == "saturate": saturate = att.i break assert to, f"to not here in node {node}" assert to != 8 and to < 17, f"Cast not implemented for to={to}, {str_tensor_proto_type()}" del saturate if not hasattr(self, "torch"): ttype = tensor_dtype_to_np_dtype(to) return [x.astype(ttype)] if isinstance(x, np.ndarray): # Type conversion between numpy and torch is not robust. itype = dtype_to_tensor_dtype(x.dtype) ttype = onnx_dtype_to_torch_dtype(itype) x = self.make_torch_tensor_from_np_array(x).to(ttype) assert "FakeTensor" not in str(type(x)), ( f"FakeTensor {node.output[0]!r} cannot be a constant {type(x)}, " f"node.op_type={node.op_type!r}, type={self.torch.Tensor}" f"{self.pretty_text()}" ) assert isinstance(x, self.torch.Tensor), ( f"Unexpected type {type(x)} for x for node type {node.op_type}, " f"name={node.name}, inputs={node.input}, outputs={node.output}" ) ttype = onnx_dtype_to_torch_dtype(to) return [x.to(ttype)] def _apply_unary_function( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 x = feeds[node.input[0]] itype = dtype_to_tensor_dtype(x.dtype) if isinstance(x, np.ndarray): ttype = tensor_dtype_to_np_dtype(itype) if node.op_type == "Sqrt": return [np.sqrt(x).astype(ttype)] if node.op_type == "Exp": return [np.exp(x).astype(ttype)] if node.op_type == "Reciprocal": return [(np.array([1], dtype=x.dtype) / x).to(ttype)] raise AssertionError( f"Not implemented for op_type={node.op_type!r}, node={node}, feeds={feeds}" ) ttype = onnx_dtype_to_torch_dtype(itype) if node.op_type == "Sqrt": return [self.torch.sqrt(x).to(ttype)] if node.op_type == "Exp": return [self.torch.exp(x).to(ttype)] if node.op_type == "Reciprocal": return [(self.torch.tensor([1], dtype=x.dtype) / x).to(ttype)] raise AssertionError( f"Not implemented for op_type={node.op_type!r}, node={node}, " f"feeds={string_type(feeds, with_shape=True)}" ) def _apply_trilu( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 upper = True for att in node.attribute: if att.name == "upper": upper = att.i break assert len(node.input) in (1, 2), ( f"Unexpected number of inputs (inputs={node.input}) " f"for Trilu{self.get_debug_msg()}" ) x = feeds[node.input[0]] k = feeds[node.input[1]] if len(node.input) > 1 else np.array(0, dtype=np.int64) assert len(x.shape) > 0, ( f"x cannot be empty but shape is {x.shape}, execution of Trilu " f"failed{self.get_debug_msg()}" ) if isinstance(x, self.torch.Tensor): assert isinstance(k, self.torch.Tensor), ( f"Expecting a tensor for {node.input[1]!r} but got " f"{type(k)}{self.get_debug_msg()}" ) ak = k.detach().cpu() iak = int(ak) if len(ak.shape) == 0 else int(ak[0]) assert iak <= 1, f"Unexpected value for k={k}{self.get_debug_msg()}" return [self.torch.triu(x, k == 0) if upper else self.torch.tril(x, k == 0)] assert isinstance(k, np.ndarray), ( f"Expecting a tensor for {node.input[1]!r} but got " f"{type(k)}{self.get_debug_msg()}" ) iak = int(k) if len(k.shape) == 0 else int(k[0]) return [np.triu(x, iak) if upper else np.tril(x, iak)] def _apply_binary_op( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 a, b = feeds[node.input[0]], feeds[node.input[1]] if a.dtype != b.dtype: a = self._to_torch_tensor(a) b = self._to_torch_tensor(b) try: if node.op_type == "Add": return [a + b] if node.op_type == "Mul": return [a * b] if node.op_type == "Sub": return [a - b] if node.op_type == "Div": return [a / b] if node.op_type == "Pow": return [a**b] raise AssertionError(f"{node.op_type!r} not implemented") except RuntimeError as e: raise AssertionError( f"Unable to multiply two objects of dtype {a.dtype}, {b.dtype} and " f"shapes {a.shape}, {b.shape}, node.op_type={node.op_type!r}, " f"node.name={node.name!r}, inputs={node.input}, outputs={node.output}" ) from e def _apply_where( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 new_feeds = {} for k, v in feeds.items(): if not hasattr(self, "torch"): new_feeds[k] = v if isinstance(v, np.ndarray): # Type conversion between numpy and torch is not robust. itype = dtype_to_tensor_dtype(v.dtype) ttype = onnx_dtype_to_torch_dtype(itype) x = self.make_torch_tensor_from_np_array(v.copy()).to(ttype) assert "FakeTensor" not in str(type(x)), ( f"FakeTensor {node.output[0]!r} cannot be a constant {type(x)}, " f"node.op_type={node.op_type!r}, type={self.torch.Tensor}" f"{self.get_debug_msg()}" ) new_feeds[k] = x else: new_feeds[k] = v if not hasattr(self, "torch"): y = np.where(*[new_feeds[k] for k in node.input]) return [y] y = self.torch.where(*[new_feeds[k] for k in node.input]) return [y] def _apply_slice( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 new_feeds = {} for k, v in feeds.items(): if isinstance(v, np.ndarray): # Type conversion between numpy and torch is not robust. itype = dtype_to_tensor_dtype(v.dtype) ttype = onnx_dtype_to_torch_dtype(itype) x = self.torch.from_numpy(v) assert x.dtype == ttype, ( f"Unexpected conversion from numpy {v.dtype} to " f"{x.dtype} != {ttype}{self.get_debug_msg()}" ) assert "FakeTensor" not in str(type(x)), ( f"FakeTensor {node.output[0]!r} cannot be a constant {type(x)}, " f"node.op_type={node.op_type!r}, type={self.torch.Tensor}" f"{self.get_debug_msg()}" ) new_feeds[k] = x else: new_feeds[k] = v assert len(node.input) >= 3, ( f"Node {node.op_type} (name={node.name!r}) has not enough " f"inputs {node.input}\n{self.pretty_text()}" ) data, starts, ends = [new_feeds[k] for k in node.input[:3]] axes = new_feeds[node.input[3]] if len(node.input) > 3 and node.input[3] else None steps = new_feeds[node.input[4]] if len(node.input) > 4 and node.input[4] else None if axes is None: if steps is None: slices = [slice(s, e) for s, e in zip(starts, ends)] else: slices = [slice(s, e, d) for s, e, d in zip(starts, ends, steps)] else: if steps is None: slices = [slice(0, a) for a in data.shape] for s, e, a in zip(starts, ends, axes): slices[a] = slice(s, e) else: slices = [slice(0, a) for a in data.shape] for s, e, a, d in zip(starts, ends, axes, steps): slices[a] = slice(s, e, d) res = data[tuple(slices)] assert len(res.shape) == 0 or min(res.shape) > 0, ( f"Empty shape found {res.shape} after Slice when x.shape={data.shape}, " f"starts={starts}, ends={ends}, axes={axes}, steps={steps}, " f"node.name={node.name!r}, input names={node.input}, " f"slices={slices}" ) assert len(res.shape) == len(data.shape), ( f"Shape mismatch input shape is {data.shape}, output shape is {res.shape}, " f"axes={axes}, starts={starts}, ends={ends}, steps={steps}, " f"node is {self.pretty_node(node)}{self.pretty_text()}" ) return [res] def _apply_shape_on_shape( self, node: NodeProto, shape: Tuple[int, ...] ) -> "torch.Tensor": # noqa: F821 if node.attribute: start = 0 end = None for att in node.attribute: if att.name == "start": start = att.i elif att.name == "end": end = att.i shape = shape[start:] if end is None else shape[start:end] return [self.torch.from_numpy(np.array(shape, dtype=np.int64))] def _apply_shape( self, node: NodeProto, feeds: Dict[str, "torch.Tensor"], # noqa: F821 ) -> "torch.Tensor": # noqa: F821 shape = tuple(map(int, feeds[node.input[0]].shape)) return self._apply_shape_on_shape(node, shape)