onnx-array-api: APIs to create ONNX Graphs#

https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api https://badge.fury.io/py/onnx-array-api.svg GitHub Issues MIT License size https://img.shields.io/badge/code%20style-black-000000.svg https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J

onnx-array-api implements APIs to create custom ONNX graphs. The objective is to speed up the implementation of converter libraries.

Numpy API

Sources available on github/onnx-array-api.

<<<

import numpy as np  # A
from onnx_array_api.npx import absolute, jit_onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot


def l1_loss(x, y):
    return absolute(x - y).sum()


def l2_loss(x, y):
    return ((x - y) ** 2).sum()


def myloss(x, y):
    return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])


jitted_myloss = jit_onnx(myloss)

x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
res = jitted_myloss(x, y)
print(res)

print(onnx_simple_text_plot(jitted_myloss.get_onnx()))

>>>

    0.042
    opset: domain='' version=18
    input: name='x0' type=dtype('float32') shape=['', '']
    input: name='x1' type=dtype('float32') shape=['', '']
    Constant(value=[1]) -> cst__0
    Constant(value=[2]) -> cst__1
    Constant(value=[1]) -> cst__2
      Slice(x0, cst__0, cst__1, cst__2) -> r__12
    Constant(value=[1]) -> cst__3
    Constant(value=[2]) -> cst__4
    Constant(value=[1]) -> cst__5
      Slice(x1, cst__3, cst__4, cst__5) -> r__14
    Constant(value=[0]) -> cst__6
    Constant(value=[1]) -> cst__7
    Constant(value=[1]) -> cst__8
      Slice(x0, cst__6, cst__7, cst__8) -> r__16
    Constant(value=[0]) -> cst__9
    Constant(value=[1]) -> cst__10
    Constant(value=[1]) -> cst__11
      Slice(x1, cst__9, cst__10, cst__11) -> r__18
    Constant(value=[1]) -> cst__13
      Squeeze(r__12, cst__13) -> r__20
    Constant(value=[1]) -> cst__15
      Squeeze(r__14, cst__15) -> r__21
        Sub(r__20, r__21) -> r__24
    Constant(value=[1]) -> cst__17
      Squeeze(r__16, cst__17) -> r__22
    Constant(value=[1]) -> cst__19
      Squeeze(r__18, cst__19) -> r__23
        Sub(r__22, r__23) -> r__25
          Abs(r__25) -> r__28
            ReduceSum(r__28, keepdims=0) -> r__30
    Constant(value=2) -> r__26
      CastLike(r__26, r__24) -> r__27
        Pow(r__24, r__27) -> r__29
          ReduceSum(r__29, keepdims=0) -> r__31
            Add(r__30, r__31) -> r__32
    output: name='r__32' type=dtype('float32') shape=None
