[docs]classShapeBasedShapeShapeAddPattern(PatternOptimization):"""Tries to find another to get a dimension obtained with the addition of two."""def__init__(self,verbose:int=0,priority:int=0):super().__init__(verbose,priority)
[docs]defmatch(self,g:"GraphBuilderPatternOptimization",# noqa: F821node:NodeProto,matched:List[MatchResult],)->Optional[MatchResult]:ifnode.op_type!="Add"ornode.domain!="":returnself.none()shape1=g.node_before(node.input[0])ifshape1isNoneorshape1.op_type!="Shape"orshape1.domain!="":returnself.none(node,inspect.currentframe().f_lineno)shape2=g.node_before(node.input[1])ifshape2isNoneorshape2.op_type!="Shape"orshape2.domain!="":returnself.none(node,inspect.currentframe().f_lineno)# ishape1 = g.get_shape_renamed(shape1.input[0])# ishape2 = g.get_shape_renamed(shape2.input[0])# value1 = g.builder.value_as_shape(node.input[0])# value2 = g.builder.value_as_shape(node.input[1])# input_shapes = [g.get_shape_renamed(i) for i in g.builder.input_names]# g.builder._known_value_shape# g.builder.constraints_)# g.builder.replacements_dimensions_returnself.none(node,inspect.currentframe().f_lineno)
[docs]defapply(self,g:"GraphBuilder",# noqa: F821shape1_node:NodeProto,shape2_node:NodeProto,add_node:NodeProto,)->List[NodeProto]:raiseNotImplementedError(f"{self.___class__.__name__} is not implemented yet.")