SQL queries to ONNX#

sql_to_onnx() converts a SQL query string into a self-contained onnx.ModelProto. Each referenced column becomes a separate 1-D ONNX input, matching the columnar representation used in tabular data pipelines.

This example covers:

  1. Basic SELECT — column pass-through and arithmetic.

  2. WHERE clause — row filtering with comparison predicates.

  3. AggregationsSUM, AVG, MIN, MAX in the SELECT list.

  4. Custom Python functions — user-defined numpy functions called directly from SQL via the custom_functions parameter; the function body is traced to ONNX nodes using trace_numpy_function().

  5. Graph visualization — rendering the produced ONNX model with plot_dot().

See SQL-to-ONNX Converter Design for the full design discussion.

import numpy as np
import onnxruntime
from yobx.helpers.onnx_helper import pretty_onnx
from yobx.reference import ExtendedReferenceEvaluator
from yobx.sql import sql_to_onnx

Helper#

A small helper runs a model with both the reference evaluator and onnxruntime and verifies the results agree.

def run(onx, feeds):
    """Run *onx* through the reference evaluator and ORT; return ref outputs."""
    ref = ExtendedReferenceEvaluator(onx)
    ref_outputs = ref.run(None, feeds)

    sess = onnxruntime.InferenceSession(
        onx.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    ort_outputs = sess.run(None, feeds)

    for r, o in zip(ref_outputs, ort_outputs):
        np.testing.assert_allclose(r, o, rtol=1e-5, atol=1e-6)
    return ref_outputs

1. Basic SELECT#

The simplest query selects two columns and computes their element-wise sum. Each column in input_dtypes becomes a separate ONNX graph input.

dtypes = {"a": np.float32, "b": np.float32}
a = np.array([1.0, 2.0, 3.0], dtype=np.float32)
b = np.array([4.0, 5.0, 6.0], dtype=np.float32)

onx_add = sql_to_onnx("SELECT a + b AS total FROM t", dtypes)

(total,) = run(onx_add, {"a": a, "b": b})
print("a + b =", total)

# Inspect the ONNX graph inputs — one per column
print("ONNX inputs:", [inp.name for inp in onx_add.graph.input])
a + b = [5. 7. 9.]
ONNX inputs: ['a', 'b']

The model

print(pretty_onnx(onx_add))
opset: domain='' version=21
input: name='a' type=dtype('float32') shape=['N']
input: name='b' type=dtype('float32') shape=['N']
Add(a, b) -> total
output: name='total' type='NOTENSOR' shape=None

2. WHERE clause (row filtering)#

The WHERE clause is translated to a boolean mask followed by Compress nodes that select only the matching rows from every column.

onx_where = sql_to_onnx("SELECT a, b FROM t WHERE a > 1.5", dtypes)
a_filt, b_filt = run(onx_where, {"a": a, "b": b})

print("rows where a > 1.5:")
print("  a =", a_filt)
print("  b =", b_filt)
assert list(a_filt) == [2.0, 3.0]
rows where a > 1.5:
  a = [2. 3.]
  b = [5. 6.]

The model

print(pretty_onnx(onx_where))
opset: domain='' version=21
input: name='a' type=dtype('float32') shape=['N']
input: name='b' type=dtype('float32') shape=['N']
init: name='filter_mask_r_lit' type=float32 shape=(1,) -- array([1.5], dtype=float32)
CastLike(filter_mask_r_lit, a) -> _onx_castlike_filter_mask_r_lit
  Greater(a, _onx_castlike_filter_mask_r_lit) -> _onx_greater_a
    Compress(a, _onx_greater_a, axis=0) -> output_0
Compress(b, _onx_greater_a, axis=0) -> output_1
output: name='output_0' type='NOTENSOR' shape=None
output: name='output_1' type='NOTENSOR' shape=None

3. Aggregation functions#

SUM, AVG, MIN, and MAX in the SELECT list are emitted as ReduceSum, ReduceMean, ReduceMin, and ReduceMax ONNX nodes.

onx_agg = sql_to_onnx("SELECT SUM(a) AS s, AVG(b) AS m FROM t", dtypes)
s_arr, m_arr = run(onx_agg, {"a": a, "b": b})
s = float(s_arr)
m = float(m_arr)

print(f"SUM(a) = {s:.1f}  (expected 6.0)")
print(f"AVG(b) = {m:.1f}  (expected 5.0)")
assert abs(s - 6.0) < 1e-5
assert abs(m - 5.0) < 1e-5
SUM(a) = 6.0  (expected 6.0)
AVG(b) = 5.0  (expected 5.0)

The model

print(pretty_onnx(onx_agg))
opset: domain='' version=21
input: name='a' type=dtype('float32') shape=['N']
input: name='b' type=dtype('float32') shape=['N']
init: name='select_s_axes' type=int64 shape=(1,) -- array([0])
ReduceMean(b, select_s_axes, keepdims=0) -> m
ReduceSum(a, select_s_axes, keepdims=0) -> s
output: name='s' type='NOTENSOR' shape=None
output: name='m' type='NOTENSOR' shape=None

4. WHERE + arithmetic SELECT#

Filters and expressions can be combined freely.

onx_combined = sql_to_onnx("SELECT a + b AS total FROM t WHERE a > 0", dtypes)
a2 = np.array([1.0, -2.0, 3.0], dtype=np.float32)
b2 = np.array([4.0, 5.0, 6.0], dtype=np.float32)

(total2,) = run(onx_combined, {"a": a2, "b": b2})
print("a + b WHERE a > 0 =", total2)
np.testing.assert_allclose(total2, np.array([5.0, 9.0], dtype=np.float32))
a + b WHERE a > 0 = [5. 9.]

The model

print(pretty_onnx(onx_combined))
opset: domain='' version=21
input: name='a' type=dtype('float32') shape=['N']
input: name='b' type=dtype('float32') shape=['N']
init: name='filter_mask_r_lit' type=int64 shape=(1,) -- array([0])
CastLike(filter_mask_r_lit, a) -> _onx_castlike_filter_mask_r_lit
  Greater(a, _onx_castlike_filter_mask_r_lit) -> _onx_greater_a
    Compress(a, _onx_greater_a, axis=0) -> _onx_compress_a
Compress(b, _onx_greater_a, axis=0) -> _onx_compress_b
  Add(_onx_compress_a, _onx_compress_b) -> total
output: name='total' type='NOTENSOR' shape=None

5. Custom Python functions#

Any numpy-backed Python function can be called by name directly in SQL. Pass it as a custom_functions dictionary entry: the key is the function name as it appears in the SQL string, and the value is the callable.

Under the hood trace_numpy_function() replaces the real inputs with lightweight proxy objects that record every numpy operation as an ONNX node, so the resulting graph is equivalent to running the Python function at inference time.

def clip_sqrt(x):
    """Safe square root: sqrt(max(x, 0))."""
    return np.sqrt(np.maximum(x, np.float32(0)))


dtypes_a = {"a": np.float32}
a3 = np.array([4.0, -1.0, 9.0, 0.0], dtype=np.float32)

onx_func = sql_to_onnx(
    "SELECT clip_sqrt(a) AS r FROM t", dtypes_a, custom_functions={"clip_sqrt": clip_sqrt}
)

(r,) = run(onx_func, {"a": a3})
expected = clip_sqrt(a3)
print("clip_sqrt(a) =", r)
np.testing.assert_allclose(r, expected, atol=1e-6)
print("clip_sqrt ✓")
clip_sqrt(a) = [2. 0. 3. 0.]
clip_sqrt ✓

The model

print(pretty_onnx(onx_func))
opset: domain='' version=21
input: name='a' type=dtype('float32') shape=['N']
init: name='init1_s_' type=float32 shape=() -- array([0.], dtype=float32)-- Opset.make_node.1/Small
Max(a, init1_s_) -> _onx_max_a
  Sqrt(_onx_max_a) -> r
output: name='r' type='NOTENSOR' shape=None

6. Custom function in WHERE clause#

Custom functions also work in WHERE predicates.

onx_where_func = sql_to_onnx(
    "SELECT a FROM t WHERE clip_sqrt(a) > 1", dtypes_a, custom_functions={"clip_sqrt": clip_sqrt}
)

(a_filt2,) = run(onx_where_func, {"a": a3})
print("a WHERE clip_sqrt(a) > 1 =", a_filt2)
np.testing.assert_allclose(a_filt2, a3[clip_sqrt(a3) > 1], atol=1e-6)
print("WHERE custom function ✓")
a WHERE clip_sqrt(a) > 1 = [4. 9.]
WHERE custom function ✓

The model

print(pretty_onnx(onx_where_func))
opset: domain='' version=21
input: name='a' type=dtype('float32') shape=['N']
init: name='init1_s_' type=float32 shape=() -- array([0.], dtype=float32)-- Opset.make_node.1/Small
init: name='filter_mask_r_lit' type=int64 shape=(1,) -- array([1])
Max(a, init1_s_) -> _onx_max_a
  Sqrt(_onx_max_a) -> _onx_sqrt_max_a
    CastLike(filter_mask_r_lit, _onx_sqrt_max_a) -> _onx_castlike_filter_mask_r_lit
    Greater(_onx_sqrt_max_a, _onx_castlike_filter_mask_r_lit) -> _onx_greater_filter_mask_l_clip_sqrt
      Compress(a, _onx_greater_filter_mask_l_clip_sqrt, axis=0) -> output_0
output: name='output_0' type='NOTENSOR' shape=None

7. Two-argument custom function#

Functions with more than one argument receive one tensor per argument.

def weighted_sum(x, y, alpha=0.5):
    """Compute alpha * x + (1 - alpha) * y."""
    return alpha * x + (np.float32(1) - np.float32(alpha)) * y


dtypes_ab = {"a": np.float32, "b": np.float32}
onx_ws = sql_to_onnx(
    "SELECT wsum(a, b) AS ws FROM t", dtypes_ab, custom_functions={"wsum": weighted_sum}
)

(ws,) = run(onx_ws, {"a": a, "b": b})
expected_ws = weighted_sum(a, b)
print("weighted_sum(a, b) =", ws)
np.testing.assert_allclose(ws, expected_ws, atol=1e-6)
print("Two-argument custom function ✓")
weighted_sum(a, b) = [2.5 3.5 4.5]
Two-argument custom function ✓

The model

print(pretty_onnx(onx_ws))
opset: domain='' version=21
input: name='a' type=dtype('float32') shape=['N']
input: name='b' type=dtype('float32') shape=['N']
init: name='init1_s_' type=float32 shape=() -- array([0.5], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
Mul(init1_s_, a) -> _onx_mul_init1_s_
Mul(init1_s_, b) -> _onx_mul_init1_s_2
  Add(_onx_mul_init1_s_, _onx_mul_init1_s_2) -> ws
output: name='ws' type='NOTENSOR' shape=None

8. Visualise the ONNX node types#

The bar chart below compares how many ONNX nodes each query produces and which node types appear in the custom-function query.

import matplotlib.pyplot as plt  # noqa: E402

models = {
    "basic add": onx_add,
    "WHERE filter": onx_where,
    "aggregation": onx_agg,
    "custom func": onx_func,
    "custom WHERE": onx_where_func,
}

node_counts = [len(list(m.graph.node)) for m in models.values()]

fig, axes = plt.subplots(1, 2, figsize=(11, 4))

# Left: node count per query
ax = axes[0]
bars = ax.bar(list(models.keys()), node_counts, color="#4c72b0")
ax.set_ylabel("Number of ONNX nodes")
ax.set_title("ONNX node count per SQL query")
for bar, count in zip(bars, node_counts):
    ax.text(
        bar.get_x() + bar.get_width() / 2,
        bar.get_height() + 0.05,
        str(count),
        ha="center",
        va="bottom",
        fontsize=9,
    )
ax.tick_params(axis="x", labelrotation=15)

# Right: node types in custom-function model
op_types: dict[str, int] = {}
for node in onx_func.graph.node:
    op_types[node.op_type] = op_types.get(node.op_type, 0) + 1

ax2 = axes[1]
ax2.bar(list(op_types.keys()), list(op_types.values()), color="#dd8452")
ax2.set_ylabel("Count")
ax2.set_title("Node types in 'clip_sqrt' query")
ax2.tick_params(axis="x", labelrotation=25)

plt.tight_layout()
plt.show()
ONNX node count per SQL query, Node types in 'clip_sqrt' query

9. Display the ONNX model graph#

plot_dot() renders the ONNX graph as an image so you can inspect nodes, edges, and tensor shapes at a glance. Here we visualize the basic SELECT a + b query — a single Add node connecting two inputs to one output.

from yobx.doc import plot_dot  # noqa: E402

plot_dot(onx_add)
plot sql to onnx

Total running time of the script: (0 minutes 0.591 seconds)

Related examples

Gallery generated by Sphinx-Gallery