[docs]classTransposeCastPattern(PatternOptimization):""" Replaces Cast + Transpose or Transpose + Cast into Transpose2DCast16 or Transpose2DCastFP32 depending on the output type. """_allowed_types=(TensorProto.FLOAT,TensorProto.FLOAT16)def__init__(self,verbose:int=0,priority:int=3):super().__init__(verbose,priority)
[docs]defapply(self,g:"GraphBuilder",# noqa: F821cast_node_before:Optional[NodeProto],node:NodeProto,cast_node_after:Optional[NodeProto],)->List[NodeProto]:out_type=(g.get_type(node.output[0])ifcast_node_afterisNoneelseg.get_type(cast_node_after.output[0]))ifout_type==TensorProto.FLOAT:suffix="32"elifout_type==TensorProto.FLOAT16:suffix="16"else:raiseAssertionError(f"out_type={out_type} must be in {self._allowed_types}")new_node=g.make_node(f"Transpose2DCastFP{suffix}",node.inputifcast_node_beforeisNoneelsecast_node_before.input,node.outputifcast_node_afterisNoneelsecast_node_after.output,domain="onnx_extended.ortops.optim.cuda",name=f"{self.__class__.__name__}--{node.name}",)return[new_node]