import inspect
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
from onnx import NodeProto
from ...xshape._onnx_helper import (
element_wise_binary_op_types,
element_wise_op_cmp_types,
unary_like_op_types,
)
from ...xshape._shape_helper import all_int, DYNAMIC_SHAPE
from ..patterns_api import MatchResult, PatternOptimization
[docs]
class ExpandPattern(PatternOptimization):
"""
Checks that a Expand is really needed.
Model with nodes to be fused:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("init7_s4_32_2_10_8", onnx.TensorProto.INT64, shape=(4,))
)
inputs.append(
oh.make_tensor_value_info("mul", onnx.TensorProto.FLOAT, shape=(32, 2, 10, 8))
)
nodes.append(
oh.make_node(
"Constant",
[],
["init7_s4_32_2_10_8"],
value=onh.from_array(np.array([32, 2, 10, 8], dtype=np.int64), name="value"),
)
)
nodes.append(oh.make_node("Expand", ["mul", "init7_s4_32_2_10_8"], ["expand"]))
outputs.append(
oh.make_tensor_value_info("expand", onnx.TensorProto.FLOAT, shape=(32, 2, 10, 8))
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
Outcome of the fusion:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("init7_s4_32_2_10_8", onnx.TensorProto.INT64, shape=(4,))
)
inputs.append(
oh.make_tensor_value_info("mul", onnx.TensorProto.FLOAT, shape=(32, 2, 10, 8))
)
nodes.append(oh.make_node("Identity", ["mul", "init7_s4_32_2_10_8"], ["expand"]))
outputs.append(
oh.make_tensor_value_info("expand", onnx.TensorProto.FLOAT, shape=(32, 2, 10, 8))
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
"""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Expand" or node.domain != "":
return self.none()
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
shape = g.get_shape(node.input[0])
if not all_int(shape):
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(node.input[1]):
# It may be a symbolic shape.
return self.none(node, inspect.currentframe().f_lineno)
value = g.get_computed_constant(node.input[1])
if value is None:
return self.none(node, inspect.currentframe().f_lineno)
with g.builder.maybe_disable_fake_tensor_mode():
new_shape = tuple(int(i) for i in value)
if shape != new_shape:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
) -> List[NodeProto]:
new_node = g.make_node(
"Identity",
node.input,
node.output,
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
)
return [new_node]
[docs]
class ExpandBroadcastPattern(PatternOptimization):
"""
Checks that a Expand is really needed before an element wise operator.
The objective is to save one allocation and let the next operator
do the expansion by broadcasting one input.
Model with nodes to be fused:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("mul_25", onnx.TensorProto.FLOAT, shape=(2, 1024, 1))
)
inputs.append(
oh.make_tensor_value_info("input66", onnx.TensorProto.FLOAT, shape=(2, 1024, 1024))
)
nodes.append(
oh.make_node(
"Constant",
[],
["init7_s3_2_1024_1024"],
value=onh.from_array(np.array([2, 1024, 1024], dtype=np.int64), name="value"),
)
)
nodes.append(
oh.make_node("Expand", ["mul_25", "init7_s3_2_1024_1024"], ["expand_11"])
)
nodes.append(
oh.make_node("Mul", ["expand_11", "input66"], ["MulMulMulPattern--mul_27"])
)
outputs.append(
oh.make_tensor_value_info(
"MulMulMulPattern--mul_27", onnx.TensorProto.FLOAT, shape=(2, 1024, 1024)
)
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
Outcome of the fusion:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("mul_25", onnx.TensorProto.FLOAT, shape=(2, 1024, 1))
)
inputs.append(
oh.make_tensor_value_info("input66", onnx.TensorProto.FLOAT, shape=(2, 1024, 1024))
)
nodes.append(
oh.make_node("Mul", ["mul_25", "input66"], ["MulMulMulPattern--mul_27"])
)
outputs.append(
oh.make_tensor_value_info(
"MulMulMulPattern--mul_27", onnx.TensorProto.FLOAT, shape=(2, 1024, 1024)
)
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
"""
_op_types = element_wise_binary_op_types() | element_wise_op_cmp_types()
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Expand" or node.domain != "":
return self.none()
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
shape = g.get_shape(node.input[0])
if not all_int(shape):
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(node.input[1]):
# It may be a symbolic shape.
return self.none(node, inspect.currentframe().f_lineno)
value = g.get_computed_constant(node.input[1])
if value is None:
return self.none(node, inspect.currentframe().f_lineno)
with g.builder.maybe_disable_fake_tensor_mode():
new_shape = tuple(int(i) for i in value)
if g.is_used_more_than_once(node.output[0]):
# More than one output, not handled right now.
return self.none(node, inspect.currentframe().f_lineno)
next_nodes = g.next_nodes(node.output[0])
assert len(next_nodes) == 1, "The previous test should have cleared out this case."
next_node = next_nodes[0]
if next_node.op_type not in self._op_types or next_node.domain != "":
# Not an element wise operator.
return self.none(node, inspect.currentframe().f_lineno)
other = next_node.input[1 if next_node.input[0] == node.output[0] else 0]
if not g.has_shape(other):
return self.none(node, inspect.currentframe().f_lineno)
other_shape = g.get_shape(other)
if new_shape != other_shape:
# Expand does not expand to the shape of the other element.
return self.none(node, inspect.currentframe().f_lineno)
if len(shape) != len(other_shape):
# Different ranks.
return self.none(node, inspect.currentframe().f_lineno)
for a, b in zip(shape, other_shape):
if not (a == b or a == 1 or b == 1):
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node, next_node], self.apply, insert_at=next_node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
next_node: NodeProto,
) -> List[NodeProto]:
if next_node.input[0] == node.output[0]:
inputs = [node.input[0], next_node.input[1]]
else:
inputs = [next_node.input[0], node.input[0]]
return [
g.make_node(
next_node.op_type,
inputs,
next_node.output,
name=f"{self.__class__.__name__}--{node.name}",
doc_string=next_node.doc_string,
)
]
[docs]
class ShapeBasedExpandBroadcastPattern(PatternOptimization):
"""
Similar to
:class:`experimental_experiment.xoptim.patterns.onnx_expand.ExpandBroadcastPattern`,
but it allows dynamic shapes as well. It does not look into the second
argument of Expand, it just infers than an expand is not needed for
a binary operator following just after.
"""
_op_types = element_wise_binary_op_types() | element_wise_op_cmp_types()
@classmethod
def _is_compatible_shapes_for_expand(
cls,
shape_left: DYNAMIC_SHAPE,
shape_right: DYNAMIC_SHAPE,
output_shape: Optional[DYNAMIC_SHAPE],
) -> bool:
"""
Checks that the binary operations of the two input shapes returns the output_shape.
Then no Expand node is needed.
"""
if output_shape is None:
return False
if max(len(shape_left), len(shape_right) if shape_right else 0) < len(output_shape):
return False
# Align shapes
if len(shape_left) < len(shape_right):
shape_left = (1,) * (len(shape_right) - len(shape_left)) + shape_left
elif len(shape_left) > len(shape_right):
shape_right = (1,) * (len(shape_left) - len(shape_right)) + shape_right
for left, right, out in zip(shape_left, shape_right, output_shape):
if isinstance(left, int):
if isinstance(right, int):
# static right
if left == 1:
if right != out:
return False
elif right == 1:
if left != out:
return False
else:
if left != right or left != out or right != out:
return False
else:
# dynamic right
if left == 1:
if right != out:
return False
else:
if left != right or left != out or right != out:
return False
else:
# dynamic left
if isinstance(right, int):
# static right
if right == 1:
if left != out:
return False
else:
if left != right or left != out or right != out:
return False
else:
# dynamic right
if left != right or left != out or right != out:
return False
return True
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type not in self._op_types or node.domain != "":
return self.none()
if (
not g.has_shape(node.output[0])
or not g.has_shape(node.input[0])
or not g.has_shape(node.input[1])
):
return self.none(node, inspect.currentframe().f_lineno)
node_left = g.node_before(node.input[0])
node_right = g.node_before(node.input[1])
before = [
None if n is None or n.op_type != "Expand" else n for n in [node_left, node_right]
]
if before == [None, None]:
return self.none(node, inspect.currentframe().f_lineno)
# At least one expand.
node_left, node_right = before
shape_left = g.get_shape_renamed(
node.input[0] if node_left is None else node_left.input[0]
)
shape_right = g.get_shape_renamed(
node.input[1] if node_right is None else node_right.input[0]
)
if self._is_compatible_shapes_for_expand(
shape_left, shape_right, g.get_shape_renamed(node.output[0])
):
if self.verbose:
print(
f"[{self.__class__.__name__}.match] {shape_left} "
f"{node.op_type} {shape_right} -> {g.get_shape_renamed(node.output[0])}"
)
return MatchResult(self, [node_left, node_right, node], self.apply)
# We could end up with the following case.
# shape_left = (1, 1, 'seq_length', 'cache_length + seq_length')
# shape_right = (1, 1, 'seq_length', 'cache_length + seq_length')
# output_shape = ('batch', 1, 'seq_length', 'cache_length + seq_length')
# When this happes, it could also be caught by another pattern.
return self.none(node, inspect.currentframe().f_lineno)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
expand_left: NodeProto,
expand_right: NodeProto,
binary_node: NodeProto,
) -> List[NodeProto]:
nodes = []
if expand_left is not None and g.is_used_more_than_once(expand_left.output[0]):
nodes.append(expand_left)
if expand_right is not None and g.is_used_more_than_once(expand_right.output[0]):
nodes.append(expand_right)
assert (
not binary_node.attribute
), f"Binary operator should not have any attribute, binary_node={binary_node}"
return [
*nodes,
g.make_node(
binary_node.op_type,
[
binary_node.input[0] if expand_left is None else expand_left.input[0],
binary_node.input[1] if expand_right is None else expand_right.input[0],
],
binary_node.output,
name=f"{self.__class__.__name__}--{binary_node.name}",
doc_string=binary_node.doc_string,
),
]
[docs]
class ExpandSwapPattern(PatternOptimization):
"""
Tries to move a node Expand forward in the graph.
Expand + Exp can be changed into Exp + Expand.
Then Exp applies on a tensor of a smaller or equal size.
Model with nodes to be fused:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 26),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(oh.make_tensor_value_info("p", onnx.TensorProto.INT64, shape=(1,)))
inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(1, 5, 7)))
inputs.append(oh.make_tensor_value_info("shape", onnx.TensorProto.INT64, shape=(3,)))
nodes.append(
oh.make_node(
"Constant",
[],
["shape"],
value=onh.from_array(np.array([3, 1, 1], dtype=np.int64), name="value"),
)
)
nodes.append(
oh.make_node(
"Constant",
[],
["p"],
value=onh.from_array(np.array([2], dtype=np.int64), name="value"),
)
)
nodes.append(oh.make_node("Expand", ["X", "shape"], ["xs"]))
nodes.append(oh.make_node("Pow", ["xs", "p"], ["Z"]))
outputs.append(oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=(3, 5, 7)))
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
Outcome of the fusion:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 26),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(oh.make_tensor_value_info("p", onnx.TensorProto.INT64, shape=(1,)))
inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(1, 5, 7)))
inputs.append(oh.make_tensor_value_info("shape", onnx.TensorProto.INT64, shape=(3,)))
nodes.append(oh.make_node("Pow", ["X", "p"], ["ExpandSwapPattern_X"]))
nodes.append(oh.make_node("Expand", ["ExpandSwapPattern_X", "shape"], ["Z"]))
outputs.append(oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=(3, 5, 7)))
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
"""
_op_types = unary_like_op_types()
_other_types = {"NegXplus1", "ReplaceZero", "Pow"}
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Expand" or node.domain != "":
return self.none()
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
assert g.is_used(node.output[0]), (
f"The match should not even begin, {node.output[0]!r} "
f"is not used among {node.output} and type={node.op_type!r}"
)
if g.is_used_more_than_once(node.output[0]):
# More than one output so it probably must be done.
return self.none(node, inspect.currentframe().f_lineno)
next_nodes = g.next_nodes(node.output[0])
assert len(next_nodes) == 1, "The previous test should have cleared out this case."
next_node = next_nodes[0]
if next_node.op_type not in self._other_types and (
next_node.op_type not in self._op_types or next_node.domain != ""
):
# Not an unary wise operator.
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node, next_node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
next_node: NodeProto,
) -> List[NodeProto]:
# We need to create a new name for the intermediate results.
# The optimizer cannot reuse an existing name if the new result
# has a different shape.
new_name = g.unique_name(f"{self.__class__.__name__}_{node.input[0]}")
unary = g.make_node(
next_node.op_type,
[node.input[0], *next_node.input[1:]],
[new_name],
name=f"{self.__class__.__name__}--{node.name}",
domain=next_node.domain,
doc_string=next_node.doc_string,
)
unary.attribute.extend(next_node.attribute)
expand = g.make_node(
node.op_type, # Expand
[new_name, node.input[1]],
[next_node.output[0]],
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
)
return [unary, expand]
[docs]
class ShapeBasedStaticExpandPattern(PatternOptimization):
"""
Compares input and output shapes to tell if the expand
can uses a constant as a second input.
Model with nodes to be fused:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [oh.make_opsetid("", 18)]
inputs = []
outputs = []
nodes = []
initializers = [onh.from_array(np.array([2], dtype=np.int64), name="two")]
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(2, 3, "d", 1))
)
nodes.append(oh.make_node("Shape", ["X"], ["D2"], start=0, end=-1))
nodes.append(oh.make_node("Concat", ["D2", "two"], ["d"], axis=0))
nodes.append(oh.make_node("Expand", ["X", "d"], ["Y"]))
outputs.append(
oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(2, 3, "d", 2))
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
Outcome of the fusion:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(2, 3, "d", 1))
)
nodes.append(
oh.make_node(
"Constant",
[],
["init7_s4_1_1_1_2"],
value=onh.from_array(np.array([1, 1, 1, 2], dtype=np.int64), name="value"),
)
)
nodes.append(oh.make_node("Expand", ["X", "init7_s4_1_1_1_2"], ["Y"]))
outputs.append(
oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(2, 3, "d", 2))
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
"""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
@classmethod
def _find_expand_shape(
cls, sh1: Tuple[Union[str, int], ...], sh2: Tuple[Union[str, int], ...]
) -> Tuple[int, ...]:
expand_shape = []
for s1, s2 in zip(sh1, sh2):
if s1 == s2:
expand_shape.append(1)
continue
if not isinstance(s1, int) or not isinstance(s2, int):
return None
expand_shape.append(s2)
return tuple(expand_shape)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Expand" or node.domain != "":
return self.none()
if g.is_constant(node.input[1]):
# already done
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
sh1 = g.get_shape_renamed(node.input[0])
sh2 = g.get_shape_renamed(node.output[0])
if len(sh1) != len(sh2):
# We ignore that case for the time being.
return self.none(node, inspect.currentframe().f_lineno)
expand_shape = self._find_expand_shape(sh1, sh2)
if expand_shape is None:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
reshape: NodeProto,
) -> List[NodeProto]:
expand_shape = self._find_expand_shape(
g.get_shape_renamed(reshape.input[0]), g.get_shape_renamed(reshape.output[0])
)
new_shape = g.make_initializer(
"",
np.array(expand_shape, dtype=np.int64),
source=f"{self.__class__.__name__}.m1",
)
return [
g.make_node(
"Expand",
[reshape.input[0], new_shape],
reshape.output,
name=f"{self.__class__.__name__}--{reshape.name}",
doc_string=reshape.doc_string,
)
]
[docs]
class ShapeBasedExpandSwapPattern(PatternOptimization):
"""
Tries to move a node Expand forward in the graph
for a binary operator. The code is similar to
:class:`experimental_experiment.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern`
Model with nodes to be fused:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("full_shape", onnx.TensorProto.INT64, shape=(2,))
)
inputs.append(oh.make_tensor_value_info("Xc", onnx.TensorProto.FLOAT, shape=("d", 1)))
inputs.append(oh.make_tensor_value_info("one", onnx.TensorProto.FLOAT, shape=(1,)))
nodes.append(
oh.make_node(
"Constant",
[],
["one"],
value=onh.from_array(np.array([4.0], dtype=np.float32), name="value"),
)
)
nodes.append(oh.make_node("Expand", ["Xc", "full_shape"], ["Xce"]))
nodes.append(oh.make_node("Add", ["Xce", "one"], ["Y"]))
outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("d", "d")))
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
Outcome of the fusion:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("full_shape", onnx.TensorProto.INT64, shape=(2,))
)
inputs.append(oh.make_tensor_value_info("Xc", onnx.TensorProto.FLOAT, shape=("d", 1)))
inputs.append(oh.make_tensor_value_info("one", onnx.TensorProto.FLOAT, shape=(1,)))
nodes.append(
oh.make_node("Add", ["Xc", "one"], ["ShapeBasedExpandSwapPattern_Y"])
)
nodes.append(
oh.make_node("Expand", ["ShapeBasedExpandSwapPattern_Y", "full_shape"], ["Y"])
)
outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("d", "d")))
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
"""
_op_types = element_wise_binary_op_types() | element_wise_op_cmp_types()
@classmethod
def _broadcast_shape(
cls,
before_expand_shape: DYNAMIC_SHAPE,
other_term_shape: DYNAMIC_SHAPE,
exc: bool = False,
) -> Optional[DYNAMIC_SHAPE]:
if len(before_expand_shape) != len(other_term_shape):
d = abs(len(before_expand_shape) - len(other_term_shape))
if len(before_expand_shape) < len(other_term_shape):
before_expand_shape = (1,) * d + before_expand_shape
else:
other_term_shape = (1,) * d + other_term_shape
if len(before_expand_shape) != len(other_term_shape):
assert not exc, (
f"Unable to produce a broadcasted shape from "
f"{before_expand_shape} and {other_term_shape}"
)
return None
res = []
for a, b in zip(before_expand_shape, other_term_shape):
if a == b:
res.append(a)
elif a == 1:
res.append(b)
elif b == 1:
res.append(a)
else:
assert not exc, (
f"Unable to produce a broadcasted shape from "
f"{before_expand_shape} and {other_term_shape}"
)
return None
return tuple(res)
@classmethod
def _get_compatible_expand_shape_for_expand_swap(
cls,
before_expand_shape: DYNAMIC_SHAPE,
expanded_shape: DYNAMIC_SHAPE,
other_term_shape: DYNAMIC_SHAPE,
other_expanded_shape: Optional[DYNAMIC_SHAPE],
output_shape: DYNAMIC_SHAPE,
) -> Optional[DYNAMIC_SHAPE]:
"""
Something like that should work.
The function returns a shape or None is not possible.
.. code-block:: python
_get_compatible_expand_shape_for_expand_swap(
("batch", 1, 1, 1),
("batch", 1, "seq_length", "cache_length+seq_length"),
(1,),
None,
("batch", 1, "seq_length", "cache_length+seq_length"),
)
>>> ("batch", 1, "seq_length", "cache_length+seq_length")
)
"""
if other_expanded_shape is not None and (
other_expanded_shape != expanded_shape
or expanded_shape != output_shape
or len(before_expand_shape) != len(other_term_shape)
):
return None
if before_expand_shape == expanded_shape or expanded_shape == other_term_shape:
# This pattern is not meant for that.
return None
if output_shape != expanded_shape:
return None
if (
other_expanded_shape is None
and not ShapeBasedExpandBroadcastPattern._is_compatible_shapes_for_expand(
before_expand_shape,
other_term_shape,
cls._broadcast_shape(before_expand_shape, other_term_shape, exc=False),
)
):
return None
if (
other_expanded_shape is not None
and not ShapeBasedExpandBroadcastPattern._is_compatible_shapes_for_expand(
before_expand_shape,
other_term_shape,
cls._broadcast_shape(before_expand_shape, other_term_shape, exc=False),
)
):
return None
if other_expanded_shape is None:
return "expand_arg"
max_dim = cls._broadcast_shape(before_expand_shape, other_term_shape)
if max_dim == output_shape:
# Expand is not necessary at all.
return None
return tuple(1 if a == b else 0 for a, b in zip(max_dim, output_shape))
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type not in self._op_types or node.domain != "":
return self.none()
if (
not g.has_shape(node.output[0])
or not g.has_shape(node.input[0])
or not g.has_shape(node.input[1])
):
return self.none(node, inspect.currentframe().f_lineno)
node_left = g.node_before(node.input[0])
node_right = g.node_before(node.input[1])
before = [
None if n is None or n.op_type != "Expand" else n for n in [node_left, node_right]
]
if before == [None, None]:
return self.none(node, inspect.currentframe().f_lineno)
if None in before:
# Only one expand
node_left, node_right = before
shape_left = g.get_shape_renamed(
node.input[0] if node_left is None else node_left.input[0]
)
shape_right = g.get_shape_renamed(
node.input[1] if node_right is None else node_right.input[0]
)
before_expand_shape = shape_right if node_left is None else shape_left
expanded_shape = (
g.get_shape_renamed(node_right.output[0])
if node_left is None
else g.get_shape_renamed(node_left.output[0])
)
other_term_shape = shape_left if node_left is None else shape_right
output_shape = g.get_shape_renamed(node.output[0])
if self._get_compatible_expand_shape_for_expand_swap(
before_expand_shape, expanded_shape, other_term_shape, None, output_shape
):
if self.verbose:
print(
f"[{self.__class__.__name__}.match.1] {shape_left} "
f"{node.op_type} {shape_right} -> {output_shape}"
)
return MatchResult(self, [node_left, node_right, node], self.apply)
return self.none(node, inspect.currentframe().f_lineno)
# Both expand.
node_left, node_right = before
if node_left.input[1] != node_right.input[1]:
# It could work in that case if both expand have different
# shape argument but the code to make sure it is is not implemented.
return self.none(node, inspect.currentframe().f_lineno)
shape_left = g.get_shape_renamed(node_left.input[0])
shape_right = g.get_shape_renamed(node_right.input[0])
output_shape = g.get_shape_renamed(node.output[0])
expand_arg = self._get_compatible_expand_shape_for_expand_swap(
shape_left,
g.get_shape_renamed(node.input[0]),
shape_right,
g.get_shape_renamed(node.input[1]),
output_shape,
)
if expand_arg:
if self.verbose:
print(
f"[{self.__class__.__name__}.match.2] {shape_left} "
f"{node.op_type} {shape_right} -> {output_shape} with "
f"expand_arg={expand_arg}"
)
return MatchResult(self, [node_left, node_right, node], self.apply)
return self.none(node, inspect.currentframe().f_lineno)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
expand_left: NodeProto,
expand_right: NodeProto,
binary_node: NodeProto,
) -> List[NodeProto]:
nodes = []
if expand_left is not None and g.is_used_more_than_once(expand_left.output[0]):
nodes.append(expand_left)
if expand_right is not None and g.is_used_more_than_once(expand_right.output[0]):
nodes.append(expand_right)
assert (
not binary_node.attribute
), f"Binary operator should not have any attribute, binary_node={binary_node}"
new_name = g.unique_name(f"{self.__class__.__name__}_{binary_node.output[0]}")
nodes.append(
g.make_node(
binary_node.op_type,
[
binary_node.input[0] if expand_left is None else expand_left.input[0],
binary_node.input[1] if expand_right is None else expand_right.input[0],
],
[new_name],
name=f"{self.__class__.__name__}--{binary_node.name}",
doc_string=binary_node.doc_string,
),
)
# One or two expand, same rewriting as the expand argument is the same.
return [
*nodes,
g.make_node(
"Expand",
[
new_name,
expand_left.input[1] if expand_right is None else expand_right.input[1],
],
binary_node.output,
name=f"{self.__class__.__name__}--{binary_node.name}",
doc_string=binary_node.doc_string,
),
]
[docs]
class ShapeBasedExpandBroadcastMatMulPattern(PatternOptimization):
"""
Similar to
:class:`experimental_experiment.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern`,
but works only with MatMul.
Model with nodes to be fused:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = [onh.from_array(np.array([1, 1], dtype=np.int64), name="o11")]
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(1, "c", "d"))
)
inputs.append(
oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", "b", "c"))
)
nodes.append(oh.make_node("Shape", ["Y"], ["batch"], start=0, end=1))
nodes.append(oh.make_node("Concat", ["batch", "o11"], ["exp"], axis=0))
nodes.append(oh.make_node("Expand", ["Y", "exp"], ["Ye"]))
nodes.append(oh.make_node("MatMul", ["X", "Ye"], ["Z"]))
outputs.append(
oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=("a", "b", "d"))
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
Outcome of the fusion:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(1, "c", "d"))
)
inputs.append(
oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", "b", "c"))
)
nodes.append(oh.make_node("MatMul", ["X", "Y"], ["Z"]))
outputs.append(
oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=("a", "b", "d"))
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
"""
@classmethod
def _is_compatible_shapes_for_expand(
cls,
shape_left: DYNAMIC_SHAPE,
shape_right: DYNAMIC_SHAPE,
output_shape: Optional[DYNAMIC_SHAPE],
) -> bool:
"""
Checks that the binary operations of the two input shapes returns the output_shape.
Then no Expand node is needed.
"""
if output_shape is None:
return False
if len(shape_left) < 2 or len(shape_right) < 2 or len(output_shape) < 2:
return False
return ShapeBasedExpandBroadcastPattern._is_compatible_shapes_for_expand(
shape_left[:-2], shape_right[:-2], output_shape[:-2]
)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "MatMul" or node.domain != "":
return self.none()
if not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
node_left = g.node_before(node.input[0])
node_right = g.node_before(node.input[1])
before = [
None if n is None or n.op_type != "Expand" else n for n in [node_left, node_right]
]
if before == [None, None]:
return self.none(node, inspect.currentframe().f_lineno)
# At least one expand.
node_left, node_right = before
shape_left = g.get_shape_renamed(
node.input[0] if node_left is None else node_left.input[0]
)
shape_right = g.get_shape_renamed(
node.input[1] if node_right is None else node_right.input[0]
)
if self._is_compatible_shapes_for_expand(
shape_left, shape_right, g.get_shape_renamed(node.output[0])
):
if self.verbose:
print(
f"[{self.__class__.__name__}.match] {shape_left} "
f"{node.op_type} {shape_right} -> {g.get_shape_renamed(node.output[0])}"
)
return MatchResult(self, [node_left, node_right, node], self.apply)
# We could end up with the following case.
# shape_left = (1, 1, 'seq_length', 'cache_length + seq_length')
# shape_right = (1, 1, 'seq_length', 'cache_length + seq_length')
# output_shape = ('batch', 1, 'seq_length', 'cache_length + seq_length')
# When this happes, it could also be caught by another pattern.
return self.none(node, inspect.currentframe().f_lineno)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
expand_left: NodeProto,
expand_right: NodeProto,
binary_node: NodeProto,
) -> List[NodeProto]:
nodes = []
if expand_left is not None and g.is_used_more_than_once(expand_left.output[0]):
nodes.append(expand_left)
if expand_right is not None and g.is_used_more_than_once(expand_right.output[0]):
nodes.append(expand_right)
assert (
not binary_node.attribute
), f"Binary operator should not have any attribute, binary_node={binary_node}"
return [
*nodes,
g.make_node(
binary_node.op_type,
[
binary_node.input[0] if expand_left is None else expand_left.input[0],
binary_node.input[1] if expand_right is None else expand_right.input[0],
],
binary_node.output,
name=f"{self.__class__.__name__}--{binary_node.name}",
doc_string=binary_node.doc_string,
),
]
[docs]
class ShapeBasedExpandCastWhereSwapPattern(PatternOptimization):
"""
Rewrites Where(Cast(X), X, cond).
Model with nodes to be fused:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("b", "c")))
inputs.append(oh.make_tensor_value_info("exp", onnx.TensorProto.INT64, shape=(3,)))
inputs.append(oh.make_tensor_value_info("cst", onnx.TensorProto.FLOAT, shape=(1,)))
nodes.append(
oh.make_node(
"Constant",
[],
["cst"],
value=onh.from_array(np.array([-np.inf], dtype=np.float32), name="value"),
)
)
nodes.append(oh.make_node("Expand", ["X", "exp"], ["Xe"]))
nodes.append(oh.make_node("Cast", ["Xe"], ["Xeb"], to=9))
nodes.append(oh.make_node("Where", ["Xeb", "Xe", "cst"], ["Y"]))
outputs.append(
oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("b", "b", "c"))
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
Outcome of the fusion:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("b", "c")))
inputs.append(oh.make_tensor_value_info("exp", onnx.TensorProto.INT64, shape=(3,)))
inputs.append(oh.make_tensor_value_info("cst", onnx.TensorProto.FLOAT, shape=(1,)))
nodes.append(
oh.make_node(
"Cast", ["X"], ["ShapeBasedExpandCastWhereSwapPattern_Xeb"], to=9
)
)
nodes.append(
oh.make_node(
"Where",
["ShapeBasedExpandCastWhereSwapPattern_Xeb", "X", "cst"],
["ShapeBasedExpandCastWhereSwapPattern_Y"],
)
)
nodes.append(
oh.make_node(
"Expand", ["ShapeBasedExpandCastWhereSwapPattern_Y", "exp"], ["Y"]
)
)
outputs.append(
oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("b", "b", "c"))
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
"""
@classmethod
def _compatible_shapes(
cls,
cond: DYNAMIC_SHAPE,
cst: DYNAMIC_SHAPE,
output: DYNAMIC_SHAPE,
before: DYNAMIC_SHAPE,
):
if cond != output:
return False
if len(before) < len(output):
before = (1,) * (len(output) - len(before)) + before
if len(cst) < len(output):
cst = (1,) * (len(output) - len(cst)) + cst
out = ShapeBasedExpandSwapPattern._broadcast_shape(before, cst)
if len(out) != len(output) or len(out) != len(before):
return False
return all(not (o != e and o != b) for b, o, e in zip(before, out, output))
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Where" or node.domain != "":
return self.none()
if g.is_used_more_than_once(node.input[0]):
return self.none()
cast_node = g.node_before(node.input[0])
if cast_node is None or cast_node.op_type != "Cast" or cast_node.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if cast_node.input[0] not in node.input[1:]:
return self.none(node, inspect.currentframe().f_lineno)
expand_node = g.node_before(cast_node.input[0])
if expand_node is None or expand_node.op_type != "Expand" or expand_node.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
nodes = g.next_nodes(cast_node.input[0])
if len(nodes) != 2:
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[2]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(expand_node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
same_index = list(node.input).index(cast_node.input[0])
if self._compatible_shapes(
g.get_shape_renamed(node.input[0]),
g.get_shape_renamed(node.input[3 - same_index]),
g.get_shape_renamed(node.output[0]),
g.get_shape_renamed(expand_node.input[0]),
):
return MatchResult(self, [expand_node, cast_node, node], self.apply)
return self.none(node, inspect.currentframe().f_lineno)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
expand_node: NodeProto,
cast_node: NodeProto,
where_node: NodeProto,
) -> List[NodeProto]:
to = g.get_attribute(cast_node, "to").i
pos_index = list(where_node.input).index(expand_node.output[0])
cast_output = g.unique_name(f"{self.__class__.__name__}_{cast_node.output[0]}")
where_output = g.unique_name(f"{self.__class__.__name__}_{where_node.output[0]}")
return [
g.make_node(
cast_node.op_type,
[expand_node.input[0]],
[cast_output],
to=to,
name=f"{self.__class__.__name__}--{cast_node.name}",
doc_string=cast_node.doc_string,
),
g.make_node(
where_node.op_type,
(
[cast_output, expand_node.input[0], where_node.input[2]]
if pos_index == 1
else [cast_output, where_node.input[1], expand_node.input[0]]
),
[where_output],
name=f"{self.__class__.__name__}--{where_node.name}",
doc_string=where_node.doc_string,
),
g.make_node(
expand_node.op_type,
[where_output, expand_node.input[1]],
[where_node.output[0]],
name=f"{self.__class__.__name__}--{expand_node.name}",
doc_string=expand_node.doc_string,
),
]
[docs]
class ShapeBasedConcatExpandPattern(PatternOptimization):
"""
Rewrites Expand(X, concat(...)) if possible.
Model with nodes to be fused:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", 1)))
inputs.append(oh.make_tensor_value_info("two", onnx.TensorProto.INT64, shape=(1,)))
nodes.append(
oh.make_node(
"Constant",
[],
["two"],
value=onh.from_array(np.array([2], dtype=np.int64), name="value"),
)
)
nodes.append(oh.make_node("Shape", ["X"], ["shx"], start=0, end=1))
nodes.append(oh.make_node("Concat", ["shx", "two"], ["sh2"], axis=0))
nodes.append(oh.make_node("Expand", ["X", "sh2"], ["Y"]))
outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", 2)))
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
Outcome of the fusion:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", 1)))
inputs.append(oh.make_tensor_value_info("two", onnx.TensorProto.INT64, shape=(1,)))
nodes.append(
oh.make_node(
"Constant",
[],
["init7_12"],
value=onh.from_array(np.array([1], dtype=np.int64), name="value"),
)
)
nodes.append(oh.make_node("Concat", ["init7_12", "two"], ["sh22"], axis=0))
nodes.append(oh.make_node("Expand", ["X", "sh22"], ["Y"]))
outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", 2)))
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
"""
@classmethod
def _compatible_shapes(
cls,
g: "GraphBuilderPatternOptimization", # noqa: F821
shape: DYNAMIC_SHAPE,
expanded_shape: DYNAMIC_SHAPE,
concat_input: Sequence[str],
) -> Optional[int]:
if len(shape) != len(expanded_shape) or len(expanded_shape) != len(concat_input):
return None
position = []
for i, (a, b) in enumerate(zip(shape, expanded_shape)):
if a == b:
continue
position.append(i)
if len(position) != 1:
# It might be Identity but this should be caught by another pattern.
return None
return position[0]
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Expand" or node.domain != "":
return self.none()
if g.is_used_more_than_once(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
if g.is_constant(node.input[1]):
# no need
return self.none(node, inspect.currentframe().f_lineno)
concat_node = g.node_before(node.input[1])
if concat_node is None or concat_node.op_type != "Concat" or concat_node.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[0]) or not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
shape1 = g.get_shape_renamed(node.input[0])
shape2 = g.get_shape_renamed(node.output[0])
index = self._compatible_shapes(g, shape1, shape2, concat_node.input)
if index is None:
return self.none(node, inspect.currentframe().f_lineno)
# checking the other values are not 1
if all(
(i == index or (g.is_constant(name) and g.get_constant_scalar(name) == 1))
for i, name in enumerate(concat_node.input)
):
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [concat_node, node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
concat_node: NodeProto,
expand_node: NodeProto,
) -> List[NodeProto]:
shape1 = g.get_shape_renamed(expand_node.input[0])
shape2 = g.get_shape_renamed(expand_node.output[0])
index = self._compatible_shapes(g, shape1, shape2, concat_node.input)
init1 = g.make_initializer(
g.unique_name("init7_1"), g.ONE, source="ShapeBasedConcatExpandPattern.1"
)
new_input = [
(iname if i == index else init1) for i, iname in enumerate(concat_node.input)
]
new_name = g.unique_name(concat_node.output[0])
return [
g.make_node(
"Concat",
new_input,
[new_name],
axis=0,
name=f"{self.__class__.__name__}--{concat_node.name}",
doc_string=concat_node.doc_string,
),
g.make_node(
"Expand",
[expand_node.input[0], new_name],
expand_node.output,
name=f"{self.__class__.__name__}--{expand_node.name}",
doc_string=expand_node.doc_string,
),
]
[docs]
class SwapExpandReshapePattern(PatternOptimization):
"""
Checks if Expand + Reshape can be swapped.
Model with nodes to be fused:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("weight", onnx.TensorProto.FLOAT, shape=(1, 4, 1))
)
inputs.append(oh.make_tensor_value_info("stat", onnx.TensorProto.INT64, shape=(3,)))
inputs.append(oh.make_tensor_value_info("shape", onnx.TensorProto.INT64, shape=(3,)))
nodes.append(
oh.make_node(
"Constant",
[],
["weight"],
value=onh.from_array(
np.array([[[2.0], [3.0], [4.0], [5.0]]], dtype=np.float32), name="value"
),
)
)
nodes.append(
oh.make_node(
"Constant",
[],
["stat"],
value=onh.from_array(np.array([0, 1, -1], dtype=np.int64), name="value"),
)
)
nodes.append(oh.make_node("Expand", ["weight", "shape"], ["resh"]))
nodes.append(oh.make_node("Reshape", ["resh", "stat"], ["Y"]))
outputs.append(
oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", 1, 4))
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
Outcome of the fusion:
.. gdot::
:script: DOT-SECTION
:process:
from experimental_experiment.doc import to_dot
import numpy as np
import ml_dtypes
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
opset_imports = [
oh.make_opsetid("", 18),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
inputs.append(
oh.make_tensor_value_info("weight", onnx.TensorProto.FLOAT, shape=(1, 4, 1))
)
inputs.append(oh.make_tensor_value_info("stat", onnx.TensorProto.INT64, shape=(3,)))
inputs.append(oh.make_tensor_value_info("shape", onnx.TensorProto.INT64, shape=(3,)))
nodes.append(oh.make_node("Reshape", ["weight", "stat"], ["Y2"]))
nodes.append(oh.make_node("Expand", ["Y2", "shape"], ["Y"]))
outputs.append(
oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", 1, 4))
)
graph = oh.make_graph(
nodes,
"pattern",
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
print("DOT-SECTION", to_dot(model))
"""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Reshape" or node.domain != "":
return self.none()
if not g.is_constant(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
if g.is_used_more_than_once(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
expand_node = g.node_before(node.input[0])
if expand_node is None or expand_node.op_type != "Expand" or expand_node.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_rank(expand_node.input[0]) or g.get_rank(expand_node.input[0]) != 3:
return self.none(node, inspect.currentframe().f_lineno)
cst = g.get_computed_constant(node.input[1])
if cst is None:
return self.none(node, inspect.currentframe().f_lineno)
shape = g.builder.value_as_shape(expand_node.input[1])
if shape is None:
return self.none(node, inspect.currentframe().f_lineno)
if tuple(cst) != (0, 1, -1) or shape[1:] != (1, 1):
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [expand_node, node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
expand_node: NodeProto,
reshape_node: NodeProto,
) -> List[NodeProto]:
new_name = g.unique_name(reshape_node.output[0])
return [
g.make_node(
"Reshape",
[expand_node.input[0], reshape_node.input[1]],
[new_name],
name=f"{self.__class__.__name__}--{reshape_node.name}",
doc_string=reshape_node.doc_string,
),
g.make_node(
"Expand",
[new_name, expand_node.input[1]],
reshape_node.output,
name=f"{self.__class__.__name__}--{expand_node.name}",
doc_string=expand_node.doc_string,
),
]