Computation Cost Inference#

Overview#

Knowing the computational cost of an ONNX model — expressed as a count of floating-point operations (FLOPs) — is useful for model comparison, profiling, and guiding optimization decisions.

yobx.xshape.cost_inference implements a lightweight per-operator FLOPs estimator. The main entry point is estimate_node_flops(), which accepts a single ONNX node together with two callables that resolve tensor shapes and integer literals, and returns the estimated FLOPs count.

When input shapes contain symbolic dimensions (strings such as "batch" or "seq"), the returned value is a symbolic arithmetic expression (also a string) that can be evaluated once concrete shapes are known. Static shapes yield plain integer counts.

Integration with BasicShapeBuilder#

BasicShapeBuilder integrates cost inference through the run_model() method. Pass inference=InferenceMode.COST to enable it:

from yobx.xshape import BasicShapeBuilder, InferenceMode

builder = BasicShapeBuilder()
cost_list = builder.run_model(model, inference=InferenceMode.COST)
# cost_list: list of (op_type, flops, node) tuples

Each element of cost_list is a (op_type, flops, node) triple where flops is either an integer, a symbolic string expression, or None (unsupported op or unknown shapes).

To substitute concrete dimension values and obtain integer FLOPs counts, call evaluate_cost_with_true_inputs() with the actual input tensors:

import numpy as np

feeds = {"X": np.random.randn(2, 64, 64).astype("float32")}
concrete = builder.evaluate_cost_with_true_inputs(feeds, cost_list)
total = sum(f or 0 for _, f, _ in concrete)
print(f"Total FLOPs: {total:,}")

For a full worked example including a before/after optimization comparison see Symbolic Cost of a Model: Attention Block.

How FLOPs Are Counted#

The estimator assigns FLOPs to operators using a set of simple, well-established counting conventions. Operators are partitioned into groups, each governed by a uniform formula:

Element-wise unary operators (Relu, Sigmoid, Exp, Sqrt, …)

1 FLOPs per output element. For Sigmoid specifically the formula accounts for the exp + add + div decomposition: 3 FLOPs per element. For Softmax / LogSoftmax: 3 FLOPs per element.

Element-wise binary operators (Add, Mul, Sub, Div, …)

1 FLOPs per output element.

Matrix multiplication (MatMul)

For inputs of shape (..., M, K) and (..., K, N) the formula is:

\text{FLOPs} = 2 \times \prod(\text{batch dims}) \times M \times K \times N

The factor of 2 accounts for one multiply-accumulate per inner-product step.

General matrix multiply (Gemm)

For an alpha * A @ B + beta * C operation with shapes (M, K) and (K, N):

\text{FLOPs} = 2 \times M \times K \times N + M \times N

The additional M*N term models the bias addition.

Convolution (Conv / ConvTranspose)

For output shape (N, C_out, *spatial_out) and weight shape (C_out, C_in_per_group, *kernel):

\text{FLOPs} = 2 \times N \times C_{out} \times C_{in/group}
               \times \prod(\text{kernel}) \times \prod(\text{spatial\_out})

Windowed pooling (MaxPool / AveragePool)

\text{FLOPs} = N \times C \times \prod(\text{spatial\_out})
               \times \prod(\text{kernel\_shape})

Global pooling (GlobalAveragePool / GlobalMaxPool)

\text{FLOPs} = N \times C \times \prod(\text{spatial dims of input})

Normalization (BatchNormalization)

mean + var + normalise ≈ 2 FLOPs per output element.

Normalization (LayerNormalization / GroupNormalization / InstanceNormalization)

mean + var + sub + div + scale + bias ≈ 6 FLOPs per output element.

Reduction operators (ReduceSum, ReduceMean, ReduceMax, …)

1 FLOPs per input element (one comparison or accumulation step).

Recurrent cells

  • LSTM2 * seq * batch * (input_size + hidden) * 4 * hidden

  • GRU2 * seq * batch * (input_size + hidden) * 3 * hidden

  • RNN2 * seq * batch * (input_size + hidden) * hidden

Data-movement operators (Cast, Transpose, Gather, Pad, …)

Element count of the first output tensor (one read + one write per element).

Shape-manipulation operators (Reshape, Squeeze, Unsqueeze)

Rank of the output tensor (one metadata operation per dimension). Shape uses the rank of the input tensor.

Zero-cost operators (Identity)

0 FLOPs — these are purely logical copies that a compiler can eliminate.

Operators not covered by any of the above categories return None.

Programmatic Formula Listing#

The function list_op_cost_formulas() returns a sorted dictionary mapping every supported op_type to the symbolic FLOPs expression that the estimator produces on a representative ONNX backend test example. All static input dimensions of that example are replaced by symbolic variables (DIM<n>) before running the estimator, so the result shows the general formula rather than a single concrete number.

The complete table of supported operators and their symbolic FLOPs formulas is generated below by calling list_op_cost_formulas() at documentation-build time:

<<<

from yobx.xshape import list_op_cost_formulas

formulas = list_op_cost_formulas()

rows = [
    ".. list-table::",
    "   :header-rows: 1",
    "   :widths: 30 70",
    "",
    "   * - Op type",
    "     - Symbolic FLOPs",
]
for op_type, formula in formulas.items():
    rows.append(f"   * - ``{op_type}``")
    rows.append(f"     - ``{formula}``")

