Source code for experimental_experiment.xoptim.patterns.onnx_cast

import inspect
from typing import List, Optional
from onnx import NodeProto, TensorProto
from ..patterns_api import MatchResult, PatternOptimization


[docs] class CastPattern(PatternOptimization): """ Checks that a Cast is really needed. """ def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Cast" or node.domain != "": return self.none() if not g.has_type(node.input[0]): itype = g.try_infer_type(node.input[0], exc=False) if itype == 0: return self.none(node, inspect.currentframe().f_lineno) else: itype = g.get_type(node.input[0]) att = g.get_attribute(node, "to") if att.i != itype: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, ) -> List[NodeProto]: new_node = g.make_node( "Identity", node.input, node.output, name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node]
[docs] class CastCastBinaryPattern(PatternOptimization): """ Moves two cast operators beyond a binary operator The cast must cast from a float type to another float type. """ _dtypes_allowed = { TensorProto.FLOAT16, TensorProto.BFLOAT16, TensorProto.FLOAT, TensorProto.DOUBLE, }
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in {"Add", "Div", "Mul", "Sub"} or node.domain != "": return self.none() if g.is_used_more_than_once(node.input[0]) or g.is_used_more_than_once(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) if not g.has_type(node.input[0]) or not g.has_type(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) dtype_left, dtype_right = g.get_type(node.input[0]), g.get_type(node.input[1]) if dtype_left not in self._dtypes_allowed or dtype_right not in self._dtypes_allowed: return self.none(node, inspect.currentframe().f_lineno) left, right = g.node_before(node.input[0]), g.node_before(node.input[1]) if left is None or left.op_type != "Cast" or left.domain != "": return self.none(node, inspect.currentframe().f_lineno) if right is None or right.op_type != "Cast" or right.domain != "": return self.none(node, inspect.currentframe().f_lineno) dtype_left, dtype_right = g.get_type(left.input[0]), g.get_type(right.input[0]) if dtype_left not in self._dtypes_allowed or dtype_right not in self._dtypes_allowed: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [left, right, node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 left: NodeProto, right: NodeProto, node: NodeProto, ) -> List[NodeProto]: to = g.get_attribute(left, "to") new_node = g.make_node( node.op_type, [left.input[0], right.input[0]], name=f"{self.__class__.__name__}--{node.name}", ) cast_node = g.make_node( "Cast", new_node.output, node.output, to=to.i, name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node, cast_node]
[docs] class CastOpCastPattern(PatternOptimization): """ Removes two cast surrounding another operator. """ _dtypes_allowed = { TensorProto.FLOAT16, TensorProto.BFLOAT16, TensorProto.FLOAT, TensorProto.DOUBLE, } _unary_ops = {"MulSigmoid", "Neg", "Sigmoid", "Softmax"} _binary_ops = {"Add", "Sub", "Mul", "Div"} _other_ops = {"SoftmaxGrad"}
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if ( (node.op_type not in self._binary_ops or node.domain != "") and node.op_type not in self._other_ops and node.op_type not in self._unary_ops ): return self.none() if "ComputationCastOpCastPattern--" in node.name: return self.none(node, inspect.currentframe().f_lineno) if g.is_used_more_than_once(node.input[0]) or ( len(node.input) > 1 and g.is_used_more_than_once(node.input[1]) ): return self.none(node, inspect.currentframe().f_lineno) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) != 1: return self.none(node, inspect.currentframe().f_lineno) cast_out_node = next_nodes[0] if cast_out_node.op_type != "Cast" or cast_out_node.domain != "": return self.none(node, inspect.currentframe().f_lineno) cast_in_left = g.node_before(node.input[0]) cast_in_right = g.node_before(node.input[1]) if len(node.input) > 1 else None if "Cast" not in ( "" if cast_in_left is None else cast_in_left.op_type, "" if cast_in_right is None else cast_in_right.op_type, ): return self.none(node, inspect.currentframe().f_lineno) if cast_out_node is None and ( cast_in_left is None or cast_in_left.op_type != "Cast" or cast_in_right is None or cast_in_right.op_type != "Cast" ): # Then we only allow this if the computation type is lower precision. compute_type = g.get_type(node.output[0]) before_type = g.get_type((cast_in_left or cast_in_right).input[0]) if not ( compute_type == TensorProto.FLOAT and before_type in (TensorProto.FLOAT16, TensorProto.BFLOAT16) ): return self.none(node, inspect.currentframe().f_lineno) return MatchResult( self, [ ( cast_in_left if cast_in_left is not None and cast_in_left.op_type == "Cast" else None ), ( cast_in_right if cast_in_right is not None and cast_in_right.op_type == "Cast" else None ), node, cast_out_node, ], self.apply, insert_at=node, )
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 cast_in_left: NodeProto, cast_in_right: NodeProto, node: NodeProto, cast_out_node: NodeProto, ) -> List[NodeProto]: # to = g.get_attribute(cast_in_left or cast_in_right, "to").i to_out = g.get_attribute(cast_out_node, "to").i left_input = None right_input = None new_nodes = [] if cast_in_left is None: left_input = g.unique_name(f"{self.__class__.__name__}--{node.output[0]}") new_nodes.append( g.make_node( "Cast", [node.input[0]], [left_input], to=to_out, name=f"{self.__class__.__name__}--CastL", ) ) else: left_input = cast_in_left.input[0] if cast_in_right is None and len(node.input) > 1: right_input = g.unique_name(f"{self.__class__.__name__}--{node.output[0]}") new_nodes.append( g.make_node( "Cast", [node.input[1]], [right_input], to=to_out, name=f"{self.__class__.__name__}--CastR", ) ) else: right_input = None if cast_in_right is None else cast_in_right.input[0] new_node = g.make_node( node.op_type, [left_input] if right_input is None else [left_input, right_input], cast_out_node.output, domain=node.domain, name=f"{self.__class__.__name__}--{node.name}", ) if node.attribute: new_node.attribute.extend(node.attribute) new_nodes.append(new_node) if g.is_used_more_than_once(node.output[0]): final_cast = g.make_node( "Cast", [new_node.output[0]], [node.output[0]], to=g.get_type(node.output[0]), name=f"{self.__class__.__name__}--{node.name}", ) new_nodes.append(final_cast) return new_nodes
[docs] class ComputationCastOpCastPattern(PatternOptimization): """ Changes the computation type to make it faster if one of the inputs was just casted before. """ _dtypes_allowed = { TensorProto.FLOAT16, TensorProto.BFLOAT16, TensorProto.FLOAT, TensorProto.DOUBLE, } _binary_ops = {"Add", "Sub", "Mul", "Div"}
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in self._binary_ops or node.domain != "": return self.none() node_left = g.node_before(node.input[0]) node_right = g.node_before(node.input[1]) type_left = "" if node_left is None else node_left.op_type type_right = "" if node_right is None else node_right.op_type if not ((type_left == "Cast") ^ (type_right == "Cast")): # only one cast allowed return self.none(node, inspect.currentframe().f_lineno) if type_left: node_right = None else: node_left = None node_before = node_left or node_right if not g.has_type(node.output[0]) or not g.has_type(node_before.input[0]): return self.none(node, inspect.currentframe().f_lineno) output_type = g.get_type(node.output[0]) before_type = g.get_type(node_before.input[0]) if not ( output_type == TensorProto.FLOAT and before_type in (TensorProto.FLOAT16, TensorProto.BFLOAT16) ): return self.none(node, inspect.currentframe().f_lineno) if g.is_used_more_than_once(node_before.output[0]): return self.none(node, inspect.currentframe().f_lineno) next_nodes = g.next_nodes(node.output[0]) op_types = [n.op_type for n in next_nodes] if "Cast" in op_types: return self.none(node, inspect.currentframe().f_lineno) # At this stage, we know the computation type is float and one input # has a lower type precision. Let's change it. return MatchResult(self, [node_left, node_right, node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node_left: Optional[NodeProto], node_right: Optional[NodeProto], node: NodeProto, ) -> List[NodeProto]: to_type = g.get_type(node.output[0]) inputs = [] if node_left is None: before_type = g.get_type(node_right.input[0]) name = g.unique_name(f"{self.__class__.__name__}--{node.input[0]}") cast_node = g.make_node( "Cast", [node.input[0]], [name], to=before_type, name=f"{self.__class__.__name__}--{node.name}", ) inputs = [name, node_right.input[0]] else: before_type = g.get_type(node_left.input[0]) name = g.unique_name(f"{self.__class__.__name__}--{node.input[1]}") cast_node = g.make_node( "Cast", [node.input[1]], [name], to=before_type, name=f"{self.__class__.__name__}--{node.name}", ) inputs = [node_left.input[0], name] name = g.unique_name(f"{self.__class__.__name__}--{node.output[0]}") new_node = g.make_node( node.op_type, inputs, [name], domain=node.domain, name=f"{self.__class__.__name__}--{node.name}", ) final_cast = g.make_node( "Cast", [name], [node.output[0]], to=to_type, name=f"{self.__class__.__name__}--{node.name}", ) return [cast_node, new_node, final_cast]