[docs]classFusedConvPattern(PatternOptimization):""" Replaces the Conv + Relu into FusedConv. """def__init__(self,verbose:int=0,priority:int=2):super().__init__(verbose,priority)
[docs]defmatch(self,g:"GraphBuilderPatternOptimization",# noqa: F821node:NodeProto,matched:List[MatchResult],)->Optional[MatchResult]:ifnode.op_type!="Conv"ornode.domain!="":returnself.none()next_nodes=g.next_nodes(node.output[0])iflen(next_nodes)!=1:returnself.none(node,inspect.currentframe().f_lineno)op_type=next_nodes[0].op_typeifop_type!="Relu":returnself.none(node,inspect.currentframe().f_lineno)# FusedConv only exists for float32.dtypes=[(g.get_type(i)ifg.has_type(i)elseNone)foriinnode.input]ifTensorProto.FLOATnotindtypes:returnself.none(node,inspect.currentframe().f_lineno)returnMatchResult(self,[node,next_nodes[0]],self.apply,insert_at=next_nodes[0])