[docs]classOpRunControlFlow(OpRun):"""Common ancestor for control flows."""
[docs]@classmethoddefhas_subgraphs(cls)->bool:"""Returns True if the kernel has subgraphs."""returnTrue
def__init__(self,node:onnx.NodeProto,version:Optional[int]=None,parent:Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"]=None,# noqa: F821):super().__init__(node,version)assert(parentisnotNone),f"parent must be specified for operator {self.__class__.__name__!r}"forattinnode.attribute:ifatt.type==onnx.AttributeProto.GRAPH:rt=parent.__class__(att.g,providers=parent.providers,opsets=parent.opsets,local_functions=parent.functions,verbose=parent.verbose,)setattr(self,att.name,rt)
[docs]defrun(self,M,cond,*args,context:Optional[Dict[str,Any]]=None):ifargs:v_initial=args[0]args=args[1:]else:v_initial=NoneassertMisNoneorhasattr(M,"dtype"),f"M must be empty or an array but its type is {type(M)}."body=self.bodyloop_inputs=body.input_namesinputs=dict.fromkeys(loop_inputs)ifv_initialisnotNone:inputs[loop_inputs[2]]=v_initialcond_name=body.output_names[0]ifargs:begin=len(loop_inputs)-len(args)all_inputs=loop_inputs[begin:]forname,valinzip(all_inputs,args):inputs[name]=valifcontextisnotNone:foraincontext:inputs[a]=context[a]k_carried_away=[[]foriinrange(self.K)]# type: ignoreit=0while(condisNoneorcond.tensorisNoneorcond.tensor.item())and(MisNoneorM.tensorisNoneorit<M.tensor.item()):iflen(body.input_names)>0andbody.input_names[0]isnotNone:inputs[body.input_names[0]]=OpRunTensor(torch.tensor(it,dtype=NoneifMisNoneelseM.dtype))iflen(body.input_names)>1andbody.input_names[1]isnotNone:inputs[body.input_names[1]]=condoutputs=list(self.body.run_with_values(*[inputs[k]forkinself.body.input_names],context=context))ifself.K>0:forkinrange(self.K):k_carried_away[k].append(outputs[-self.K+k])index_cond=self.output_index[cond_name]cond=outputs[index_cond]assert(condisnotNone),f"Condition {cond_name!r} returned by the subgraph cannot be None."fori,oinzip(body.input_names[2:],body.output_names[1:]):inputs[i]=outputs[self.output_index[o]]it+=1ifit==0:outputs=[inputs[i]foriinbody.input_names[2:]]else:outputs=outputs[1:1+self.N]outputs.extend([OpRunTensor(torch.cat(x,axis=0))forxink_carried_away])whilelen(outputs)<len(self.body.output_names):outputs.append(OpRunTensor(torch.empty(())))returntuple(outputs)