onnx-array-api: APIs to create ONNX Graphs#

onnx-array-api implements APIs to create custom ONNX graphs. The objective is to speed up the implementation of converter libraries.
Contents
More
Numpy API
Sources available on github/onnx-array-api.
<<<
import numpy as np # A
from onnx_array_api.npx import absolute, jit_onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
def l1_loss(x, y):
return absolute(x - y).sum()
def l2_loss(x, y):
return ((x - y) ** 2).sum()
def myloss(x, y):
return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
jitted_myloss = jit_onnx(myloss)
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
res = jitted_myloss(x, y)
print(res)
print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
>>>
0.042
opset: domain='' version=18
input: name='x0' type=dtype('float32') shape=['', '']
input: name='x1' type=dtype('float32') shape=['', '']
Constant(value=[1]) -> cst__0
Constant(value=[2]) -> cst__1
Constant(value=[1]) -> cst__2
Slice(x0, cst__0, cst__1, cst__2) -> r__12
Constant(value=[1]) -> cst__3
Constant(value=[2]) -> cst__4
Constant(value=[1]) -> cst__5
Slice(x1, cst__3, cst__4, cst__5) -> r__14
Constant(value=[0]) -> cst__6
Constant(value=[1]) -> cst__7
Constant(value=[1]) -> cst__8
Slice(x0, cst__6, cst__7, cst__8) -> r__16
Constant(value=[0]) -> cst__9
Constant(value=[1]) -> cst__10
Constant(value=[1]) -> cst__11
Slice(x1, cst__9, cst__10, cst__11) -> r__18
Constant(value=[1]) -> cst__13
Squeeze(r__12, cst__13) -> r__20
Constant(value=[1]) -> cst__15
Squeeze(r__14, cst__15) -> r__21
Sub(r__20, r__21) -> r__24
Constant(value=[1]) -> cst__17
Squeeze(r__16, cst__17) -> r__22
Constant(value=[1]) -> cst__19
Squeeze(r__18, cst__19) -> r__23
Sub(r__22, r__23) -> r__25
Abs(r__25) -> r__28
ReduceSum(r__28, keepdims=0) -> r__30
Constant(value=2) -> r__26
CastLike(r__26, r__24) -> r__27
Pow(r__24, r__27) -> r__29
ReduceSum(r__29, keepdims=0) -> r__31
Add(r__30, r__31) -> r__32
output: name='r__32' type=dtype('float32') shape=None
![digraph{
size=7;
orientation=portrait;
ranksep=0.25;
nodesep=0.05;
x0 [shape=box color=red label="x0\nTensorProto.FLOAT\nshape=['', '']" fontsize=10];
x1 [shape=box color=red label="x1\nTensorProto.FLOAT\nshape=['', '']" fontsize=10];
r__32 [shape=box color=green label="r__32\nTensorProto.FLOAT" fontsize=10];
cst__0 [shape=box label="cst__0" fontsize=10];
Constant [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant -> cst__0;
cst__1 [shape=box label="cst__1" fontsize=10];
Constant1 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10];
Constant1 -> cst__1;
cst__2 [shape=box label="cst__2" fontsize=10];
Constant12 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant12 -> cst__2;
cst__3 [shape=box label="cst__3" fontsize=10];
Constant123 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant123 -> cst__3;
cst__4 [shape=box label="cst__4" fontsize=10];
Constant1234 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10];
Constant1234 -> cst__4;
cst__5 [shape=box label="cst__5" fontsize=10];
Constant12345 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant12345 -> cst__5;
cst__6 [shape=box label="cst__6" fontsize=10];
Constant123456 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10];
Constant123456 -> cst__6;
cst__7 [shape=box label="cst__7" fontsize=10];
Constant1234567 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant1234567 -> cst__7;
cst__8 [shape=box label="cst__8" fontsize=10];
Constant12345678 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant12345678 -> cst__8;
cst__9 [shape=box label="cst__9" fontsize=10];
Constant123456789 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10];
Constant123456789 -> cst__9;
cst__10 [shape=box label="cst__10" fontsize=10];
Constant12345678910 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant12345678910 -> cst__10;
cst__11 [shape=box label="cst__11" fontsize=10];
Constant1234567891011 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant1234567891011 -> cst__11;
r__12 [shape=box label="r__12" fontsize=10];
Slice [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
x0 -> Slice;
cst__0 -> Slice;
cst__1 -> Slice;
cst__2 -> Slice;
Slice -> r__12;
cst__13 [shape=box label="cst__13" fontsize=10];
Constant123456789101112 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant123456789101112 -> cst__13;
r__14 [shape=box label="r__14" fontsize=10];
Slice1 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
x1 -> Slice1;
cst__3 -> Slice1;
cst__4 -> Slice1;
cst__5 -> Slice1;
Slice1 -> r__14;
cst__15 [shape=box label="cst__15" fontsize=10];
Constant12345678910111213 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant12345678910111213 -> cst__15;
r__16 [shape=box label="r__16" fontsize=10];
Slice12 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
x0 -> Slice12;
cst__6 -> Slice12;
cst__7 -> Slice12;
cst__8 -> Slice12;
Slice12 -> r__16;
cst__17 [shape=box label="cst__17" fontsize=10];
Constant1234567891011121314 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant1234567891011121314 -> cst__17;
r__18 [shape=box label="r__18" fontsize=10];
Slice123 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
x1 -> Slice123;
cst__9 -> Slice123;
cst__10 -> Slice123;
cst__11 -> Slice123;
Slice123 -> r__18;
cst__19 [shape=box label="cst__19" fontsize=10];
Constant123456789101112131415 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
Constant123456789101112131415 -> cst__19;
r__20 [shape=box label="r__20" fontsize=10];
Squeeze [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
r__12 -> Squeeze;
cst__13 -> Squeeze;
Squeeze -> r__20;
r__21 [shape=box label="r__21" fontsize=10];
Squeeze1 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
r__14 -> Squeeze1;
cst__15 -> Squeeze1;
Squeeze1 -> r__21;
r__22 [shape=box label="r__22" fontsize=10];
Squeeze12 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
r__16 -> Squeeze12;
cst__17 -> Squeeze12;
Squeeze12 -> r__22;
r__23 [shape=box label="r__23" fontsize=10];
Squeeze123 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
r__18 -> Squeeze123;
cst__19 -> Squeeze123;
Squeeze123 -> r__23;
r__24 [shape=box label="r__24" fontsize=10];
Sub [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10];
r__20 -> Sub;
r__21 -> Sub;
Sub -> r__24;
r__25 [shape=box label="r__25" fontsize=10];
Sub1 [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10];
r__22 -> Sub1;
r__23 -> Sub1;
Sub1 -> r__25;
r__26 [shape=box label="r__26" fontsize=10];
Constant12345678910111213141516 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=2" fontsize=10];
Constant12345678910111213141516 -> r__26;
r__27 [shape=box label="r__27" fontsize=10];
CastLike [shape=box style="filled,rounded" color=orange label="CastLike" fontsize=10];
r__26 -> CastLike;
r__24 -> CastLike;
CastLike -> r__27;
r__28 [shape=box label="r__28" fontsize=10];
Abs [shape=box style="filled,rounded" color=orange label="Abs" fontsize=10];
r__25 -> Abs;
Abs -> r__28;
r__29 [shape=box label="r__29" fontsize=10];
Pow [shape=box style="filled,rounded" color=orange label="Pow" fontsize=10];
r__24 -> Pow;
r__27 -> Pow;
Pow -> r__29;
r__30 [shape=box label="r__30" fontsize=10];
ReduceSum [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10];
r__28 -> ReduceSum;
ReduceSum -> r__30;
r__31 [shape=box label="r__31" fontsize=10];
ReduceSum1 [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10];
r__29 -> ReduceSum1;
ReduceSum1 -> r__31;
Add [shape=box style="filled,rounded" color=orange label="Add" fontsize=10];
r__30 -> Add;
r__31 -> Add;
Add -> r__32;
}](_images/graphviz-355f7e32b52da8fee86535b86f4700586cf7aac6.png)
Light API
<<<
import numpy as np
from onnx_array_api.light_api import start
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
model = (
start()
.vin("X")
.vin("Y")
.bring("X", "Y")
.Sub()
.rename("dxy")
.cst(np.array([2], dtype=np.int64), "two")
.bring("dxy", "two")
.Pow()
.ReduceSum()
.rename("Z")
.vout()
.to_onnx()
)
print(onnx_simple_text_plot(model))
>>>
opset: domain='' version=20
input: name='X' type=dtype('float32') shape=None
input: name='Y' type=dtype('float32') shape=None
init: name='two' type=dtype('int64') shape=(1,) -- array([2])
Sub(X, Y) -> dxy
Pow(dxy, two) -> r1_0
ReduceSum(r1_0, keepdims=1, noop_with_empty_axes=0) -> Z
output: name='Z' type=dtype('float32') shape=None