Source code for experimental_experiment.xoptim.patterns_ort.missing_kernels

import inspect
from typing import List, Optional
from onnx import NodeProto, TensorProto
from ..patterns_api import MatchResult, PatternOptimization


[docs] class MissingRangePattern(PatternOptimization): """Replaces Range by Cast Range Cast because of some missing kernels."""
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Range" or node.domain != "": return self.none() if not g.has_type(node.input[0]) or g.get_type(node.input[0]) in { TensorProto.INT32, TensorProto.INT64, TensorProto.FLOAT, }: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, ) -> List[NodeProto]: to = g.get_type(node.input[0]) other_to = ( TensorProto.FLOAT if to in {TensorProto.FLOAT16, TensorProto.BFLOAT16} else TensorProto.INT32 ) nodes = [] new_inputs = [] for i in node.input: n = g.unique_name(f"{self.__class__.__name__}--{i}") nodes.append( g.make_node( "Cast", [i], [n], to=other_to, name=f"{self.__class__.__name__}--Cast" ) ) new_inputs.append(n) n = g.unique_name(f"{self.__class__.__name__}--{i}") nodes.append( g.make_node( node.op_type, new_inputs, [n], domain=node.domain, name=f"{self.__class__.__name__}", ) ) nodes.append( g.make_node( "Cast", [n], [node.output[0]], to=to, name=f"{self.__class__.__name__}--Cast" ) ) return nodes
[docs] class MissingCosSinPattern(PatternOptimization): """Replaces Cos/Sin by Cast Cos/Sin Cast because of some missing kernels."""
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in {"Cos", "Sin", "Log"} or node.domain != "": return self.none() if not g.has_type(node.input[0]) or g.get_type(node.input[0]) in { TensorProto.DOUBLE, TensorProto.FLOAT16, TensorProto.FLOAT, }: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, ) -> List[NodeProto]: to = g.get_type(node.input[0]) other_to = TensorProto.FLOAT n1 = g.unique_name(f"{self.__class__.__name__}--{node.input[0]}") n2 = g.unique_name(f"{self.__class__.__name__}--{node.output[0]}") return [ g.make_node( "Cast", [node.input[0]], [n1], to=other_to, name=f"{self.__class__.__name__}--Cast", ), g.make_node( node.op_type, [n1], [n2], domain=node.domain, name=f"{self.__class__.__name__}", ), g.make_node( "Cast", [n2], [node.output[0]], to=to, name=f"{self.__class__.__name__}--Cast", ), ]