onnx-array-api: APIs to create ONNX Graphs#

https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api https://badge.fury.io/py/onnx-array-api.svg GitHub Issues MIT License size https://img.shields.io/badge/code%20style-black-000000.svg https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J

onnx-array-api implements APIs to create custom ONNX graphs. The objective is to speed up the implementation of converter libraries.

Sources available on github/onnx-array-api.

GraphBuilder API#

Almost every converting library (converting a machine learned model to ONNX) is implementing its own graph builder and customizes it for its needs. It handles some frequent tasks such as giving names to intermediate results, loading, saving onnx models. It can be used as well to extend an existing graph. See GraphBuilder: common API for ONNX.

<<<

import numpy as np
from onnx_array_api.graph_api import GraphBuilder
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot

g = GraphBuilder()
g.make_tensor_input("X", np.float32, (None, None))
g.make_tensor_input("Y", np.float32, (None, None))
r1 = g.make_node(
    "Sub", ["X", "Y"]
)  # the name given to the output is given by the class,
# it ensures the name is unique
init = g.make_initializer(np.array([2], dtype=np.int64))  # the class automatically
# converts the array to a tensor
r2 = g.make_node("Pow", [r1, init])
g.make_node("ReduceSum", [r2], outputs=["Z"])  # the output name is given because
# the user wants to choose the name
g.make_tensor_output("Z", np.float32, (None, None))

onx = g.to_onnx()  # final conversion to onnx

print(onnx_simple_text_plot(onx))

>>>

    opset: domain='' version=21
    input: name='X' type=dtype('float32') shape=['', '']
    input: name='Y' type=dtype('float32') shape=['', '']
    init: name='cst' type=dtype('int64') shape=(1,) -- array([2])
    Sub(X, Y) -> _onx_sub0
      Pow(_onx_sub0, cst) -> _onx_pow0
        ReduceSum(_onx_pow0) -> Z
    output: name='Z' type=dtype('float32') shape=['', '']

Light API#

The syntax is inspired from the Reverse Polish Notation. This kind of API is easy to use to build new graphs, less easy to extend an existing graph. See Light API for ONNX: everything in one line.

<<<

import numpy as np
from onnx_array_api.light_api import start
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot

model = (
    start()
    .vin("X")
    .vin("Y")
    .bring("X", "Y")
    .Sub()
    .rename("dxy")
    .cst(np.array([2], dtype=np.int64), "two")
    .bring("dxy", "two")
    .Pow()
    .ReduceSum()
    .rename("Z")
    .vout()
    .to_onnx()
)

print(onnx_simple_text_plot(model))

>>>

    opset: domain='' version=21
    input: name='X' type=dtype('float32') shape=None
    input: name='Y' type=dtype('float32') shape=None
    init: name='two' type=dtype('int64') shape=(1,) -- array([2])
    Sub(X, Y) -> dxy
      Pow(dxy, two) -> r1_0
        ReduceSum(r1_0, keepdims=1, noop_with_empty_axes=0) -> Z
    output: name='Z' type=dtype('float32') shape=None

Numpy API#

Writing ONNX graphs requires to know ONNX syntax unless it is possible to reuse an existing syntax such as numpy. This is what this API is doing. This kind of API is easy to use to build new graphs, almost impossible to use to extend new graphs as it usually requires to know onnx for that. See Numpy API for ONNX.

<<<

import numpy as np  # A
from onnx_array_api.npx import absolute, jit_onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot


def l1_loss(x, y):
    return absolute(x - y).sum()


def l2_loss(x, y):
    return ((x - y) ** 2).sum()


def myloss(x, y):
    return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])


jitted_myloss = jit_onnx(myloss)

x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
res = jitted_myloss(x, y)
print(res)

print(onnx_simple_text_plot(jitted_myloss.get_onnx()))

