[docs]classSlicesSplitPattern(PatternOptimization):""" Detects multiple slices into a split. """
[docs]defmatch(self,g:"GraphBuilderPatternOptimization",# noqa: F821node:NodeProto,matched:List[MatchResult],)->Optional[MatchResult]:ifnode.op_type!="Slice"ornode.domain!="":returnself.none()ifnotg.has_shape(node.input[0]):returnself.none(node,inspect.currentframe().f_lineno)users=[opforoping.next_nodes(node.input[0])ifop.op_type=="Slice"andop.domain==""]iflen(users)<=1:returnself.none(node,inspect.currentframe().f_lineno)foruserinusers:iflen(user.input)==4:continueiflen(user.input)==5:ifnotg.is_constant_scalar(user.input[-1]):returnself.none(node,inspect.currentframe().f_lineno)scalar=g.get_constant_scalar(user.input[-1])ifscalar!=1:returnself.none(node,inspect.currentframe().f_lineno)continuereturnself.none(node,inspect.currentframe().f_lineno)# axisifall(len(op.input)==2foropinusers):axis=0else:axes=[op.input[3]foropinusers]ifany(notg.is_constant_scalar(a)forainaxes):returnself.none(node,inspect.currentframe().f_lineno)csts=[g.get_constant_scalar(a)forainaxes]iflen(set(csts))!=1:returnself.none(node,inspect.currentframe().f_lineno)axis=csts[0]shape=g.get_shape(node.input[0])dim=shape[axis]ifnotisinstance(dim,int):returnself.none(node,inspect.currentframe().f_lineno)# starts, endsstarts=[op.input[1]foropinusers]ends=[op.input[2]foropinusers]ifnotg.is_constant_scalar(starts[0],0):returnself.none(node,inspect.currentframe().f_lineno)ifnotg.is_constant_scalar(ends[-1]):returnself.none(node,inspect.currentframe().f_lineno)last=g.get_constant_scalar(ends[-1])iflastnotin(dim,9223372036854775807):# 9223372036854775807 is what torch uses to specify the endreturnself.none(node,inspect.currentframe().f_lineno)ifany(notg.is_constant(i)foriinstarts)orany(notg.is_constant(i)foriinends):# no constantsreturnself.none(node,inspect.currentframe().f_lineno)cst_starts=[Noneforainstarts]cst_ends=[Noneforainends]foriinrange(len(starts)-1):ifends[i]==starts[i+1]:continueend=cst_ends[i]org.get_computed_constant(ends[i])start=cst_starts[i+1]org.get_computed_constant(starts[i+1])ifall(end==start):cst_ends[i]=endcst_starts[i+1]=startcontinuereturnself.none(node,inspect.currentframe().f_lineno)returnMatchResult(self,users,self.apply)
[docs]defapply(self,g:"GraphBuilder",# noqa: F821*nodes:NodeProto,)->List[NodeProto]:# nodes are all slicesstarts=[op.input[1]foropinnodes]ends=[op.input[2]foropinnodes]cst_starts=[g.get_constant_scalar(a)forainstarts]cst_ends=[g.get_constant_scalar(a)forainends]axis=g.get_constant_scalar(nodes[0].input[3])ifcst_ends[-1]==9223372036854775807:# 9223372036854775807 is what torch uses to specify the endshape=g.get_shape(nodes[0].input[0])cst_ends[-1]=shape[axis]n_els=[cst_ends[i]-cst_starts[i]foriinrange(len(starts))]splits=g.make_initializer("",np.array(n_els,dtype=np.int64))outputs=[op.output[0]foropinnodes]node=g.make_node("Split",[nodes[0].input[0],splits],outputs,axis=axis,name=f"{self.__class__.__name__}--{nodes[0].name}",)return[node]