digraph{
  size=7;
  orientation=portrait;
  ranksep=0.25;
  nodesep=0.05;

  x0 [shape=box color=red label="x0\nTensorProto.FLOAT\nshape=['', '']" fontsize=10];
  x1 [shape=box color=red label="x1\nTensorProto.FLOAT\nshape=['', '']" fontsize=10];

  r__32 [shape=box color=green label="r__32\nTensorProto.FLOAT" fontsize=10];


  cst__0 [shape=box label="cst__0" fontsize=10];
  Constant [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant -> cst__0;

  cst__1 [shape=box label="cst__1" fontsize=10];
  Constant1 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10];
  Constant1 -> cst__1;

  cst__2 [shape=box label="cst__2" fontsize=10];
  Constant12 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant12 -> cst__2;

  cst__3 [shape=box label="cst__3" fontsize=10];
  Constant123 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant123 -> cst__3;

  cst__4 [shape=box label="cst__4" fontsize=10];
  Constant1234 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10];
  Constant1234 -> cst__4;

  cst__5 [shape=box label="cst__5" fontsize=10];
  Constant12345 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant12345 -> cst__5;

  cst__6 [shape=box label="cst__6" fontsize=10];
  Constant123456 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10];
  Constant123456 -> cst__6;

  cst__7 [shape=box label="cst__7" fontsize=10];
  Constant1234567 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant1234567 -> cst__7;

  cst__8 [shape=box label="cst__8" fontsize=10];
  Constant12345678 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant12345678 -> cst__8;

  cst__9 [shape=box label="cst__9" fontsize=10];
  Constant123456789 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10];
  Constant123456789 -> cst__9;

  cst__10 [shape=box label="cst__10" fontsize=10];
  Constant12345678910 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant12345678910 -> cst__10;

  cst__11 [shape=box label="cst__11" fontsize=10];
  Constant1234567891011 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant1234567891011 -> cst__11;

  r__12 [shape=box label="r__12" fontsize=10];
  Slice [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
  x0 -> Slice;
  cst__0 -> Slice;
  cst__1 -> Slice;
  cst__2 -> Slice;
  Slice -> r__12;

  cst__13 [shape=box label="cst__13" fontsize=10];
  Constant123456789101112 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant123456789101112 -> cst__13;

  r__14 [shape=box label="r__14" fontsize=10];
  Slice1 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
  x1 -> Slice1;
  cst__3 -> Slice1;
  cst__4 -> Slice1;
  cst__5 -> Slice1;
  Slice1 -> r__14;

  cst__15 [shape=box label="cst__15" fontsize=10];
  Constant12345678910111213 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant12345678910111213 -> cst__15;

  r__16 [shape=box label="r__16" fontsize=10];
  Slice12 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
  x0 -> Slice12;
  cst__6 -> Slice12;
  cst__7 -> Slice12;
  cst__8 -> Slice12;
  Slice12 -> r__16;

  cst__17 [shape=box label="cst__17" fontsize=10];
  Constant1234567891011121314 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant1234567891011121314 -> cst__17;

  r__18 [shape=box label="r__18" fontsize=10];
  Slice123 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
  x1 -> Slice123;
  cst__9 -> Slice123;
  cst__10 -> Slice123;
  cst__11 -> Slice123;
  Slice123 -> r__18;

  cst__19 [shape=box label="cst__19" fontsize=10];
  Constant123456789101112131415 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant123456789101112131415 -> cst__19;

  r__20 [shape=box label="r__20" fontsize=10];
  Squeeze [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
  r__12 -> Squeeze;
  cst__13 -> Squeeze;
  Squeeze -> r__20;

  r__21 [shape=box label="r__21" fontsize=10];
  Squeeze1 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
  r__14 -> Squeeze1;
  cst__15 -> Squeeze1;
  Squeeze1 -> r__21;

  r__22 [shape=box label="r__22" fontsize=10];
  Squeeze12 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
  r__16 -> Squeeze12;
  cst__17 -> Squeeze12;
  Squeeze12 -> r__22;

  r__23 [shape=box label="r__23" fontsize=10];
  Squeeze123 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
  r__18 -> Squeeze123;
  cst__19 -> Squeeze123;
  Squeeze123 -> r__23;

  r__24 [shape=box label="r__24" fontsize=10];
  Sub [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10];
  r__20 -> Sub;
  r__21 -> Sub;
  Sub -> r__24;

  r__25 [shape=box label="r__25" fontsize=10];
  Sub1 [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10];
  r__22 -> Sub1;
  r__23 -> Sub1;
  Sub1 -> r__25;

  r__26 [shape=box label="r__26" fontsize=10];
  Constant12345678910111213141516 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=2" fontsize=10];
  Constant12345678910111213141516 -> r__26;

  r__27 [shape=box label="r__27" fontsize=10];
  CastLike [shape=box style="filled,rounded" color=orange label="CastLike" fontsize=10];
  r__26 -> CastLike;
  r__24 -> CastLike;
  CastLike -> r__27;

  r__28 [shape=box label="r__28" fontsize=10];
  Abs [shape=box style="filled,rounded" color=orange label="Abs" fontsize=10];
  r__25 -> Abs;
  Abs -> r__28;

  r__29 [shape=box label="r__29" fontsize=10];
  Pow [shape=box style="filled,rounded" color=orange label="Pow" fontsize=10];
  r__24 -> Pow;
  r__27 -> Pow;
  Pow -> r__29;

  r__30 [shape=box label="r__30" fontsize=10];
  ReduceSum [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10];
  r__28 -> ReduceSum;
  ReduceSum -> r__30;

  r__31 [shape=box label="r__31" fontsize=10];
  ReduceSum1 [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10];
  r__29 -> ReduceSum1;
  ReduceSum1 -> r__31;

  Add [shape=box style="filled,rounded" color=orange label="Add" fontsize=10];
  r__30 -> Add;
  r__31 -> Add;
  Add -> r__32;
}

Light API

<<<

import numpy as np
from onnx_array_api.light_api import start
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot

model = (
    start()
    .vin("X")
    .vin("Y")
    .bring("X", "Y")
    .Sub()
    .rename("dxy")
    .cst(np.array([2], dtype=np.int64), "two")
    .bring("dxy", "two")
    .Pow()
    .ReduceSum()
    .rename("Z")
    .vout()
    .to_onnx()
)

print(onnx_simple_text_plot(model))

>>>

    opset: domain='' version=20
    input: name='X' type=dtype('float32') shape=None
    input: name='Y' type=dtype('float32') shape=None
    init: name='two' type=dtype('int64') shape=(1,) -- array([2])
    Sub(X, Y) -> dxy
      Pow(dxy, two) -> r1_0
        ReduceSum(r1_0, keepdims=1, noop_with_empty_axes=0) -> Z
    output: name='Z' type=dtype('float32') shape=None