yobx.xshape.cost_inference#

Per-op FLOPs (floating-point operation) estimators for ONNX nodes.

The public entry point is estimate_node_flops(). All per-op helper functions are private (_flops_*) and dispatched via _OP_HANDLERS.

All estimators use the symbolic-dimension helpers from yobx.xexpressions.operations so that symbolic (dynamic) dimensions are handled correctly. When the shapes are partially or fully unknown the estimators return None.

Integration with BasicShapeBuilder: use BasicShapeBuilder.estimate_node_flops() to estimate the cost of a node using the shapes already inferred during a run_model() call.

yobx.xshape.cost_inference.estimate_node_flops(node: NodeProto, shape_fn: Callable[[str], Tuple[int | str, ...] | None], literal_fn: Callable[[str], Tuple[int | str, ...] | None]) int | str | None[source]#

Estimates the number of floating-point operations for a single ONNX node.

Returns None when the shapes are not fully known (dynamic shapes) or the op_type is not covered.

Parameters:
  • node – ONNX node

  • shape_fn – callable mapping tensor name → shape tuple (from shape inference)

  • literal_fn – callable mapping tensor name → int-value tuple for 1-D integer constant tensors (shape specification tensors); used as a fallback when shape_fn cannot resolve a shape

Returns:

estimated number of FLOPs, or None

yobx.xshape.cost_inference.list_op_cost_formulas() Dict[str, str][source]#

Returns a mapping from each supported ONNX op_type to the symbolic FLOPs expression produced by estimate_node_flops() on a representative test case from the ONNX backend test suite.

For every single-node model found in the ONNX backend test data directory the static input dimensions are replaced by symbolic variables (DIM<n>) using replace_static_dimensions_by_strings(). BasicShapeBuilder is then run with inference=InferenceMode.COST to obtain the symbolic FLOPs expression.

Only the first passing test case per op_type is kept. Operators with no matching backend test case, or whose cost cannot be inferred symbolically, are omitted.

Returns:

{op_type: symbolic_flops_expression} sorted alphabetically.