SQL-to-ONNX Converter Design#
Overview#
yobx.sql converts a SQL query string into a self-contained
onnx.ModelProto.
The primary design principle is:
Every column referenced in the query is a distinct 1-D ONNX input tensor.
This mirrors the columnar representation used in tabular data pipelines: the caller feeds one vector per column rather than a single 2-D matrix. The ONNX model can therefore be applied to any subset of rows without copying the full table.
Architecture#
The conversion pipeline consists of two stages:
Parsing (
parse_sql()) — turns the SQL string into aParsedQuerycontaining an ordered list ofSqlOperationobjects, one per SQL clause.Emission (
sql_to_onnx()) — iterates over the operations in execution order and appends ONNX nodes to aGraphBuilder.
SQL string
│
▼
parse_sql() ─── ParsedQuery ──► operations list
│
▼
sql_to_onnx() ─── GraphBuilder ──► ModelProto
Supported SQL clauses#
Clause |
SqlOperation |
ONNX nodes emitted |
|---|---|---|
|
|
|
|
|
|
|
(groups are processed together with |
|
|
|
Columnar input convention#
Each column becomes a separate ONNX graph input whose name matches the column name in the SQL query. Column names are normalised to lower-case.
import numpy as np
from yobx.sql import sql_to_onnx
# Columns "a" and "b" → two separate ONNX inputs
dtypes = {"a": np.float32, "b": np.float32}
onx = sql_to_onnx(
"SELECT a + b AS total FROM t WHERE a > 0",
dtypes,
)
# Inputs of the ONNX model
for inp in onx.graph.input:
print(inp.name) # → "a", then "b"
Execution order#
Operations are applied in the following logical order:
JoinOp — merge left and right tables using an equi-join key.
FilterOp — apply the
WHEREpredicate as a boolean mask (Compress) to all column tensors simultaneously.GroupByOp — record the group keys; referenced by aggregation expressions in the
SelectOp.SelectOp — compute output expressions (arithmetic, aggregations) over the filtered/joined columns.
inputs: col_a, col_b, col_key_left, col_key_right
│
├── JoinOp: Unsqueeze/Equal/ArgMax → Compress/Gather aligned columns
│
├── FilterOp: Greater/And/… → Compress (row mask applied to all cols)
│
├── GroupByOp: (records group columns, used by aggregations)
│
└── SelectOp: Add/ReduceSum/… → Identity → outputs
Parser design#
The parser (yobx.sql.parse) is a hand-written recursive-descent
parser using a single-pass tokeniser. No third-party SQL library is
required.
Tokens#
The tokeniser (yobx.sql.parse._tokenize()) classifies input characters
into four token kinds:
"num"— integer or floating-point literals."str"— single- or double-quoted string literals."op"— operators and punctuation (+,-,*,/,=,<,>,<=,>=,<>,!=,(,),,)."id"— identifiers (SQL keywords and column/table names).
All identifiers are lowercased so that parsing is case-insensitive.
Expression grammar#
primary ::= agg_func "(" expr ")"
| "(" expr ")"
| number
| string
| identifier
multiplicative ::= primary ( ("*" | "/") primary )*
additive ::= multiplicative ( ("+" | "-") multiplicative )*
expr ::= additive
comparison ::= expr ( "=" | "<" | ">" | "<=" | ">=" | "<>" | "!=" ) expr
and_pred ::= comparison ( "AND" comparison )*
condition ::= and_pred ( "OR" and_pred )*
Limitations and future work#
GROUP BYuses a whole-dataset aggregation (ReduceSumetc.) rather than true per-group aggregation. True per-group semantics require an ONNXLoopor a custom kernel.SELECT DISTINCTis parsed but raisesNotImplementedErrorduring conversion.Only equi-joins on a single key column are supported for
JOIN.HAVING,ORDER BY,LIMIT, and subqueries are not yet supported.String equality (
WHERE name = 'alice') is not yet supported (string literals are parsed but ONNXEqualon strings may need a separate handling path).
Example#
<<<
import numpy as np
from yobx.helpers.onnx_helper import pretty_onnx
from yobx.sql import sql_to_onnx, parse_sql
from yobx.reference import ExtendedReferenceEvaluator
# ── parse ────────────────────────────────────────────────────────
pq = parse_sql("SELECT a + b AS total FROM t WHERE a > 0")
for op in pq.operations:
print(type(op).__name__, "—", op)
# ── convert ──────────────────────────────────────────────────────
dtypes = {"a": np.float32, "b": np.float32}
onx = sql_to_onnx(
"SELECT a + b AS total FROM t WHERE a > 0",
dtypes,
)
print(pretty_onnx(onx))
# ── run ──────────────────────────────────────────────────────────
ref = ExtendedReferenceEvaluator(onx)
a = np.array([1.0, -2.0, 3.0], dtype=np.float32)
b = np.array([4.0, 5.0, 6.0], dtype=np.float32)
(total,) = ref.run(None, {"a": a, "b": b})
print(total) # → [5. 9.] (rows where a > 0)
>>>
FilterOp — FilterOp(condition=Condition(left=ColumnRef(column='a', table=None), op='>', right=Literal(value=0)))
SelectOp — SelectOp(items=[SelectItem(expr=BinaryExpr(left=ColumnRef(column='a', table=None), op='+', right=ColumnRef(column='b', table=None)), alias='total')], distinct=False)
opset: domain='' version=21
input: name='a' type=dtype('float32') shape=['N']
input: name='b' type=dtype('float32') shape=['N']
init: name='filter_mask_r_lit' type=int64 shape=(1,) -- array([0])
CastLike(filter_mask_r_lit, a) -> _onx_castlike_filter_mask_r_lit
Greater(a, _onx_castlike_filter_mask_r_lit) -> _onx_greater_a
Compress(a, _onx_greater_a, axis=0) -> _onx_compress_a
Compress(b, _onx_greater_a, axis=0) -> _onx_compress_b
Add(_onx_compress_a, _onx_compress_b) -> total
output: name='total' type='NOTENSOR' shape=None
[5. 9.]