[docs]classOpRunOpSequence(OpRun):"Ancestor for kernel using sequences."
[docs]classConcatFromSequence_11(OpRunOpSequence):"ConcatFromSequence"def__init__(self,node:onnx.NodeProto,version:Optional[int]=None):super().__init__(node,version)axis=self.get_attribute_int(node,"axis",None)assertisinstance(axis,int),f"Unexpected value for attribute axis={axis!r}"self.axis=axisself.new_axis=self.get_attribute_int(node,"new_axis",0)
[docs]defrun(self,input_sequence:OpRunSequence)->OpRunTensor:assertisinstance(input_sequence,OpRunSequence),f"Unexpected type {type(input_sequence)} for input_sequence"seq=input_sequence.sequenceifself.new_axis==1:ifself.axis==-1:seq2=[s.unsqueeze(len(s.shape))forsinseq]res=torch.cat(seq2,axis=-1)else:seq2=[s.expand(self.axis)forsinseq]res=torch.cat(seq2,axis=self.axis)else:res=torch.cat(seq,axis=self.axis)returnOpRunTensor(res)
[docs]defrun(self,input_sequence:OpRunSequence,tensor:OpRunTensor,position:Optional[OpRunTensor]=None,)->OpRunSequence:assertisinstance(input_sequence,OpRunSequence),(f"Unexpected type {type(input_sequence)} for input_sequence: "f"{input_sequence.string_type()}")returninput_sequence.insert_at(tensor,position)