Converting a TFLite/LiteRT model to ONNX#

yobx.litert.to_onnx() converts a TFLite/LiteRT .tflite model into an onnx.ModelProto that can be executed with any ONNX-compatible runtime.

The converter:

  1. Parses the binary FlatBuffer that every .tflite file uses with a pure-Python parser — no tensorflow or ai_edge_litert package is required.

  2. Walks every operator in the requested subgraph (default: subgraph 0).

  3. Converts each operator to its ONNX equivalent via a registry of op-level converters.

This example demonstrates the converter with a hand-crafted minimal TFLite model so that no external TFLite dependencies are needed to run the gallery.

import numpy as np
import onnxruntime
from yobx.doc import plot_dot
from yobx.litert import to_onnx
from yobx.litert.litert_helper import (
    _make_sample_tflite_model,
    parse_tflite_model,
    BuiltinOperator,
)

Build a minimal TFLite FlatBuffer#

_make_sample_tflite_model() returns the bytes of a minimal TFLite model with a single RELU operator (input: float32 [1, 4], output: float32 [1, 4]). In a real workflow you would pass the path to a .tflite file produced by the TFLite converter or downloaded from a model hub.

model_bytes = _make_sample_tflite_model()
print(f"Minimal TFLite model size: {len(model_bytes)} bytes")
Minimal TFLite model size: 352 bytes

Inspect the parsed model#

parse_tflite_model() returns a plain Python TFLiteModel object.

tflite_model = parse_tflite_model(model_bytes)
sg = tflite_model.subgraphs[0]

print(f"TFLite model version : {tflite_model.version}")
print(f"Number of subgraphs  : {len(tflite_model.subgraphs)}")
print(f"Number of tensors    : {len(sg.tensors)}")
print(f"Number of operators  : {len(sg.operators)}")
for i, op in enumerate(sg.operators):
    in_names = [sg.tensors[j].name for j in op.inputs if j >= 0]
    out_names = [sg.tensors[j].name for j in op.outputs if j >= 0]
    print(f"  Op {i}: {op.name}  inputs={in_names}  outputs={out_names}")
TFLite model version : 3
Number of subgraphs  : 1
Number of tensors    : 2
Number of operators  : 1
  Op 0: RELU  inputs=['input']  outputs=['relu']

Convert to ONNX#

yobx.litert.to_onnx() accepts the raw bytes (or a file path). We provide a dummy NumPy input so the converter can infer the input dtype and mark axis 0 as dynamic.

X = np.random.default_rng(0).standard_normal((1, 4)).astype(np.float32)

onx = to_onnx(model_bytes, (X,), input_names=["input"])

print(f"\nONNX opset version : {onx.opset_import[0].version}")
print(f"Number of nodes    : {len(onx.graph.node)}")
print(f"Node op-types      : {[n.op_type for n in onx.graph.node]}")
print(f"Graph input name   : {onx.graph.input[0].name}")
print(f"Graph output name  : {onx.graph.output[0].name}")
ONNX opset version : 21
Number of nodes    : 1
Node op-types      : ['Relu']
Graph input name   : input
Graph output name  : relu

Run and verify#

The ONNX model produces the same output as a manual np.maximum(X, 0).

sess = onnxruntime.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
(result,) = sess.run(None, {onx.graph.input[0].name: X})

expected = np.maximum(X, 0.0)
print(f"\nInput  : {X[0]}")
print(f"Expected (ReLU): {expected[0]}")
print(f"ONNX output    : {result[0]}")

assert np.allclose(expected, result, atol=1e-6), "Mismatch!"
print("Outputs match ✓")
Input  : [ 0.12573022 -0.13210486  0.64042264  0.10490011]
Expected (ReLU): [0.12573022 0.         0.64042264 0.10490011]
ONNX output    : [0.12573022 0.         0.64042264 0.10490011]
Outputs match ✓

Dynamic batch dimension#

By default to_onnx() marks axis 0 as dynamic.

batch_dim = onx.graph.input[0].type.tensor_type.shape.dim[0]
print(f"\nBatch dim param : {batch_dim.dim_param!r}")
assert batch_dim.dim_param, "Expected a dynamic batch dimension"

# Verify the converted model works for different batch sizes.
for n in (1, 3, 7):
    X_batch = np.random.default_rng(n).standard_normal((n, 4)).astype(np.float32)
    (out,) = sess.run(None, {onx.graph.input[0].name: X_batch})
    assert np.allclose(np.maximum(X_batch, 0), out, atol=1e-6), f"Mismatch for batch={n}"

print("Dynamic batch verified for batch sizes 1, 3, 7 ✓")
Batch dim param : 'batch'
Dynamic batch verified for batch sizes 1, 3, 7 ✓

Custom op converter#

Use extra_converters to override or extend the built-in converters. Here we replace RELU with Clip(0, 6) (i.e. Relu6).

def custom_relu6(g, sts, outputs, op):
    """Replace RELU with Clip(0, 6)."""
    return g.op.Clip(
        op.inputs[0],
        np.array(0.0, dtype=np.float32),
        np.array(6.0, dtype=np.float32),
        outputs=outputs,
        name="relu6",
    )


onx_custom = to_onnx(
    model_bytes,
    (X,),
    input_names=["input"],
    extra_converters={BuiltinOperator.RELU: custom_relu6},
)

custom_op_types = [n.op_type for n in onx_custom.graph.node]
print(f"\nOp-types with custom converter: {custom_op_types}")
assert "Clip" in custom_op_types, "Expected Clip node from custom converter"
assert "Relu" not in custom_op_types, "Relu should have been replaced"
print("Custom converter verified ✓")
Op-types with custom converter: ['Clip']
Custom converter verified ✓

Visualise the graph#

plot_dot(onx)
plot litert to onnx

Total running time of the script: (0 minutes 0.105 seconds)

Related examples

Gallery generated by Sphinx-Gallery