print("\n".join(rows))

>>>

Op type

Symbolic FLOPs

Abs

DIM3*DIM4*DIM5

Acos

DIM3*DIM4*DIM5

Acosh

DIM3*DIM4*DIM5

Add

DIM3*DIM4*DIM5

And

DIM3*DIM4

Asin

DIM3*DIM4*DIM5

Asinh

DIM3*DIM4*DIM5

Atan

DIM3*DIM4*DIM5

Atanh

DIM3*DIM4*DIM5

BatchNormalization

2*DIM2*DIM3*DIM4*DIM5

BitShift

DIM3

Cast

DIM3*DIM4

CastLike

DIM3*DIM4

Ceil

DIM3*DIM4*DIM5

Celu

DIM1*DIM3*DIM3*DIM3

Concat

2*DIM2

Constant

25

ConstantOfShape

DIM3

Conv

2*DIM1*DIM1*DIM1*DIM3*DIM3*conv_f3_0(DIM5,3,1)*conv_f3_0(DIM5,3,1)

Cos

DIM3*DIM4*DIM5

Cosh

DIM3*DIM4*DIM5

Div

DIM3*DIM4*DIM5

Elu

DIM3*DIM4*DIM5

Equal

DIM3*DIM4*DIM5

Erf

DIM1*DIM3*DIM32*DIM32

Exp

DIM3*DIM4*DIM5

Expand

dim0_data*dim1_data

Flatten

dim0_a*dim1_a*dim2_a*dim3_a

Floor

DIM3*DIM4*DIM5

GRU

2*DIM1*DIM18*(DIM18//3+DIM2)*DIM3

Gather

DIM2*DIM3*DIM3*DIM4

GatherElements

DIM2*DIM2

GatherND

DIM2*DIM2*DIM2

Gemm

2*DIM3*DIM4*DIM5+DIM3*DIM5

GlobalAveragePool

dim0_x*dim1_x*dim2_x*dim3_x

GlobalMaxPool

dim0_x*dim1_x*dim2_x*dim3_x

Greater

DIM3*DIM4*DIM5

GreaterOrEqual

DIM3*DIM4*DIM5

HardSigmoid

DIM3*DIM4*DIM5

HardSwish

DIM3*DIM4*DIM5

Identity

0

InstanceNormalization

6*DIM2*DIM3*DIM4*DIM5

LSTM

2*DIM1*(DIM2+DIM28//4)*DIM28*DIM3

LayerNormalization

6*DIM3*DIM4

LeakyRelu

DIM3*DIM4*DIM5

Less

DIM3*DIM4*DIM5

LessOrEqual

DIM3*DIM4*DIM5

Log

DIM3*DIM4*DIM5

LogSoftmax

3*DIM3*DIM4*DIM5

MatMul

2*DIM3*DIM3*DIM4

MaxPool

2*DIM1*DIM3*conv_f3_0(DIM32,2,1)

Mish

DIM10000

Mod

DIM2*DIM3*(DIM5^DIM1)

Mul

DIM3*DIM4*DIM5

Neg

DIM3*DIM4*DIM5

Not

DIM3*DIM4

OneHot

DIM3

Or

DIM3*DIM4

PRelu

DIM3*DIM4*DIM5

Pad

DIM1*DIM3*DIM4*DIM5

Pow

DIM3*DIM4*DIM5

RNN

2*DIM2*DIM3*(DIM3+DIM5)*DIM5

ReduceL1

dim0_data*dim1_data*dim2_data

ReduceL2

dim0_data*dim1_data*dim2_data

ReduceLogSum

dim0_data*dim1_data*dim2_data

ReduceLogSumExp

dim0_data*dim1_data*dim2_data

ReduceMax

dim0_data*dim1_data

ReduceMean

dim0_data*dim1_data*dim2_data

ReduceMin

dim0_data*dim1_data

ReduceProd

dim0_data*dim1_data*dim2_data

ReduceSum

dim0_data*dim1_data*dim2_data

ReduceSumSquare

dim0_data*dim1_data*dim2_data

Relu

DIM3*DIM4*DIM5

Round

DIM15

Scatter

DIM1*DIM5

ScatterElements

DIM1*DIM5

ScatterND

DIM4*DIM4*DIM4

Selu

DIM3*DIM4*DIM5

Shape

3

Shrink

DIM5

Sigmoid

3*DIM3*DIM4*DIM5

Sign

DIM11

Sin

DIM3*DIM4*DIM5

Sinh

DIM3*DIM4*DIM5

Slice

DIM10*DIM20*DIM5

Softmax

3*DIM3*DIM4*DIM5

Softplus

DIM3*DIM4*DIM5

Softsign

DIM3*DIM4*DIM5

Split

(3+DIM7)//4

Sqrt

DIM3*DIM4*DIM5

Sub

DIM3*DIM4*DIM5

Tan

DIM3*DIM4*DIM5

Tanh

DIM3*DIM4*DIM5

ThresholdedRelu

DIM3*DIM4*DIM5

Tile

DIM2*DIM3*DIM4*DIM5

Transpose

DIM2*DIM3*DIM4

Xor

DIM3*DIM4

See also the gallery example Computation Cost: How It Works and Supported Operator Formulas for additional context and worked examples.