[docs]classSoftmaxGradPattern(PatternOptimization):""" Replaces the sequence Mul, ReduceSum, Mul, Sub by SoftmaxGrad """
[docs]defmatch(self,g:"GraphBuilderPatternOptimization",# noqa: F821node:NodeProto,matched:List[MatchResult],)->Optional[MatchResult]:ifnode.op_type!="ReduceSum"ornode.domain!="":returnself.none()axis=g.get_constant_or_attribute(node,"axes",input_index=1,cvt=tuple)assertisinstance(axis,tuple),f"unexpected type {type(axis)} for axis"iflen(axis)!=1:returnself.none(node,inspect.currentframe().f_lineno)mul_node=g.node_before(node.input[0])ifmul_node.op_type!="Mul"ormul_node.domain!="":returnself.none(node,inspect.currentframe().f_lineno)next_mul_node=g.next_node(node.output[0])ifnext_mul_node.op_type!="Mul"ornext_mul_node.domain!="":returnself.none(node,inspect.currentframe().f_lineno)sub_node=g.next_node(next_mul_node.output[0])ifsub_node.op_type!="Sub"orsub_node.domain!="":returnself.none(node,inspect.currentframe().f_lineno)next_nodes=g.next_nodes(mul_node.output[0])iflen(next_nodes)!=2:returnself.none(node,inspect.currentframe().f_lineno)if{id(next_nodes[0]),id(next_nodes[1])}!={id(sub_node),id(node)}:returnself.none(node,inspect.currentframe().f_lineno)ifg.is_used_more_than_once(next_mul_node.output[0])org.is_used_more_than_once(node.output[0]):returnself.none(node,inspect.currentframe().f_lineno)nodes=[mul_node,node,next_mul_node,sub_node]returnMatchResult(self,nodes,self.apply,insert_at=node)