Source code for experimental_experiment.xoptim.patterns.onnx_layer_normalization

import inspect
from typing import List, Optional
import numpy as np
from onnx import NodeProto
from ...helpers import tensor_dtype_to_np_dtype, from_array_extended
from ..patterns_api import MatchResult, PatternOptimization


[docs] class LayerNormalizationPattern(PatternOptimization): """ Fuses nodes of a LayerNormalization. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("add_1", onnx.TensorProto.FLOAT16, shape=(4, 512, 128)) ) nodes.append( oh.make_node( "Constant", [], ["init7_s1_-1"], value=onh.from_array(np.array([-1], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init10_s1_"], value=onh.from_array(np.array([2.0], dtype=np.float16), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init10_s1_2"], value=onh.from_array(np.array([0.0], dtype=np.float16), name="value"), ) ) nodes.append( oh.make_node( "ReduceMean", ["add_1", "init7_s1_-1"], ["_onx_reducemean0"], keepdims=1 ) ) nodes.append(oh.make_node("Sub", ["add_1", "_onx_reducemean0"], ["_onx_sub0"])) nodes.append(oh.make_node("Pow", ["_onx_sub0", "init10_s1_"], ["_onx_pow0"])) nodes.append( oh.make_node( "ReduceMean", ["_onx_pow0", "init7_s1_-1"], ["_onx_reducemean02"], keepdims=1 ) ) nodes.append( oh.make_node("Add", ["_onx_reducemean02", "init10_s1_2"], ["_onx_add0"]) ) nodes.append(oh.make_node("Sqrt", ["_onx_add0"], ["_onx_sqrt0"])) nodes.append(oh.make_node("Div", ["_onx_sub0", "_onx_sqrt0"], ["_onx_div0"])) outputs.append( oh.make_tensor_value_info( "_onx_div0", onnx.TensorProto.FLOAT16, shape=(4, 512, 128) ) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("add_1", onnx.TensorProto.FLOAT16, shape=(4, 512, 128)) ) nodes.append( oh.make_node( "Constant", [], ["p_model_albert_embeddings_layernorm_weight"], value=onh.from_array( np.array( [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ], dtype=np.float16, ), name="value", ), ) ) nodes.append( oh.make_node( "Constant", [], ["p_model_albert_embeddings_layernorm_bias"], value=onh.from_array( np.array( [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ], dtype=np.float16, ), name="value", ), ) ) nodes.append( oh.make_node( "LayerNormalization", [ "add_1", "p_model_albert_embeddings_layernorm_weight", "p_model_albert_embeddings_layernorm_bias", ], ["_onx_div0"], axis=-1, epsilon=0.0, stash_type=1, ) ) outputs.append( oh.make_tensor_value_info( "_onx_div0", onnx.TensorProto.FLOAT16, shape=(4, 512, 128) ) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "ReduceMean" or node.domain != "": return self.none() if len(node.input) != 2: # Not defined for older opset than 18. return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) axis = g.get_computed_constant(node.input[1]) if axis.tolist() != [-1]: if not g.has_rank(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) rk = g.get_rank(node.input[0]) al = axis.tolist() if al != list(range(rk - len(al), rk)): return self.none(node, inspect.currentframe().f_lineno) # before pow = g.node_before(node.input[0]) if pow is None: return self.none(node, inspect.currentframe().f_lineno) if pow.op_type != "Pow" or len(g.next_nodes(pow.output[0])) != 1: return self.none(node, inspect.currentframe().f_lineno) if ( not g.is_constant_scalar(pow.input[1], broadcast=True) or g.get_constant_scalar(pow.input[1], broadcast=True) != 2 ): return self.none(node, inspect.currentframe().f_lineno) sub = g.node_before(pow.input[0]) if sub is None: return self.none(node, inspect.currentframe().f_lineno) if sub.op_type != "Sub" or len(g.next_nodes(sub.output[0])) != 2: return self.none(node, inspect.currentframe().f_lineno) red = g.node_before(sub.input[1]) if red is None: return self.none(node, inspect.currentframe().f_lineno) if red.op_type != "ReduceMean" or len(g.next_nodes(red.output[0])) != 1: return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(red.input[1]): return self.none(node, inspect.currentframe().f_lineno) axis2 = g.get_computed_constant(red.input[1]) if axis.tolist() != axis2.tolist(): return self.none(node, inspect.currentframe().f_lineno) if sub.input[0] != red.input[0]: return self.none(node, inspect.currentframe().f_lineno) kp = g.get_attribute(red, "keepdims", exc=False) if kp is None or kp.i != 1: return self.none(node, inspect.currentframe().f_lineno) # after add = g.next_nodes(node.output[0]) if len(add) != 1: return self.none(node, inspect.currentframe().f_lineno) if add[0].op_type == "Add": add = add[0] if not g.is_constant_scalar(add.input[1], broadcast=True): return self.none(node, inspect.currentframe().f_lineno) sqrt = g.next_nodes(add.output[0]) else: add = None if add is None: sqrt = g.next_nodes(node.output[0]) if len(sqrt) != 1 or sqrt[0].op_type != "Sqrt": return self.none(node, inspect.currentframe().f_lineno) sqrt = sqrt[0] div = g.next_nodes(sqrt.output[0]) if len(div) != 1: return self.none(node, inspect.currentframe().f_lineno) div = div[0] if div.op_type == "Div": if len(g.next_nodes(div.input[1])) != 1: return self.none(node, inspect.currentframe().f_lineno) if div.input[0] != sub.output[0]: return self.none(node, inspect.currentframe().f_lineno) elif div.op_type == "Reciprocal": if div.input[0] != sub.output[0]: return self.none(node, inspect.currentframe().f_lineno) else: return self.none(node, inspect.currentframe().f_lineno) return MatchResult( self, [red, sub, pow, node, add, sqrt, div], self.apply, insert_at=node )
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 red: NodeProto, sub: NodeProto, pow: NodeProto, node: NodeProto, add: Optional[NodeProto], sqrt: NodeProto, div: NodeProto, ) -> List[NodeProto]: itype = g.get_type(red.input[0]) dtype = tensor_dtype_to_np_dtype(itype) axis = g.get_computed_constant(red.input[1]).tolist() scale = None dtype = tensor_dtype_to_np_dtype(g.get_type(red.input[0])) new_nodes = [] if axis == [-1]: ly_axis = -1 if g.has_shape(red.input[0]): shape = g.get_shape(red.input[0]) if isinstance(shape[-1], int): scale = g.make_initializer( "", np.ones((shape[-1],), dtype=dtype), source="LayerNormalizationPattern.apply.scale", ) bias = g.make_initializer( "", np.zeros((shape[-1],), dtype=dtype), source="LayerNormalizationPattern.apply.bias", ) else: ly_axis = min(axis) if scale is None: shape = g.unique_name(f"{self.__class__.__name__}_Sh_{red.input[0]}") new_nodes.append( g.make_node( "Shape", [red.input[0]], [shape], start=ly_axis, name=f"{self.__class__.__name__}--{red.name}", ) ) scale = g.unique_name(f"{self.__class__.__name__}_Sc_{red.input[0]}") new_nodes.append( g.make_node( "ConstantOfShape", [shape], [scale], name=f"{self.__class__.__name__}--{red.name}", value=from_array_extended(np.array([1], dtype=dtype)), ) ) bias = g.unique_name(f"{self.__class__.__name__}_Bi_{red.input[0]}") new_nodes.append( g.make_node( "ConstantOfShape", [shape], [bias], name=f"{self.__class__.__name__}--{red.name}", value=from_array_extended(np.array([0], dtype=dtype)), ) ) eps = ( g.get_constant_scalar(add.input[1], broadcast=True) if add else 9.999999960041972e-13 ) new_nodes.append( g.make_node( "LayerNormalization", [red.input[0], scale, bias], [div.output[0]], epsilon=float(eps), name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, stash_type=1, # itype, axis=ly_axis, ) ) return new_nodes
[docs] class LayerNormalizationScalePattern(PatternOptimization): """ Fused LayerNormalization, scale, bias just after. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 20), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("scale", onnx.TensorProto.FLOAT, shape=(3,))) inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", "b"))) inputs.append(oh.make_tensor_value_info("s0", onnx.TensorProto.FLOAT, shape=(3,))) nodes.append( oh.make_node( "Constant", [], ["s0"], value=onh.from_array( np.array( [-0.10000000149011612, -0.009999999776482582, -0.05000000074505806], dtype=np.float32, ), name="value", ), ) ) nodes.append( oh.make_node( "Constant", [], ["scale"], value=onh.from_array(np.array([2.0, 3.0, 4.0], dtype=np.float32), name="value"), ) ) nodes.append( oh.make_node( "LayerNormalization", ["X", "s0"], ["norm"], epsilon=0.10000000149011612 ) ) nodes.append(oh.make_node("Mul", ["norm", "scale"], ["Y"])) outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", "b"))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 20), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("scale", onnx.TensorProto.FLOAT, shape=(3,))) inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", "b"))) inputs.append(oh.make_tensor_value_info("s0", onnx.TensorProto.FLOAT, shape=(3,))) nodes.append( oh.make_node("Mul", ["s0", "scale"], ["LayerNormalizationScalePattern_s0"]) ) nodes.append( oh.make_node( "LayerNormalization", ["X", "LayerNormalizationScalePattern_s0"], ["Y"], epsilon=0.10000000149011612, ) ) outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", "b"))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "LayerNormalization" or node.domain != "": return self.none() if len(node.output) != 1: # No need for the scale. return self.none(node, inspect.currentframe().f_lineno) nodes = g.next_nodes(node.output[0]) if len(nodes) != 1 or nodes[0].op_type != "Mul": return self.none(node, inspect.currentframe().f_lineno) mul_node = nodes[0] nodes = g.next_nodes(mul_node.output[0]) if len(nodes) == 0: return MatchResult(self, [node, mul_node, None], self.apply, insert_at=mul_node) index = 1 if mul_node.input[0] == node.output[0] else 0 if not g.has_shape(mul_node.input[index]): return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) if g.get_shape(mul_node.input[index]) != g.get_shape(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) nodes = g.next_nodes(mul_node.output[0]) if len(nodes) != 1 or nodes[0].op_type != "Add": return MatchResult(self, [node, mul_node, None], self.apply, insert_at=nodes[0]) add_node = nodes[0] index = 1 if add_node.input[0] == mul_node.output[0] else 0 if not g.has_shape(add_node.input[index]): return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) if g.get_shape(add_node.input[index]) != g.get_shape(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node, mul_node, add_node], self.apply, insert_at=nodes[0])
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 ln_node: NodeProto, mul_node: NodeProto, add_node: Optional[NodeProto], ) -> List[NodeProto]: # scale scale = mul_node.input[1] if mul_node.input[0] == ln_node.output[0] else mul_node.input[0] new_scale = None if g.is_constant_scalar(ln_node.input[1], broadcast=True): fscale = g.get_constant_scalar(ln_node.input[1], broadcast=True) if fscale == 1: new_scale = scale new_nodes = [] if new_scale is None: new_scale = g.unique_name(f"{self.__class__.__name__}_{ln_node.input[1]}") node = g.make_node( "Mul", [ln_node.input[1], scale], [new_scale], name=f"{self.__class__.__name__}--{ln_node.name}", ) new_nodes.append(node) if add_node: if len(ln_node.input) == 2: new_bias = ( add_node.input[1] if add_node.input[0] == mul_node.output[0] else add_node.input[0] ) else: # there is an existing bias existing_bias = ln_node.input[2] mul_cst = ( mul_node.input[0] if mul_node.input[1] == ln_node.output[0] else mul_node.input[1] ) add_cst = ( add_node.input[0] if add_node.input[1] == mul_node.output[0] else add_node.input[1] ) # new_bias = existing_bias * mul_cst + add_cst temp = g.unique_name(f"{self.__class__.__name__}_{ln_node.input[1]}") new_bias = g.unique_name(f"{self.__class__.__name__}_{ln_node.input[1]}") new_nodes.extend( [ g.make_node( "Mul", [mul_cst, existing_bias], [temp], name=f"{self.__class__.__name__}--{ln_node.name}", ), g.make_node( "Add", [temp, add_cst], [new_bias], name=f"{self.__class__.__name__}--{ln_node.name}", ), ] ) else: new_bias = ln_node.input[2] if len(ln_node.input) > 2 else None kwargs = {} axis = g.get_attribute(ln_node, "axis", exc=None) if axis: kwargs["axis"] = axis.i epsilon = g.get_attribute(ln_node, "epsilon", exc=None) if epsilon: kwargs["epsilon"] = epsilon.f stash_type = g.get_attribute(ln_node, "stash_type", exc=None) if stash_type: kwargs["stash_type"] = stash_type.i new_node = g.make_node( "LayerNormalization", ( [ln_node.input[0], new_scale, new_bias] if new_bias else [ln_node.input[0], new_scale] ), [(add_node or mul_node).output[0]], name=f"{self.__class__.__name__}--{ln_node.name}", doc_string=ln_node.doc_string, **kwargs, ) return [*new_nodes, new_node]
[docs] class CastLayerNormalizationCastPattern(PatternOptimization): """ Checks that a Cast is really needed around LayerNormalization. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("scale", onnx.TensorProto.FLOAT, shape=(3,))) inputs.append(oh.make_tensor_value_info("bias", onnx.TensorProto.FLOAT, shape=(3,))) inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT16, shape=(3, 3))) nodes.append( oh.make_node( "Constant", [], ["scale"], value=onh.from_array( np.array([0.5, 0.6000000238418579, 0.699999988079071], dtype=np.float32), name="value", ), ) ) nodes.append( oh.make_node( "Constant", [], ["bias"], value=onh.from_array( np.array([-0.5, -0.6000000238418579, -0.699999988079071], dtype=np.float32), name="value", ), ) ) nodes.append(oh.make_node("Cast", ["X"], ["xc"], to=1)) nodes.append( oh.make_node( "LayerNormalization", ["xc", "scale", "bias"], ["norm"], stash_type=1 ) ) nodes.append(oh.make_node("Cast", ["norm"], ["Y"], to=10)) outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT16, shape=(3, 3))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("scale", onnx.TensorProto.FLOAT, shape=(3,))) inputs.append(oh.make_tensor_value_info("bias", onnx.TensorProto.FLOAT, shape=(3,))) inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT16, shape=(3, 3))) nodes.append( oh.make_node( "Cast", ["scale"], ["CastLayerNormalizationCastPattern_scale::C10"], to=10 ) ) nodes.append( oh.make_node( "Cast", ["bias"], ["CastLayerNormalizationCastPattern_bias::C10"], to=10 ) ) nodes.append( oh.make_node( "LayerNormalization", [ "X", "CastLayerNormalizationCastPattern_scale::C10", "CastLayerNormalizationCastPattern_bias::C10", ], ["Y"], stash_type=1, ) ) outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT16, shape=(3, 3))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in ( "LayerNormalization", "SimplifiedLayerNormalization", ) or node.domain not in ("", "com.microsoft"): return self.none() if len(node.output) > 1 and g.is_used(node.output[1]): # No need for the scale. return self.none(node, inspect.currentframe().f_lineno) stash_type = g.get_attribute(node, "stash_type", exc=False) stash_itype = 1 if stash_type is None else stash_type.i cast_before = g.node_before(node.input[0]) if cast_before is None or cast_before.op_type != "Cast" or cast_before.domain != "": return self.none(node, inspect.currentframe().f_lineno) to = g.get_attribute(cast_before, "to") if to.i != stash_itype: return self.none(node, inspect.currentframe().f_lineno) if g.is_used_more_than_once(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) cast_afters = g.next_nodes(node.output[0]) if len(cast_afters) != 1: return self.none(node, inspect.currentframe().f_lineno) cast_after = cast_afters[0] if cast_after.op_type != "Cast" or cast_after.domain != "": return self.none(node, inspect.currentframe().f_lineno) to = g.get_attribute(cast_after, "to") itype = g.get_type(cast_before.input[0]) if to.i != itype: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [cast_before, node, cast_after], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 cast_before: NodeProto, node: NodeProto, cast_after: NodeProto, ) -> List[NodeProto]: itype = g.get_type(cast_before.input[0]) other = [] nodes = [] for i in node.input[1:]: name = g.unique_name(f"{self.__class__.__name__}_{i}::C{itype}") other.append(name) nodes.append( g.make_node( "Cast", [i], [name], to=itype, name=f"{self.__class__.__name__}--cast--{node.name}", ) ) new_node = g.make_node( node.op_type, [cast_before.input[0], *other], [cast_after.output[0], *node.output[1:]], name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, domain=node.domain, ) new_node.attribute.extend(node.attribute) return [*nodes, new_node]
[docs] class BatchNormalizationPattern(PatternOptimization): """ Checks that a BatchNormalization is really needed. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(1024, 16))) nodes.append( oh.make_node( "Constant", [], ["scale"], value=onh.from_array( np.array( [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ], dtype=np.float32, ), name="value", ), ) ) nodes.append( oh.make_node( "Constant", [], ["B"], value=onh.from_array( np.array( [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ], dtype=np.float32, ), name="value", ), ) ) nodes.append( oh.make_node( "Constant", [], ["input_mean"], value=onh.from_array( np.array( [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ], dtype=np.float32, ), name="value", ), ) ) nodes.append( oh.make_node( "Constant", [], ["input_var"], value=onh.from_array( np.array( [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ], dtype=np.float32, ), name="value", ), ) ) nodes.append( oh.make_node( "BatchNormalization", ["X", "scale", "B", "input_mean", "input_var"], ["Y"], epsilon=0.0, ) ) outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(1024, 16))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(1024, 16))) nodes.append(oh.make_node("Identity", ["X"], ["Y"])) outputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(1024, 16))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """ def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "BatchNormalization" or node.domain != "": return self.none() if len(node.output) > 1 and g.next_nodes(node.output[1]): return self.none(node, inspect.currentframe().f_lineno) if len(node.output) > 2 and g.next_nodes(node.output[2]): return self.none(node, inspect.currentframe().f_lineno) momentum = 0.9 epsilon = 1e-5 training_mode = 0 for att in node.attribute: if att.name == "momentum": momentum = att.f elif att.name == "epsilon": epsilon = att.f elif att.name == "training_mode": training_mode = att.i if training_mode and momentum != 0: return self.none(node, inspect.currentframe().f_lineno) if epsilon != 0: return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(node.input[2]): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(node.input[3]): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(node.input[3]): return self.none(node, inspect.currentframe().f_lineno) # biases for z in node.input[2:4]: cst = g.get_computed_constant(z) if cst is None: return self.none(node, inspect.currentframe().f_lineno) if cst.min() == cst.max() == 0: continue return self.none(node, inspect.currentframe().f_lineno) # scales for z in [node.input[1], node.input[4]]: cst = g.get_computed_constant(z) if cst is None: return self.none(node, inspect.currentframe().f_lineno) if cst.min() == cst.max() == 1: continue return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, ) -> List[NodeProto]: new_node = g.make_node( "Identity", node.input[:1], node.output[:1], name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node]
[docs] class BatchNormalizationTrainingPattern(PatternOptimization): """Checks that a BatchNormalization in training mode can be avoided.""" def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "BatchNormalization" or node.domain != "": return self.none() if g.main_opset < 18: return self.none(node, inspect.currentframe().f_lineno) if not g.has_rank(node.input[0]) or g.get_rank(node.input[0]) < 2: return self.none(node, inspect.currentframe().f_lineno) if len(node.output) > 1 and ( not g.has_rank(node.input[1]) or g.next_nodes(node.output[1]) ): return self.none(node, inspect.currentframe().f_lineno) if len(node.output) > 2 and ( not g.has_rank(node.input[2]) or g.next_nodes(node.output[2]) ): return self.none(node, inspect.currentframe().f_lineno) momentum = 0.9 training_mode = 0 for att in node.attribute: if att.name == "momentum": momentum = att.f elif att.name == "training_mode": training_mode = att.i if not training_mode and momentum != 1: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, ) -> List[NodeProto]: nname = f"{self.__class__.__name__}--{node.name}" rk = g.get_rank(node.input[0]) axes = tuple(np.delete(np.arange(rk), 1)) init_axes = g.make_initializer( "", np.array(list(axes), dtype=np.int64), source="BatchNormalizationTrainingPattern.apply.init_axes", ) mean_name = g.unique_name(f"{self.__class__.__name__}_mean_{node.input[0]}") mean = g.make_node( "ReduceMean", [node.input[0], init_axes], [mean_name], keepdims=1, name=nname ) centered_name = g.unique_name(f"{self.__class__.__name__}_center_{node.input[0]}") sub = g.make_node("Sub", [node.input[0], mean_name], [centered_name], name=nname) x2 = g.unique_name(f"{self.__class__.__name__}_x2_{node.input[0]}") mul2 = g.make_node("Mul", [centered_name, centered_name], [x2], name=nname) var_name = g.unique_name(f"{self.__class__.__name__}_var_{node.input[0]}") var = g.make_node("ReduceMean", [x2, init_axes], [var_name], keepdims=1, name=nname) dtype = tensor_dtype_to_np_dtype(g.get_type(node.input[0])) epsilon = g.get_attributes_with_default(node, epsilon=1e-5)["epsilon"] init_epsilon = g.make_initializer( "", np.array([epsilon], dtype=dtype), source="BatchNormalizationTrainingPattern.apply.init_epsilon", ) vare_name = g.unique_name(f"{self.__class__.__name__}_vareps_{node.input[0]}") add = g.make_node("Add", [var_name, init_epsilon], [vare_name], name=nname) std_name = g.unique_name(f"{self.__class__.__name__}_vareps_{node.input[0]}") sqrt = g.make_node("Sqrt", [vare_name], [std_name], name=nname) new_shape = [1 for _ in range(rk)] new_shape[1] = -1 new_shape = g.make_initializer( "", np.array(new_shape, dtype=np.int64), source="BatchNormalizationTrainingPattern.apply.new_shape", ) if g.get_rank(node.input[1]) == 1: scale_name = g.unique_name(f"{self.__class__.__name__}_scale_{node.input[1]}") scale = g.make_node("Reshape", [node.input[1], new_shape], [scale_name], name=nname) else: scale_name = node.input[1] scale = None if g.get_rank(node.input[2]) == 1: bias_name = g.unique_name(f"{self.__class__.__name__}_bias_{node.input[2]}") bias = g.make_node("Reshape", [node.input[2], new_shape], [bias_name], name=nname) else: bias_name = node.input[2] bias = None scaled_name = g.unique_name(f"{self.__class__.__name__}_scaled_{node.input[1]}") scaled = g.make_node("Div", [centered_name, std_name], [scaled_name], name=nname) scaled2_name = g.unique_name(f"{self.__class__.__name__}_scaled2_{node.input[2]}") scaled2 = g.make_node("Mul", [scaled_name, scale_name], [scaled2_name], name=nname) final = g.make_node("Add", [scaled2_name, bias_name], [node.output[0]], name=nname) return [ _ for _ in [mean, sub, mul2, var, add, sqrt, scale, bias, scaled, scaled2, final] if _ is not None ]
[docs] class RMSNormalizationPattern(PatternOptimization): """ Fuses the nodes equivalent to RMSNormalization(23). Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 23), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT16, shape=("a", "D")) ) inputs.append(oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=(1,))) nodes.append( oh.make_node( "Constant", [], ["exp"], value=onh.from_array(np.array([2.0], dtype=np.float16), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["axis"], value=onh.from_array(np.array([-1], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["eps"], value=onh.from_array( np.array([9.999999974752427e-07], dtype=np.float32), name="value" ), ) ) nodes.append( oh.make_node( "Constant", [], ["one"], value=onh.from_array(np.array([1.0], dtype=np.float32), name="value"), ) ) nodes.append(oh.make_node("Cast", ["X"], ["Xc"], to=1)) nodes.append(oh.make_node("Pow", ["Xc", "exp"], ["x2"])) nodes.append(oh.make_node("ReduceMean", ["x2", "axis"], ["xr"])) nodes.append(oh.make_node("Add", ["xr", "eps"], ["xa"])) nodes.append(oh.make_node("Sqrt", ["xa"], ["xq"])) nodes.append(oh.make_node("Div", ["one", "xq"], ["Z"])) nodes.append(oh.make_node("Mul", ["Z", "Xc"], ["Yc"])) nodes.append(oh.make_node("Cast", ["Yc"], ["Y"], to=10)) outputs.append( oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT16, shape=("a", "D")) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 23), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT16, shape=("a", "D")) ) inputs.append(oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=(1,))) nodes.append(oh.make_node("Shape", ["X"], ["shape-X"])) nodes.append(oh.make_node("Gather", ["shape-X", "axis"], ["gather-shape-X"])) nodes.append( oh.make_node( "ConstantOfShape", ["gather-shape-X"], ["constantofshape-gather-shape-X"], value=onh.from_array(np.array([1.0], dtype=np.float16), name="value"), ) ) nodes.append( oh.make_node( "RMSNormalization", ["X", "constantofshape-gather-shape-X"], ["Y"], axis=-1, epsilon=9.999999974752427e-07, stash_type=1, ) ) outputs.append( oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT16, shape=("a", "D")) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if g.main_opset < 23: return self.none() if node.op_type != "ReduceMean" or node.domain != "": return self.none() if len(node.input) < 2: return self.none(node, inspect.currentframe().f_lineno) axis = g.get_constant_or_attribute(node, "axes", input_index=1, cvt=tuple) assert isinstance(axis, tuple), f"unexpected type {type(axis)} for axis" if len(axis) != 1: return self.none(node, inspect.currentframe().f_lineno) node_pow = g.node_before(node.input[0]) if node_pow is None: return self.none(node, inspect.currentframe().f_lineno) if node_pow.op_type != "Pow" or node.domain != "": return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(node_pow.input[1], 2): return self.none(node, inspect.currentframe().f_lineno) node_add = g.next_node(node.output[0]) if node_add.op_type != "Add" or node_add.domain != "": return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(node_add.input[0]) and not g.is_constant_scalar( node_add.input[1] ): return self.none(node, inspect.currentframe().f_lineno) node_sqrt = g.next_node(node_add.output[0]) if node_sqrt.op_type != "Sqrt" or node_sqrt.domain != "": return self.none(node, inspect.currentframe().f_lineno) node_reciprocal = g.next_node(node_sqrt.output[0]) if node_reciprocal.op_type not in ("Reciprocal", "Div") or node_reciprocal.domain != "": return self.none(node, inspect.currentframe().f_lineno) if node_reciprocal.op_type == "Div": if node_reciprocal.input[1] != node_sqrt.output[0]: return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(node_reciprocal.input[0], 1): return self.none(node, inspect.currentframe().f_lineno) node_mul = g.next_node(node_reciprocal.output[0]) if node_mul.op_type != "Mul" or node_mul.domain != "": return self.none(node, inspect.currentframe().f_lineno) if ( g.is_used_more_than_once(node_pow.output[0]) or g.is_used_more_than_once(node.output[0]) or g.is_used_more_than_once(node_add.output[0]) or g.is_used_more_than_once(node_sqrt.output[0]) ): # intermediate results are used return self.none(node, inspect.currentframe().f_lineno) mul_i = set(node_mul.input) cmp = {node_pow.input[0], node_reciprocal.output[0]} if mul_i != cmp: # We check the multiplication node takes the output of the div node # and the input of the pow node. return self.none(node, inspect.currentframe().f_lineno) cast_1 = g.node_before(node_pow.input[0]) if cast_1 is not None and cast_1.op_type != "Cast": cast_1 = None cast_2 = None if cast_1 is not None: to = g.get_attribute(cast_1, "to").i if to == g.get_type(node.input[0]): cast_2 = g.next_node(node_mul.output[0]) if cast_2 is None: cast_1 = None else: cast_1 = None nodes = [ cast_1, node_pow, node, node_add, node_sqrt, node_reciprocal, node_mul, cast_2, ] return MatchResult(self, nodes, self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 cast_1: NodeProto, node_pow: NodeProto, node_reduce: NodeProto, node_add: NodeProto, node_sqrt: NodeProto, node_reciprocal: NodeProto, node_mul: NodeProto, cast_2: NodeProto, ) -> List[NodeProto]: nname = node_reduce.name nodes = [] epsilon = g.get_computed_constant(node_add.input[1]) shape = g.get_shape(node_reduce.input[0]) if g.has_shape(node_reduce.input[0]) else None axis = g.get_constant_or_attribute(node_reduce, "axes", input_index=1)[0] assert shape is None or axis < len( shape ), f"axis={axis} and shape={shape} don't match for {node_reduce.input[0]!r}" stash_type = g.get_type(node_reduce.input[0]) dtype = tensor_dtype_to_np_dtype( stash_type if cast_1 is None else g.get_type(cast_1.input[0]) ) input_name = node_pow.input[0] if cast_1 is None else cast_1.input[0] if shape is not None and isinstance(shape[axis], int): # a constant scale = g.make_initializer( f"ONES{shape[axis]}", np.ones((shape[axis],), dtype=dtype), source="RMSNormalization.apply.scale.1", ) else: sh = ( g.make_node( "Shape", [input_name], start=axis, end=axis + 1, name=f"{self.__class__.__name__}--{nname}", ) if axis != -1 else g.make_node( "Shape", [input_name], start=axis, name=f"{self.__class__.__name__}--{nname}" ) ) cc = g.make_node( "ConstantOfShape", [sh.output[0]], value=from_array_extended(np.array([1], dtype=dtype)), name=f"{self.__class__.__name__}--{nname}", ) scale = cc.output[0] nodes.extend([sh, cc]) layer = g.make_node( "RMSNormalization", [input_name, scale], [node_mul.output[0] if cast_2 is None else cast_2.output[0]], epsilon=float(epsilon[0] if epsilon.shape else epsilon), axis=int(axis), stash_type=stash_type, name=f"{self.__class__.__name__}--{nname}", ) nodes.append(layer) return nodes