from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx.helper import tensor_dtype_to_np_dtype
from ..xbuilder._shape_helper import all_int
from ..xbuilder._dtype_helper import torch_dtype_to_onnx_dtype
from ..xbuilder.graph_builder import GraphBuilder
from ..xbuilder.shape_type_compute import (
broadcast_shape,
set_type_shape_unary_op,
set_type_shape_reduce_op,
)
T = str
[docs]
def prims_add(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
y: T,
name="prims_add",
) -> T:
"add"
from ._aten_functions import aten_add
return aten_add(g, sts, outputs, x, y, name=name)
[docs]
def prims_amax(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
dim: Optional[int] = None,
keepdim: bool = False,
output_dtype: Optional["torch.dtype"] = None, # noqa: F821
name: str = "prims_amax",
) -> T:
"reducemax"
assert (
output_dtype is None
), f"not implemented when output_dtype={output_dtype!r}{g.get_debug_msg()}"
if dim is None:
res = g.op.ReduceMaxAnyOpset(x, keepdims=1 if keepdim else 0, outputs=outputs)
elif isinstance(dim, int):
res = g.op.ReduceMaxAnyOpset(
x,
np.array([dim], dtype=np.int64),
keepdims=1 if keepdim else 0,
outputs=outputs,
)
elif isinstance(dim, list) and all_int(dim):
res = g.op.ReduceMaxAnyOpset(
x,
np.array(dim, dtype=np.int64),
keepdims=1 if keepdim else 0,
outputs=outputs,
)
else:
raise RuntimeError(f"Unexpected type {type(dim)} for dim")
if not sts:
set_type_shape_reduce_op(g, outputs[0], x, keepdim=keepdim)
return res
[docs]
def prims_broadcast_in_dim(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
a: T,
shape: List[int],
broadcast_dimensions: List[int],
) -> T:
"""
broadcast
::
s = list(shape)
for broadcast_dimension in broadcast_dimensions:
s[broadcast_dimension] = -1
v = a
for idx, x in enumerate(s):
if x != -1:
v = unsqueeze(v, idx)
return expand(v, shape)
"""
assert max(broadcast_dimensions) < len(shape), (
f"Index out of boundary, shape={shape}, "
f"broadcast_dimensions={broadcast_dimensions}{g.get_debug_msg()}"
)
s = list(shape)
for broadcast_dimension in broadcast_dimensions:
s[broadcast_dimension] = -1
uns = []
for idx, x in enumerate(s):
if x != -1:
uns.append(idx)
unsqueezed = (
g.op.UnsqueezeAnyOpset(a, np.array(uns, dtype=np.int64), name="broadcast_in_dim")
if len(uns) > 0
else a
)
res = g.op.Expand(
unsqueezed,
np.array(shape, dtype=np.int64),
name="broadcast_in_dim",
outputs=outputs,
)
if not sts:
g.set_type(res, g.get_type(a))
g.set_shape(res, shape)
return res
[docs]
def prims_cat(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
tensors: Tuple[T, ...],
dim: int = 0,
name: str = "prims_cat",
) -> T:
"concat"
from ._aten_functions import aten_cat
return aten_cat(g, sts, outputs, tensors, dim=dim, name=name)
[docs]
def prims_clone(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
memory_format: Optional[str] = None,
) -> T:
"identity"
from ._aten_functions import aten_clone
return aten_clone(g, sts, outputs, x, memory_format=memory_format, name="prims_clone")
[docs]
def prims_convert_element_type(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
dtype: "torch.dtype", # noqa: F821
name: str = "prims_convert_element_type",
) -> T:
"cast"
assert (
dtype is not None
), f"dtype cannot be none for prims_convert_element_type{g.get_debug_msg()}"
onnx_to = torch_dtype_to_onnx_dtype(dtype)
if onnx_to == g.get_type(x):
return g.op.Identity(x, outputs=outputs, name=name)
res = g.make_node("Cast", [x], outputs, to=onnx_to, name=name)
return res
[docs]
def prims_collapse_view(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
start: int,
end: int,
name: str = "prims_collapse_view",
) -> T:
"reshape"
assert g.has_shape(
x
), f"collapse_view not implemented if x has no shape{g.get_debug_msg()}"
shape = g.get_shape(x)
start = (start + len(shape)) % len(shape)
end = (end + len(shape)) % len(shape)
new_shape = []
s = 1
for i in range(len(shape)):
if start <= i <= end:
if i == start:
new_shape.append(-1)
s *= shape[i]
else:
new_shape.append(shape[i])
ashape = np.array(new_shape, dtype=np.int64)
res = g.op.Reshape(x, ashape, outputs=outputs, name=name)
if not sts:
g.set_type(res, g.get_type(x))
ashape[ashape == -1] = s
g.set_shape(res, tuple(ashape))
return res
[docs]
def prims_cos(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
name: str = "prims_cos",
) -> T:
"cos"
from ._aten_functions import aten_cos
return aten_cos(g, sts, outputs, x, name=name)
[docs]
def prims_div(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
y: T,
name: str = "prims_div",
) -> T:
"div"
from ._aten_functions import aten_div
return aten_div(g, sts, outputs, x, y, name=name)
[docs]
def prims_empty_strided(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
size: T,
stride: T,
dtype: Optional["torch.dtype"] = None, # noqa: F821
layout=None,
device: Optional["torch.device"] = None, # noqa: F821
requires_grad: bool = False,
name: str = "prims_empty_strided",
) -> T:
"constantofshape"
# strided is unused.
from ._aten_functions import aten_empty_strided
return aten_empty_strided(
g,
sts,
outputs,
size,
stride,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
name=name,
)
[docs]
def prims_eq(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
y: T,
name: str = "prims_eq",
) -> T:
"equal"
from ._aten_functions import aten_eq
return aten_eq(g, sts, outputs, x, y, name=name)
[docs]
def prims_exp(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
name: str = "prims_exp",
) -> T:
"exp"
from ._aten_functions import aten_exp
return aten_exp(g, sts, outputs, x, name=name)
[docs]
def prims_ge(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
y: T,
name: str = "prims_ge",
) -> T:
"less"
from ._aten_functions import aten_ge
return aten_ge(g, sts, outputs, x, y, name=name)
[docs]
def prims_gt(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
y: T,
name: str = "prims_gt",
) -> T:
"greater"
from ._aten_functions import aten_gt
return aten_gt(g, sts, outputs, x, y, name=name)
[docs]
def prims_iota(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
length: int,
start: int = 0,
step: int = 1,
dtype: Optional["torch.dtype"] = None, # noqa: F821
device: Optional["torch.device"] = None, # noqa: F821
requires_grad: bool = False,
) -> T:
"arange"
assert isinstance(
length, int
), f"not implemented when length={length!r}{g.get_debug_msg()}"
assert isinstance(start, int), f"not implemented when start={start!r}{g.get_debug_msg()}"
assert isinstance(step, int), f"not implemented when step={step!r}{g.get_debug_msg()}"
end = start + length * step
from ._aten_functions import aten_arange
return aten_arange(
g,
sts,
outputs,
start,
end,
step,
dtype=dtype,
device=device,
requires_grad=requires_grad,
name="prims_iota",
)
[docs]
def prims_lt(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
y: T,
name: str = "prims_lt",
) -> T:
"less"
from ._aten_functions import aten_lt
return aten_lt(g, sts, outputs, x, y, name=name)
[docs]
def prims_mul(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
y: T,
name: str = "prims_mul",
) -> T:
"mul"
from ._aten_functions import aten_mul
return aten_mul(g, sts, outputs, x, y, name=name)
[docs]
def prims_neg(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
name="prims_neg",
) -> T:
"neg"
from ._aten_functions import aten_neg
return aten_neg(g, sts, outputs, x, name=name)
[docs]
def prims_pow(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
exponent: T,
name: str = "prims_pow",
) -> T:
"pow"
from ._aten_functions import aten_pow_Tensor_Tensor
return aten_pow_Tensor_Tensor(g, sts, outputs, x, exponent, name=name)
[docs]
def prims_rsqrt(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
name: str = "prims_rsqrt",
) -> T:
"rqsrt"
res = g.op.Reciprocal(g.op.Sqrt(x, name=name), name=name, outputs=outputs)
if not sts:
set_type_shape_unary_op(g, res, x)
return res
[docs]
def prims_sin(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
name: str = "prims_sin",
) -> T:
"sim"
from ._aten_functions import aten_sin
return aten_sin(g, sts, outputs, x, name=name)
[docs]
def prims_split_dim(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
dim: int,
outer_length: int,
name: str = "prims_split_dim",
):
"split"
assert len(outputs) == 1, f"Expecting 1 outputs but got {outputs}{g.get_debug_msg()}"
assert g.has_shape(x), f"Not implemented when shape of {x!r} is unknown{g.get_debug_msg()}"
shape = g.get_shape(x)
shape_dim = shape[dim]
assert isinstance(
shape_dim, int
), f"Not implemented for a dynamic dimension {shape_dim}{g.get_debug_msg()}"
assert shape_dim % outer_length == 0, (
f"shape_dim={shape_dim} not a multiple of "
f"outer_length={outer_length}{g.get_debug_msg()}"
)
inner_length = shape_dim // outer_length
new_shape = shape[0:dim] + (outer_length, inner_length) + shape[dim + 1 :]
res = g.op.Reshape(x, np.array(new_shape), outputs=outputs, name=name)
if not sts:
g.set_type(res, g.get_type(x))
g.get_shape(res, tuple(new_shape))
return res
[docs]
def prims_sub(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
y: T,
name: str = "prims_sub",
) -> T:
"sub"
from ._aten_functions import aten_sub
return aten_sub(g, sts, outputs, x, y, name=name)
[docs]
def prims_sum(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
x: T,
dim: Optional[Union[int, List[int]]] = None,
keepdim: bool = False,
output_dtype: Optional["torch.dtype"] = None, # noqa: F821
) -> T:
"reducesum"
from ._aten_functions import aten_sum
return aten_sum(
g, sts, outputs, x, dim, keepdim=keepdim, dtype=output_dtype, name="prims_sum"
)
[docs]
def prims_transpose(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
input_name: T,
perm: List[int],
name: str = "prims_transpose",
) -> T:
"transpose"
res = g.make_node("Transpose", [input_name], outputs, perm=list(perm), name=name)
if not sts:
g.set_type(outputs[0], g.get_type(input_name))
if g.has_shape(input_name):
shape = list(g.get_shape(input_name))
new_shape = shape.copy()
for i, p in enumerate(perm):
new_shape[i] = shape[p]
g.set_shape(outputs[0], tuple(new_shape))
elif g.has_rank(input_name):
g.set_rank(outputs[0], g.has_rank(input_name))
return res
[docs]
def prims_view_of(
g: GraphBuilder, sts: Optional[Dict[str, Any]], outputs: List[str], x: T
) -> T:
"identity"
return g.op.Identity(x, outputs=outputs, name="prims_view_of")
[docs]
def prims_where(
g: GraphBuilder,
sts: Optional[Dict[str, Any]],
outputs: List[str],
condition: T,
x: T,
other: T,
) -> T:
"where"
assert not (isinstance(x, (int, float)) and isinstance(other, (int, float))), (
f"The two last arguments ({x}, {other}) are constant and cannot be used "
f"to infer types{g.get_debug_msg}"
)
assert not isinstance(x, float) or not isinstance(other, float), (
f"The output type cannot be guessed if the last two arguments are both floats, "
f"x={x}, other={other}{g.get_debug_msg()}"
)
dtype = tensor_dtype_to_np_dtype(g.get_type(other if isinstance(other, str) else x))
ax = x if isinstance(x, str) else np.array([x], dtype=dtype)
aother = other if isinstance(other, str) else np.array([other], dtype=dtype)
res = g.op.Where(condition, ax, aother, outputs=outputs, name="prims_where")
if not sts:
g.set_type(res, g.get_type(other))
if g.has_shape(condition) and g.has_shape(other):
shape = broadcast_shape(
g.get_shape(condition), g.get_shape(other), graph_builder=g
)
g.set_shape(res, shape)
else:
g.set_rank(max(g.get_rank(condition), g.get_rank(other)))
return res