>>>

    0.042
    opset: domain='' version=18
    input: name='x0' type=dtype('float32') shape=['', '']
    input: name='x1' type=dtype('float32') shape=['', '']
    Constant(value=[1]) -> cst__0
    Constant(value=[2]) -> cst__1
    Constant(value=[1]) -> cst__2
      Slice(x0, cst__0, cst__1, cst__2) -> r__12
    Constant(value=[1]) -> cst__3
    Constant(value=[2]) -> cst__4
    Constant(value=[1]) -> cst__5
      Slice(x1, cst__3, cst__4, cst__5) -> r__14
    Constant(value=[0]) -> cst__6
    Constant(value=[1]) -> cst__7
    Constant(value=[1]) -> cst__8
      Slice(x0, cst__6, cst__7, cst__8) -> r__16
    Constant(value=[0]) -> cst__9
    Constant(value=[1]) -> cst__10
    Constant(value=[1]) -> cst__11
      Slice(x1, cst__9, cst__10, cst__11) -> r__18
    Constant(value=[1]) -> cst__13
      Squeeze(r__12, cst__13) -> r__20
    Constant(value=[1]) -> cst__15
      Squeeze(r__14, cst__15) -> r__21
        Sub(r__20, r__21) -> r__24
    Constant(value=[1]) -> cst__17
      Squeeze(r__16, cst__17) -> r__22
    Constant(value=[1]) -> cst__19
      Squeeze(r__18, cst__19) -> r__23
        Sub(r__22, r__23) -> r__25
          Abs(r__25) -> r__28
            ReduceSum(r__28, keepdims=0) -> r__30
    Constant(value=2) -> r__26
      CastLike(r__26, r__24) -> r__27
        Pow(r__24, r__27) -> r__29
          ReduceSum(r__29, keepdims=0) -> r__31
            Add(r__30, r__31) -> r__32
    output: name='r__32' type=dtype('float32') shape=None
