Source code for onnx_diagnostic.reference.ops.op_complex
import numpy as np
from onnx.reference.op_run import OpRun
[docs]
class ToComplex(OpRun):
op_domain = "ai.onnx.complex"
def _run(self, x):
assert x.shape[-1] in (
1,
2,
), f"Unexpected shape {x.shape}, it should a tensor (..., 2)"
if x.shape[-1] == 1:
return (x[..., 0] + 0j,)
return (x[..., 0] + 1j * x[..., 1],)
[docs]
class ComplexModule(OpRun):
op_domain = "ai.onnx.complex"
def _run(self, x):
assert x.dtype in (
np.complex64,
np.complex128,
), f"Unexpected type {x.dtype}, it should a complex tensor"
return (np.abs(x),)