How float format has an impact on speed computation

An example with Conv. The floats followed the IEEE standard Single-precision floating-point format. The number is interprated in a different whether the exponent is null or not. When it is null, it is called a denormalized number or subnormal number. Let’s see their impact on the computation time through the operator Conv.

Create one model

import struct
import matplotlib.pyplot as plt
from pandas import DataFrame
from tqdm import tqdm
import numpy
from onnx import TensorProto
from onnx.helper import (
    make_model,
    make_node,
    make_graph,
    make_tensor_value_info,
    make_opsetid,
)
from onnx.checker import check_model
from onnx.numpy_helper import to_array, from_array
from onnxruntime import (
    InferenceSession,
    get_available_providers,
    OrtValue,
    SessionOptions,
    GraphOptimizationLevel,
)
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from onnx_extended.ext_test_case import measure_time, unit_test_going
from onnx_extended.reference import CReferenceEvaluator

try:
    import torch
except ImportError:
    # no torch is available
    print("torch is not available")
    torch = None

DIM = 64 if unit_test_going() else 256


def _denorm(x):
    i = int.from_bytes(struct.pack("<f", numpy.float32(x)), "little")
    i &= 0x807FFFFF
    return numpy.uint32(i).view(numpy.float32)


denorm = numpy.vectorize(_denorm)


def create_model():
    X = make_tensor_value_info("X", TensorProto.FLOAT, [1, DIM, 14, 14])
    Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None, None])
    B = from_array(numpy.zeros([DIM], dtype=numpy.float32), name="B")
    w = numpy.random.randn(DIM, DIM, 3, 3).astype(numpy.float32)

    # let's randomly denormalize some number
    mask = (numpy.random.randint(2, size=w.shape) % 2).astype(numpy.float32)
    d = denorm(w)
    w = w * mask - (mask - 1) * d
    W = from_array(w, name="W")

    node1 = make_node(
        "Conv", ["X", "W", "B"], ["Y"], kernel_shape=[3, 3], pads=[1, 1, 1, 1]
    )
    graph = make_graph([node1], "lr", [X], [Y], [W, B])
    onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)], ir_version=8)
    check_model(onnx_model)
    return onnx_model


onx = create_model()
onnx_file = "plot_op_conv_denorm.onnx"
with open(onnx_file, "wb") as f:
    f.write(onx.SerializeToString())

The model looks like:

print(onnx_simple_text_plot(onx))

onnx_model = onnx_file
input_shape = (1, DIM, 14, 14)
opset: domain='' version=18
input: name='X' type=dtype('float32') shape=[1, 256, 14, 14]
init: name='W' type=dtype('float32') shape=(256, 256, 3, 3)
init: name='B' type=dtype('float32') shape=(256,)
Conv(X, W, B, kernel_shape=[3,3], pads=[1,1,1,1]) -> Y
output: name='Y' type=dtype('float32') shape=['', '', '', '']

CReferenceEvaluator and InferenceSession

Let’s first compare the outputs are the same.

sess_options = SessionOptions()
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL


sess1 = CReferenceEvaluator(onnx_model)
sess2 = InferenceSession(onnx_model, sess_options, providers=["CPUExecutionProvider"])

X = numpy.ones(input_shape, dtype=numpy.float32)

expected = sess1.run(None, {"X": X})[0]
got = sess2.run(None, {"X": X})[0]
diff = numpy.abs(expected - got).max()
print(f"difference: {diff}")
difference: 2.6702880859375e-05

Everything works fine.

Time measurement

feeds = {"X": X}

t1 = measure_time(lambda: sess1.run(None, feeds), repeat=2, number=5)
print(f"CReferenceEvaluator: {t1['average']}s")

t2 = measure_time(lambda: sess2.run(None, feeds), repeat=2, number=5)
print(f"InferenceSession: {t2['average']}s")
CReferenceEvaluator: 0.0893450799999755s
InferenceSession: 0.08160096000001432s

Plotting

Let’s modify the the weight of the model and multiply everything by a scalar. Let’s choose an random input.

