Source code for experimental_experiment.reference.ops.op_rotary
fromonnx.reference.op_runimportOpRun
[docs]classRotary(OpRun):op_domain="onnx_extended.ortops.optim.cuda"def_run(self,X,splits=None,side=None):assert(splitsisNoneorsplits.shape==(2,)andsplits[0]==splits[1]),f"Unexpected split value {splits}"last_dim=X.shape[-1]//2cp=X.copy()ifside=="left":cp[...,:last_dim]=X[...,last_dim:]cp[...,last_dim:]=-X[...,:last_dim]else:cp[...,:last_dim]=-X[...,last_dim:]cp[...,last_dim:]=X[...,:last_dim]return(cp,)