Source code for experimental_experiment.reference.ops.op_cast_like
fromonnx.helperimportnp_dtype_to_tensor_dtypefromonnx.onnx_pbimportTensorProtofromonnx.reference.op_runimportOpRunfromonnx.reference.ops.op_castimport(bfloat16,cast_to,float8e4m3fn,float8e4m3fnuz,float8e5m2,float8e5m2fnuz,)def_cast_like(x,y,saturate):ify.dtype==bfloat16andy.dtype.descr[0][0]=="bfloat16":# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16to=TensorProto.BFLOAT16elify.dtype==float8e4m3fnandy.dtype.descr[0][0]=="e4m3fn":to=TensorProto.FLOAT8E4M3FNelify.dtype==float8e4m3fnuzandy.dtype.descr[0][0]=="e4m3fnuz":to=TensorProto.FLOAT8E4M3FNUZelify.dtype==float8e5m2andy.dtype.descr[0][0]=="e5m2":to=TensorProto.FLOAT8E5M2elify.dtype==float8e5m2fnuzandy.dtype.descr[0][0]=="e5m2fnuz":to=TensorProto.FLOAT8E5M2FNUZelse:to=np_dtype_to_tensor_dtype(y.dtype)# type: ignorereturn(cast_to(x,to,saturate),)