has_cuda = "CUDAExecutionProvider" in get_available_providers()
X = numpy.random.random(X.shape).astype(X.dtype)


def modify(onx, scale):
    t = to_array(onx.graph.initializer[0])
    b = to_array(onx.graph.initializer[1]).copy()
    t = (t * scale).astype(numpy.float32)
    graph = make_graph(
        onx.graph.node,
        onx.graph.name,
        onx.graph.input,
        onx.graph.output,
        [from_array(t, name=onx.graph.initializer[0].name), onx.graph.initializer[1]],
    )
    model = make_model(graph, opset_imports=onx.opset_import, ir_version=onx.ir_version)
    return t, b, model


scales = [2**p for p in range(0, 31, 2)]
data = []
feeds = {"X": X}
expected = sess2.run(None, feeds)[0]
if torch is not None:
    tx = torch.from_numpy(X)

sess_options0 = SessionOptions()
sess_options0.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
sess_options0.add_session_config_entry("session.set_denormal_as_zero", "1")

for scale in tqdm(scales):
    w, b, new_onx = modify(onx, scale)
    n_denorm = (w == denorm(w)).astype(numpy.int32).sum() / w.size

    # sess1 = CReferenceEvaluator(new_onx)
    sess2 = InferenceSession(
        new_onx.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
    )
    sess3 = InferenceSession(
        new_onx.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    sess4 = InferenceSession(
        new_onx.SerializeToString(), sess_options0, providers=["CPUExecutionProvider"]
    )

    # sess1.run(None, feeds)
    got = sess2.run(None, feeds)[0]
    diff = numpy.abs(got / scale - expected).max()
    sess3.run(None, feeds)
    got0 = sess4.run(None, feeds)[0]
    diff0 = numpy.abs(got0 / scale - expected).max()

    # t1 = measure_time(lambda: sess1.run(None, feeds), repeat=2, number=5)
    t2 = measure_time(lambda: sess2.run(None, feeds), repeat=2, number=5)
    t3 = measure_time(lambda: sess3.run(None, feeds), repeat=2, number=5)
    t4 = measure_time(lambda: sess4.run(None, feeds), repeat=2, number=5)
    obs = dict(
        scale=scale,
        ort=t2["average"],
        diff=diff,
        diff0=diff0,
        ort0=t4["average"],
        n_denorm=n_denorm,
    )
    # obs["ref"]=t1["average"]
    obs["ort-opt"] = t3["average"]

    if torch is not None:
        tw = torch.from_numpy(w)
        tb = torch.from_numpy(b)
        torch.nn.functional.conv2d(tx, tw, tb, padding=1)
        t3 = measure_time(
            lambda: torch.nn.functional.conv2d(tx, tw, tb, padding=1),
            repeat=2,
            number=5,
        )
        obs["torch"] = t3["average"]

    if has_cuda:
        sess2 = InferenceSession(
            new_onx.SerializeToString(),
            sess_options,
            providers=["CUDAExecutionProvider"],
        )
        sess3 = InferenceSession(
            new_onx.SerializeToString(), providers=["CUDAExecutionProvider"]
        )
        x_ortvalue = OrtValue.ortvalue_from_numpy(X, "cuda", 0)
        cuda_feeds = {"X": x_ortvalue}
        sess2.run_with_ort_values(None, cuda_feeds)
        sess3.run_with_ort_values(None, cuda_feeds)
        t2 = measure_time(lambda: sess2.run(None, cuda_feeds), repeat=2, number=5)
        t3 = measure_time(lambda: sess3.run(None, cuda_feeds), repeat=2, number=5)
        obs["ort-cuda"] = t2["average"]
        obs["ort-cuda-opt"] = t2["average"]

    data.append(obs)
    if unit_test_going() and len(data) >= 2:
        break

df = DataFrame(data)
df
  0%|          | 0/16 [00:00<?, ?it/s]
  6%|▋         | 1/16 [00:10<02:31, 10.13s/it]
 12%|█▎        | 2/16 [00:14<01:35,  6.83s/it]
 19%|█▉        | 3/16 [00:17<01:04,  4.94s/it]
 25%|██▌       | 4/16 [00:19<00:44,  3.71s/it]
 31%|███▏      | 5/16 [00:20<00:31,  2.91s/it]
 38%|███▊      | 6/16 [00:22<00:24,  2.41s/it]
 44%|████▍     | 7/16 [00:23<00:18,  2.08s/it]
 50%|█████     | 8/16 [00:24<00:14,  1.87s/it]
 56%|█████▋    | 9/16 [00:26<00:12,  1.75s/it]
 62%|██████▎   | 10/16 [00:27<00:10,  1.70s/it]
 69%|██████▉   | 11/16 [00:29<00:08,  1.61s/it]
 75%|███████▌  | 12/16 [00:30<00:06,  1.55s/it]
 81%|████████▏ | 13/16 [00:32<00:04,  1.55s/it]
 88%|████████▊ | 14/16 [00:33<00:02,  1.49s/it]
 94%|█████████▍| 15/16 [00:35<00:01,  1.49s/it]
100%|██████████| 16/16 [00:36<00:00,  1.50s/it]
100%|██████████| 16/16 [00:36<00:00,  2.29s/it]
scale ort diff diff0 ort0 n_denorm ort-opt torch ort-cuda ort-cuda-opt
0 1 0.079767 0.0 0.0 0.070050 0.500331 0.140657 0.140096 0.002372 0.002372
1 4 0.028542 0.0 0.0 0.025135 0.161274 0.107145 0.104200 0.002346 0.002346
2 16 0.009844 0.0 0.0 0.010231 0.043706 0.046657 0.044319 0.002480 0.002480
3 64 0.004837 0.0 0.0 0.005221 0.011300 0.015556 0.015679 0.002366 0.002366
4 256 0.003366 0.0 0.0 0.002325 0.002769 0.004922 0.005104 0.002394 0.002394
5 1024 0.003479 0.0 0.0 0.002366 0.000676 0.002414 0.002713 0.002392 0.002392
6 4096 0.002954 0.0 0.0 0.001911 0.000166 0.001641 0.002093 0.002419 0.002419
7 16384 0.002996 0.0 0.0 0.002111 0.000034 0.001219 0.001795 0.002356 0.002356
8 65536 0.002186 0.0 0.0 0.001881 0.000007 0.001379 0.001502 0.002428 0.002428
9 262144 0.003005 0.0 0.0 0.002413 0.000002 0.001323 0.002207 0.002415 0.002415
10 1048576 0.002261 0.0 0.0 0.002653 0.000000 0.001460 0.001618 0.002283 0.002283
11 4194304 0.003134 0.0 0.0 0.001905 0.000000 0.001497 0.001737 0.002545 0.002545
12 16777216 0.003337 0.0 0.0 0.003057 0.000000 0.002312 0.002766 0.002262 0.002262
13 67108864 0.002172 0.0 0.0 0.002091 0.000000 0.001579 0.001828 0.002353 0.002353
14 268435456 0.001799 0.0 0.0 0.002009 0.000000 0.001608 0.001875 0.002493 0.002493
15 1073741824 0.003875 0.0 0.0 0.002095 0.000000 0.001431 0.002047 0.002421 0.002421


Finally.

dfp = df.drop(["diff", "diff0", "n_denorm"], axis=1).set_index("scale")
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
dfp.plot(ax=ax[0], logx=True, logy=True, title="Comparison of Conv processing time")
df[["n_denorm"]].plot(
    ax=ax[1], logx=True, logy=True, title="Ratio of denormalized numbers"
)

fig.tight_layout()
fig.savefig("plot_op_conv_denorm.png")
# plt.show()
Comparison of Conv processing time, Ratio of denormalized numbers

Conclusion

Denormalized numbers should be avoided.

Total running time of the script: (0 minutes 41.074 seconds)

Gallery generated by Sphinx-Gallery