digraph{
  size=7;
  nodesep=0.05;
  ranksep=0.25;
  orientation=portrait;

  x0 [shape=box color=red label="x0\nTensorProto.FLOAT\nshape=['', '']" fontsize=10];
  x1 [shape=box color=red label="x1\nTensorProto.FLOAT\nshape=['', '']" fontsize=10];

  r__32 [shape=box color=green label="r__32\nTensorProto.FLOAT" fontsize=10];


  cst__0 [shape=box label="cst__0" fontsize=10];
  Constant [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant -> cst__0;

  cst__1 [shape=box label="cst__1" fontsize=10];
  Constant1 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10];
  Constant1 -> cst__1;

  cst__2 [shape=box label="cst__2" fontsize=10];
  Constant12 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant12 -> cst__2;

  cst__3 [shape=box label="cst__3" fontsize=10];
  Constant123 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant123 -> cst__3;

  cst__4 [shape=box label="cst__4" fontsize=10];
  Constant1234 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10];
  Constant1234 -> cst__4;

  cst__5 [shape=box label="cst__5" fontsize=10];
  Constant12345 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant12345 -> cst__5;

  cst__6 [shape=box label="cst__6" fontsize=10];
  Constant123456 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10];
  Constant123456 -> cst__6;

  cst__7 [shape=box label="cst__7" fontsize=10];
  Constant1234567 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant1234567 -> cst__7;

  cst__8 [shape=box label="cst__8" fontsize=10];
  Constant12345678 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant12345678 -> cst__8;

  cst__9 [shape=box label="cst__9" fontsize=10];
  Constant123456789 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10];
  Constant123456789 -> cst__9;

  cst__10 [shape=box label="cst__10" fontsize=10];
  Constant12345678910 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant12345678910 -> cst__10;

  cst__11 [shape=box label="cst__11" fontsize=10];
  Constant1234567891011 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant1234567891011 -> cst__11;

  r__12 [shape=box label="r__12" fontsize=10];
  Slice [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
  x0 -> Slice;
  cst__0 -> Slice;
  cst__1 -> Slice;
  cst__2 -> Slice;
  Slice -> r__12;

  cst__13 [shape=box label="cst__13" fontsize=10];
  Constant123456789101112 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant123456789101112 -> cst__13;

  r__14 [shape=box label="r__14" fontsize=10];
  Slice1 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
  x1 -> Slice1;
  cst__3 -> Slice1;
  cst__4 -> Slice1;
  cst__5 -> Slice1;
  Slice1 -> r__14;

  cst__15 [shape=box label="cst__15" fontsize=10];
  Constant12345678910111213 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant12345678910111213 -> cst__15;

  r__16 [shape=box label="r__16" fontsize=10];
  Slice12 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
  x0 -> Slice12;
  cst__6 -> Slice12;
  cst__7 -> Slice12;
  cst__8 -> Slice12;
  Slice12 -> r__16;

  cst__17 [shape=box label="cst__17" fontsize=10];
  Constant1234567891011121314 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant1234567891011121314 -> cst__17;

  r__18 [shape=box label="r__18" fontsize=10];
  Slice123 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];
  x1 -> Slice123;
  cst__9 -> Slice123;
  cst__10 -> Slice123;
  cst__11 -> Slice123;
  Slice123 -> r__18;

  cst__19 [shape=box label="cst__19" fontsize=10];
  Constant123456789101112131415 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];
  Constant123456789101112131415 -> cst__19;

  r__20 [shape=box label="r__20" fontsize=10];
  Squeeze [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
  r__12 -> Squeeze;
  cst__13 -> Squeeze;
  Squeeze -> r__20;

  r__21 [shape=box label="r__21" fontsize=10];
  Squeeze1 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
  r__14 -> Squeeze1;
  cst__15 -> Squeeze1;
  Squeeze1 -> r__21;

  r__22 [shape=box label="r__22" fontsize=10];
  Squeeze12 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
  r__16 -> Squeeze12;
  cst__17 -> Squeeze12;
  Squeeze12 -> r__22;

  r__23 [shape=box label="r__23" fontsize=10];
  Squeeze123 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];
  r__18 -> Squeeze123;
  cst__19 -> Squeeze123;
  Squeeze123 -> r__23;

  r__24 [shape=box label="r__24" fontsize=10];
  Sub [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10];
  r__20 -> Sub;
  r__21 -> Sub;
  Sub -> r__24;

  r__25 [shape=box label="r__25" fontsize=10];
  Sub1 [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10];
  r__22 -> Sub1;
  r__23 -> Sub1;
  Sub1 -> r__25;

  r__26 [shape=box label="r__26" fontsize=10];
  Constant12345678910111213141516 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=2" fontsize=10];
  Constant12345678910111213141516 -> r__26;

  r__27 [shape=box label="r__27" fontsize=10];
  CastLike [shape=box style="filled,rounded" color=orange label="CastLike" fontsize=10];
  r__26 -> CastLike;
  r__24 -> CastLike;
  CastLike -> r__27;

  r__28 [shape=box label="r__28" fontsize=10];
  Abs [shape=box style="filled,rounded" color=orange label="Abs" fontsize=10];
  r__25 -> Abs;
  Abs -> r__28;

  r__29 [shape=box label="r__29" fontsize=10];
  Pow [shape=box style="filled,rounded" color=orange label="Pow" fontsize=10];
  r__24 -> Pow;
  r__27 -> Pow;
  Pow -> r__29;

  r__30 [shape=box label="r__30" fontsize=10];
  ReduceSum [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10];
  r__28 -> ReduceSum;
  ReduceSum -> r__30;

  r__31 [shape=box label="r__31" fontsize=10];
  ReduceSum1 [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10];
  r__29 -> ReduceSum1;
  ReduceSum1 -> r__31;

  Add [shape=box style="filled,rounded" color=orange label="Add" fontsize=10];
  r__30 -> Add;
  r__31 -> Add;
  Add -> r__32;
}

Older versions#