SQL-to-ONNX Converter Design#

Overview#

yobx.sql converts a SQL query string into a self-contained onnx.ModelProto.

The primary design principle is:

Every column referenced in the query is a distinct 1-D ONNX input tensor.

This mirrors the columnar representation used in tabular data pipelines: the caller feeds one vector per column rather than a single 2-D matrix. The ONNX model can therefore be applied to any subset of rows without copying the full table.

Architecture#

The conversion pipeline consists of two stages:

  1. Parsing (parse_sql()) — turns the SQL string into a ParsedQuery containing an ordered list of SqlOperation objects, one per SQL clause.

  2. Emission (sql_to_onnx()) — iterates over the operations in execution order and appends ONNX nodes to a GraphBuilder.

SQL string
    │
    ▼
parse_sql()         ─── ParsedQuery ──► operations list
                                          │
                                          ▼
sql_to_onnx()       ─── GraphBuilder ──► ModelProto

Supported SQL clauses#

Clause

SqlOperation

ONNX nodes emitted

SELECT expr [AS alias],

SelectOp

Identity, Add, Sub, Mul, Div, ReduceSum, ReduceMean, ReduceMin, ReduceMax

WHERE condition

FilterOp

Compress, Equal, Less, Greater, LessOrEqual, GreaterOrEqual, And, Or, Not

GROUP BY col,

GroupByOp

(groups are processed together with SelectOp aggregations)

[INNER|LEFT|RIGHT|FULL] JOIN ON col = col

JoinOp

Unsqueeze, Equal, ArgMax, ReduceMax, Compress, Gather

Columnar input convention#

Each column becomes a separate ONNX graph input whose name matches the column name in the SQL query. Column names are normalised to lower-case.

import numpy as np
from yobx.sql import sql_to_onnx

# Columns "a" and "b" → two separate ONNX inputs
dtypes = {"a": np.float32, "b": np.float32}
onx = sql_to_onnx(
    "SELECT a + b AS total FROM t WHERE a > 0",
    dtypes,
)

# Inputs of the ONNX model
for inp in onx.graph.input:
    print(inp.name)   # → "a", then "b"

Execution order#

Operations are applied in the following logical order:

  1. JoinOp — merge left and right tables using an equi-join key.

  2. FilterOp — apply the WHERE predicate as a boolean mask (Compress) to all column tensors simultaneously.

  3. GroupByOp — record the group keys; referenced by aggregation expressions in the SelectOp.

  4. SelectOp — compute output expressions (arithmetic, aggregations) over the filtered/joined columns.

inputs: col_a, col_b, col_key_left, col_key_right
    │
    ├── JoinOp:   Unsqueeze/Equal/ArgMax → Compress/Gather aligned columns
    │
    ├── FilterOp: Greater/And/… → Compress (row mask applied to all cols)
    │
    ├── GroupByOp: (records group columns, used by aggregations)
    │
    └── SelectOp: Add/ReduceSum/… → Identity → outputs

Parser design#

The parser (yobx.sql.parse) is a hand-written recursive-descent parser using a single-pass tokeniser. No third-party SQL library is required.

Tokens#

The tokeniser (yobx.sql.parse._tokenize()) classifies input characters into four token kinds:

  • "num" — integer or floating-point literals.

  • "str" — single- or double-quoted string literals.

  • "op" — operators and punctuation (+, -, *, /, =, <, >, <=, >=, <>, !=, (, ), ,).

  • "id" — identifiers (SQL keywords and column/table names).

All identifiers are lowercased so that parsing is case-insensitive.

Expression grammar#

primary  ::= agg_func "(" expr ")"
           | "(" expr ")"
           | number
           | string
           | identifier

multiplicative ::= primary ( ("*" | "/") primary )*
additive       ::= multiplicative ( ("+" | "-") multiplicative )*
expr           ::= additive

comparison ::= expr ( "=" | "<" | ">" | "<=" | ">=" | "<>" | "!=" ) expr
and_pred   ::= comparison ( "AND" comparison )*
condition  ::= and_pred ( "OR" and_pred )*

Limitations and future work#

  • GROUP BY uses a whole-dataset aggregation (ReduceSum etc.) rather than true per-group aggregation. True per-group semantics require an ONNX Loop or a custom kernel.

  • SELECT DISTINCT is parsed but raises NotImplementedError during conversion.

  • Only equi-joins on a single key column are supported for JOIN.

  • HAVING, ORDER BY, LIMIT, and subqueries are not yet supported.

  • String equality (WHERE name = 'alice') is not yet supported (string literals are parsed but ONNX Equal on strings may need a separate handling path).

Example#

<<<

import numpy as np
from yobx.helpers.onnx_helper import pretty_onnx
from yobx.sql import sql_to_onnx, parse_sql
from yobx.reference import ExtendedReferenceEvaluator

# ── parse ────────────────────────────────────────────────────────
pq = parse_sql("SELECT a + b AS total FROM t WHERE a > 0")
for op in pq.operations:
    print(type(op).__name__, "—", op)

# ── convert ──────────────────────────────────────────────────────
dtypes = {"a": np.float32, "b": np.float32}
onx = sql_to_onnx(
    "SELECT a + b AS total FROM t WHERE a > 0",
    dtypes,
)
print(pretty_onnx(onx))

# ── run ──────────────────────────────────────────────────────────
ref = ExtendedReferenceEvaluator(onx)
a = np.array([1.0, -2.0, 3.0], dtype=np.float32)
b = np.array([4.0, 5.0, 6.0], dtype=np.float32)
(total,) = ref.run(None, {"a": a, "b": b})
print(total)  # → [5.  9.]   (rows where a > 0)

>>>

    FilterOp — FilterOp(condition=Condition(left=ColumnRef(column='a', table=None), op='>', right=Literal(value=0)))
    SelectOp — SelectOp(items=[SelectItem(expr=BinaryExpr(left=ColumnRef(column='a', table=None), op='+', right=ColumnRef(column='b', table=None)), alias='total')], distinct=False)
    opset: domain='' version=21
    input: name='a' type=dtype('float32') shape=['N']
    input: name='b' type=dtype('float32') shape=['N']
    init: name='filter_mask_r_lit' type=int64 shape=(1,) -- array([0])
    CastLike(filter_mask_r_lit, a) -> _onx_castlike_filter_mask_r_lit
      Greater(a, _onx_castlike_filter_mask_r_lit) -> _onx_greater_a
        Compress(a, _onx_greater_a, axis=0) -> _onx_compress_a
    Compress(b, _onx_greater_a, axis=0) -> _onx_compress_b
      Add(_onx_compress_a, _onx_compress_b) -> total
    output: name='total' type='NOTENSOR' shape=None
    [5. 9.]