onnx-array-api: APIs to create ONNX Graphs#
onnx-array-api implements APIs to create custom ONNX graphs. The objective is to speed up the implementation of converter libraries.
Numpy API
Sources available on github/onnx-array-api.
<<<
import numpy as np # A
from onnx_array_api.npx import absolute, jit_onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
def l1_loss(x, y):
return absolute(x - y).sum()
def l2_loss(x, y):
return ((x - y) ** 2).sum()
def myloss(x, y):
return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
jitted_myloss = jit_onnx(myloss)
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
res = jitted_myloss(x, y)
print(res)
print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
>>>
0.042
opset: domain='' version=18
input: name='x0' type=dtype('float32') shape=['', '']
input: name='x1' type=dtype('float32') shape=['', '']
Constant(value=[1]) -> cst__0
Constant(value=[2]) -> cst__1
Constant(value=[1]) -> cst__2
Slice(x0, cst__0, cst__1, cst__2) -> r__12
Constant(value=[1]) -> cst__3
Constant(value=[2]) -> cst__4
Constant(value=[1]) -> cst__5
Slice(x1, cst__3, cst__4, cst__5) -> r__14
Constant(value=[0]) -> cst__6
Constant(value=[1]) -> cst__7
Constant(value=[1]) -> cst__8
Slice(x0, cst__6, cst__7, cst__8) -> r__16
Constant(value=[0]) -> cst__9
Constant(value=[1]) -> cst__10
Constant(value=[1]) -> cst__11
Slice(x1, cst__9, cst__10, cst__11) -> r__18
Constant(value=[1]) -> cst__13
Squeeze(r__12, cst__13) -> r__20
Constant(value=[1]) -> cst__15
Squeeze(r__14, cst__15) -> r__21
Sub(r__20, r__21) -> r__24
Constant(value=[1]) -> cst__17
Squeeze(r__16, cst__17) -> r__22
Constant(value=[1]) -> cst__19
Squeeze(r__18, cst__19) -> r__23
Sub(r__22, r__23) -> r__25
Abs(r__25) -> r__28
ReduceSum(r__28, keepdims=0) -> r__30
Constant(value=2) -> r__26
CastLike(r__26, r__24) -> r__27
Pow(r__24, r__27) -> r__29
ReduceSum(r__29, keepdims=0) -> r__31
Add(r__30, r__31) -> r__32
output: name='r__32' type=dtype('float32') shape=None
Light API
<<<
import numpy as np
from onnx_array_api.light_api import start
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
model = (
start()
.vin("X")
.vin("Y")
.bring("X", "Y")
.Sub()
.rename("dxy")
.cst(np.array([2], dtype=np.int64), "two")
.bring("dxy", "two")
.Pow()
.ReduceSum()
.rename("Z")
.vout()
.to_onnx()
)
print(onnx_simple_text_plot(model))
>>>
opset: domain='' version=20
input: name='X' type=dtype('float32') shape=None
input: name='Y' type=dtype('float32') shape=None
init: name='two' type=dtype('int64') shape=(1,) -- array([2])
Sub(X, Y) -> dxy
Pow(dxy, two) -> r1_0
ReduceSum(r1_0, keepdims=1, noop_with_empty_axes=0) -> Z
output: name='Z' type=dtype('float32') shape=None