Source code for experimental_experiment.xoptim.patterns_exp.unary_operators

import inspect
from typing import List, Optional
from onnx import NodeProto, TensorProto
from ..patterns_api import MatchResult, PatternOptimization


[docs] class TransposeCastPattern(PatternOptimization): """ Replaces Cast + Transpose or Transpose + Cast into Transpose2DCast16 or Transpose2DCastFP32 depending on the output type. """ _allowed_types = (TensorProto.FLOAT, TensorProto.FLOAT16) def __init__(self, verbose: int = 0, priority: int = 3): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if not g.has_processor("CUDA"): return self.none() if node.op_type != "Transpose" or node.domain != "": return self.none() perm = list(g.get_attribute(node, "perm").ints) if perm != [1, 0]: return self.none(node, inspect.currentframe().f_lineno) if not g.has_type(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) if g.get_type(node.input[0]) not in self._allowed_types: return self.none(node, inspect.currentframe().f_lineno) cast_node_before = g.node_before(node.input[0]) if ( cast_node_before is None or cast_node_before.op_type != "Cast" or cast_node_before.domain != "" or g.is_used_more_than_once(node.input[0]) or not g.has_type(cast_node_before.input[0]) or g.get_type(cast_node_before.input[0]) not in self._allowed_types ): cast_node_before = None if cast_node_before is not None: return MatchResult( self, [cast_node_before, node, None], self.apply, insert_at=node ) cast_node_after = g.next_nodes(node.output[0]) if ( len(cast_node_after) != 1 or cast_node_after[0].op_type != "Cast" or cast_node_after[0].domain != "" or g.is_used_more_than_once(node.output[0]) or not g.has_type(cast_node_after[0].output[0]) or g.get_type(cast_node_after[0].output[0]) not in self._allowed_types ): cast_node_after = None if cast_node_after is not None: return MatchResult( self, [None, node, cast_node_after[0]], self.apply, insert_at=node ) return self.none(node, inspect.currentframe().f_lineno)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 cast_node_before: Optional[NodeProto], node: NodeProto, cast_node_after: Optional[NodeProto], ) -> List[NodeProto]: out_type = ( g.get_type(node.output[0]) if cast_node_after is None else g.get_type(cast_node_after.output[0]) ) if out_type == TensorProto.FLOAT: suffix = "32" elif out_type == TensorProto.FLOAT16: suffix = "16" else: raise AssertionError(f"out_type={out_type} must be in {self._allowed_types}") new_node = g.make_node( f"Transpose2DCastFP{suffix}", node.input if cast_node_before is None else cast_node_before.input, node.output if cast_node_after is None else cast_node_after.output, domain="onnx_extended.ortops.optim.cuda", name=f"{self.__class__.__name__}--{node.name}", ) return [new_node]