Source code for experimental_experiment.xoptim.patterns.onnx_split

import inspect
from typing import List, Optional
import numpy as np
from onnx import NodeProto
from ...helpers import make_idn
from ..patterns_api import MatchResult, PatternOptimization


[docs] class SlicesSplitPattern(PatternOptimization): """ Merges multiple parallel slices into a split. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info( "transpose_1", onnx.TensorProto.FLOAT16, shape=(2, 2, 1024, 512) ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s1_0"], value=onh.from_array(np.array([0], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s1_256"], value=onh.from_array(np.array([256], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s1_3"], value=onh.from_array(np.array([3], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s1_9223372036854775807"], value=onh.from_array( np.array([9223372036854775807], dtype=np.int64), name="value" ), ) ) nodes.append( oh.make_node( "Slice", ["transpose_1", "init7_s1_0", "init7_s1_256", "init7_s1_3"], ["slice_11"], ) ) nodes.append( oh.make_node( "Slice", ["transpose_1", "init7_s1_256", "init7_s1_9223372036854775807", "init7_s1_3"], ["slice_12"], ) ) outputs.append( oh.make_tensor_value_info( "slice_11", onnx.TensorProto.FLOAT16, shape=(2, 2, 1024, 256) ) ) outputs.append( oh.make_tensor_value_info( "slice_12", onnx.TensorProto.FLOAT16, shape=(2, 2, 1024, 256) ) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info( "transpose_1", onnx.TensorProto.FLOAT16, shape=(2, 2, 1024, 512) ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s2_256_256"], value=onh.from_array(np.array([256, 256], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Split", ["transpose_1", "init7_s2_256_256"], ["slice_11", "slice_12"], axis=3 ) ) outputs.append( oh.make_tensor_value_info( "slice_11", onnx.TensorProto.FLOAT16, shape=(2, 2, 1024, 256) ) ) outputs.append( oh.make_tensor_value_info( "slice_12", onnx.TensorProto.FLOAT16, shape=(2, 2, 1024, 256) ) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Slice" or node.domain != "": return self.none() if not g.has_shape(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) users = [ op for op in g.next_nodes(node.input[0]) if op.op_type == "Slice" and op.domain == "" ] if len(users) <= 1: return self.none(node, inspect.currentframe().f_lineno) for user in users: if len(user.input) == 4: continue if len(user.input) == 5: if not g.is_constant_scalar(user.input[-1]): return self.none(node, inspect.currentframe().f_lineno) scalar = g.get_constant_scalar(user.input[-1]) if scalar != 1: return self.none(node, inspect.currentframe().f_lineno) continue return self.none(node, inspect.currentframe().f_lineno) # axis if all(len(op.input) == 2 for op in users): axis = 0 else: axes = [op.input[3] for op in users] if any(not g.is_constant_scalar(a) for a in axes): return self.none(node, inspect.currentframe().f_lineno) csts = [g.get_constant_scalar(a) for a in axes] if len(set(csts)) != 1: return self.none(node, inspect.currentframe().f_lineno) axis = csts[0] shape = g.get_shape(node.input[0]) dim = shape[axis] if not isinstance(dim, int): return self.none(node, inspect.currentframe().f_lineno) # starts, ends starts = [op.input[1] for op in users] ends = [op.input[2] for op in users] if not g.is_constant_scalar(starts[0], 0): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(ends[-1]): return self.none(node, inspect.currentframe().f_lineno) last = g.get_constant_scalar(ends[-1]) if last not in (dim, 9223372036854775807): # 9223372036854775807 is what torch uses to specify the end return self.none(node, inspect.currentframe().f_lineno) if any(not g.is_constant(i) for i in starts) or any(not g.is_constant(i) for i in ends): # no constants return self.none(node, inspect.currentframe().f_lineno) cst_starts = [None for a in starts] cst_ends = [None for a in ends] for i in range(len(starts) - 1): if ends[i] == starts[i + 1]: continue end = cst_ends[i] or g.get_computed_constant(ends[i]) start = cst_starts[i + 1] or g.get_computed_constant(starts[i + 1]) if all(end == start): cst_ends[i] = end cst_starts[i + 1] = start continue return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, users, self.apply)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 *nodes: NodeProto, ) -> List[NodeProto]: # nodes are all slices starts = [op.input[1] for op in nodes] ends = [op.input[2] for op in nodes] cst_starts = [g.get_constant_scalar(a) for a in starts] cst_ends = [g.get_constant_scalar(a) for a in ends] axis = g.get_constant_scalar(nodes[0].input[3]) if cst_ends[-1] == 9223372036854775807: # 9223372036854775807 is what torch uses to specify the end shape = g.get_shape(nodes[0].input[0]) cst_ends[-1] = shape[axis] n_els = [cst_ends[i] - cst_starts[i] for i in range(len(starts))] splits = g.make_initializer( "", np.array(n_els, dtype=np.int64), source="SlicesSplitPattern.apply.splits" ) outputs = [op.output[0] for op in nodes] node = g.make_node( "Split", [nodes[0].input[0], splits], outputs, axis=axis, name=f"{self.__class__.__name__}--{nodes[0].name}", ) return [node]
[docs] class GathersSplitPattern(PatternOptimization): """ Merges multiple parallel gather into a split followed by unsqueeze. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh from onnx_array_api.translate_api.make_helper import make_node_extended opset_imports = [ oh.make_opsetid("", 26), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", 2))) nodes.append( make_node_extended( "Constant", [], ["zero"], value=onh.from_array(np.array(0, dtype=np.int64), name="value"), ) ) nodes.append( make_node_extended( "Constant", [], ["one"], value=onh.from_array(np.array(1, dtype=np.int64), name="value"), ) ) nodes.append(make_node_extended("Gather", ["X", "zero"], ["x1"], axis=1)) nodes.append(make_node_extended("Gather", ["X", "one"], ["x2"], axis=1)) outputs.append(oh.make_tensor_value_info("x2", onnx.TensorProto.FLOAT, shape=("a",))) outputs.append(oh.make_tensor_value_info("x1", onnx.TensorProto.FLOAT, shape=("a",))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh from onnx_array_api.translate_api.make_helper import make_node_extended opset_imports = [ oh.make_opsetid("", 26), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", 2))) nodes.append( make_node_extended( "Constant", [], ["init7_s1_1"], value=onh.from_array(np.array([1], dtype=np.int64), name="value"), ) ) nodes.append( make_node_extended( "Split", ["X"], ["GathersSplitPattern--x1", "GathersSplitPattern--x2"], axis=1, num_outputs=2, ) ) nodes.append( make_node_extended("Squeeze", ["GathersSplitPattern--x1", "init7_s1_1"], ["x1"]) ) nodes.append( make_node_extended("Squeeze", ["GathersSplitPattern--x2", "init7_s1_1"], ["x2"]) ) outputs.append(oh.make_tensor_value_info("x2", onnx.TensorProto.FLOAT, shape=("a",))) outputs.append(oh.make_tensor_value_info("x1", onnx.TensorProto.FLOAT, shape=("a",))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Gather" or node.domain != "": return self.none() if not g.has_shape(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) users = [ op for op in g.next_nodes(node.input[0]) if op.op_type == "Gather" and op.domain == "" ] if len(users) <= 1: return self.none(node, inspect.currentframe().f_lineno) axis = None csts = set() rank = None keep_users = [] for user in users: if len(user.input) != 2: continue a = g.get_attribute_with_default(user, "axis", default_value=0) assert a is not None, f"user={user}" if axis is not None and a != axis: return self.none(node, inspect.currentframe().f_lineno) axis = a if not g.is_constant_scalar(user.input[1]): continue cst = g.get_constant_scalar(user.input[1]) if cst is None: return self.none(node, inspect.currentframe().f_lineno) if cst in csts: return self.none(node, inspect.currentframe().f_lineno) rk = g.get_rank(user.input[1]) if rank is not None and rk != rank: return self.none(node, inspect.currentframe().f_lineno) rank = rk csts.add(cst) keep_users.append(user) users = keep_users sorted_indices = sorted(csts) if sorted_indices != list(range(len(csts))): return self.none(node, inspect.currentframe().f_lineno) shape = g.get_shape(node.input[0]) if axis < 0: axis += len(shape) if axis >= len(shape): return self.none(node, inspect.currentframe().f_lineno) if not isinstance(shape[axis], int): return self.none(node, inspect.currentframe().f_lineno) if shape[axis] != len(sorted_indices): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, users, self.apply)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 *gather_nodes: NodeProto, ) -> List[NodeProto]: # nodes are all slices axis = g.get_attribute_with_default(gather_nodes[0], "axis", default_value=0) outputs = [None for u in gather_nodes] rank = g.get_rank(gather_nodes[0].input[1]) post_nodes = [] if rank == 0: axis_init = g.make_initializer( "", np.array([axis], dtype=np.int64), source=f"{self.__class__.__name__}.axes" ) for user in gather_nodes: cst = g.get_constant_scalar(user.input[1]) if rank == 1: outputs[cst] = user.output[0] else: name = g.unique_name(f"{self.__class__.__name__}--{user.output[0]}") post_nodes.append( g.make_node( "Squeeze", [name, axis_init], [user.output[0]], name=f"{self.__class__.__name__}--{user.name}", ) ) outputs[cst] = name node = g.make_node( "Split", [gather_nodes[0].input[0]], outputs, axis=axis, num_outputs=len(outputs), name=f"{self.__class__.__name__}--{gather_nodes[0].name}", ) return [node, *post_nodes]
[docs] class SplitConcatPattern(PatternOptimization): """ Replaces Split + Concat into identity if this is equivalent. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", "b"))) nodes.append(oh.make_node("Split", ["X"], ["s1", "s2"], axis=-1, num_outputs=2)) nodes.append(oh.make_node("Concat", ["s1", "s2"], ["Y"], axis=-1)) outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", "b"))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", "b"))) nodes.append(oh.make_node("Identity", ["X"], ["Y"])) outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", "b"))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Split" or node.domain != "": return self.none() only_id = None only_node = None for o in node.output: n = g.next_nodes(o) if len(n) != 1: return self.none(node, inspect.currentframe().f_lineno) i = make_idn(n[0]) if only_id is None: only_id = i only_node = n[0] elif i != only_id: return self.none(node, inspect.currentframe().f_lineno) if only_node.op_type != "Concat" or only_node.domain != "": return self.none(node, inspect.currentframe().f_lineno) axis_split = g.get_attribute(node, "axis").i axis_concat = g.get_attribute(only_node, "axis").i if axis_split < 0 and axis_concat >= 0: axis_split += g.get_rank(node.input[0]) if axis_concat < 0 and axis_split >= 0: axis_concat += g.get_rank(node.input[0]) if axis_split != axis_concat: return self.none(node, inspect.currentframe().f_lineno) if node.output != only_node.input: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node, only_node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 split_node: NodeProto, concat_node: NodeProto, ) -> List[NodeProto]: return [ g.make_node( "Identity", split_node.input, concat_node.output, name=f"{self.__class__.__name__}--{split_node.name}", ) ]