LayerNormalization implementation cannot be exchanged

This example applies what was illustrated Reproducible Parallelized Reduction is difficult, reduction operations are sensitive to parallelization.

We consider a small model including a layer normalization followed by a matrix multiplication and we show that replacing a kernel by another one may significantly impact the output.

The model

import itertools
import pandas
import onnx
import onnx.helper as oh
import onnxruntime
import torch
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic.ext_test_case import unit_test_going
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name, onnx_dtype_to_np_dtype
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
from onnx_diagnostic.helpers.doc_helper import LayerNormalizationOrt, MatMulOrt
from onnx_diagnostic.reference import TorchOnnxEvaluator

TFLOAT = onnx.TensorProto.FLOAT
TFLOAT16 = onnx.TensorProto.FLOAT16


def get_model(itype: int = TFLOAT16):
    return oh.make_model(
        oh.make_graph(
            [
                oh.make_node("LayerNormalization", ["X", "scale", "bias"], ["norm"], axis=-1),
                oh.make_node("MatMul", ["norm", "weights"], ["mm"]),
                oh.make_node("Add", ["mm", "bias2"], ["Z"]),
            ],
            "layer_norm_matmul_add",
            [
                oh.make_tensor_value_info("X", itype, ["a", "b", "c"]),
                oh.make_tensor_value_info("scale", itype, ["c"]),
                oh.make_tensor_value_info("bias", itype, ["c"]),
                oh.make_tensor_value_info("weights", itype, ["c", "c"]),
                oh.make_tensor_value_info("bias2", itype, ["c"]),
            ],
            [oh.make_tensor_value_info("Z", itype, ["a", "b", "c"])],
        ),
        ir_version=9,
        opset_imports=[oh.make_opsetid("", 18)],
    )


model = get_model()
plot_dot(model)
plot layer norm discrepancies

Let’s compare two runtimes

That will be onnxruntime and onnx_diagnostic.reference.TorchOnnxEvaluator.

last_dim = 64 if unit_test_going() else 1152


def make_feeds(last_dim: int):
    return {
        "X": (torch.rand((32, 1024, last_dim), dtype=torch.float16) - 0.5) * 120,
        "scale": torch.rand((last_dim,), dtype=torch.float16),
        "bias": torch.rand((last_dim,), dtype=torch.float16),
        "weights": torch.rand((last_dim, last_dim), dtype=torch.float16),
        "bias2": torch.rand((last_dim,), dtype=torch.float16),
    }


def cast_feeds(itype, provider, feeds):
    np_feeds = {k: v.detach().numpy() for k, v in feeds.items()}
    if provider == "CUDA":
        if not torch.cuda.is_available():
            return None, None
        tch_feeds = {k: v.to("cuda") for k, v in feeds.items()}
        ort_feeds = np_feeds
    else:
        tch_feeds = feeds.copy()
        tch_feeds["X"] = tch_feeds["X"][:2]  # too long otherwise
        ort_feeds = np_feeds.copy()
        ort_feeds["X"] = ort_feeds["X"][:2]
    tch_feeds = {k: v.to(ttype) for k, v in tch_feeds.items()}
    ort_feeds = {k: v.astype(np_dtype) for k, v in ort_feeds.items()}
    return tch_feeds, ort_feeds


feeds = make_feeds(last_dim)
kws = dict(with_shape=True, with_min_max=True, with_device=True)
data = []
baseline = {}

for provider, itype in itertools.product(["CPU", "CUDA"], [TFLOAT, TFLOAT16]):
    ttype = onnx_dtype_to_torch_dtype(itype)
    np_dtype = onnx_dtype_to_np_dtype(itype)
    tch_feeds, ort_feeds = cast_feeds(itype, provider, feeds)
    if tch_feeds is None:
        continue

    model = get_model(itype)
    print()
    print(f"-- running on {provider} with {onnx_dtype_name(itype)}")
    print("-- running with torch")
    torch_sess = TorchOnnxEvaluator(model, providers=[f"{provider}ExecutionProvider"])
    expected = torch_sess.run(None, tch_feeds)
    baseline[itype, provider, "torch"] = expected
    print(f"-- torch: {string_type(expected, **kws)}")

    print("-- running with ort")
    ort_sess = onnxruntime.InferenceSession(
        model.SerializeToString(), providers=[f"{provider}ExecutionProvider"]
    )
    got = ort_sess.run(None, ort_feeds)
    baseline[itype, provider, "ort"] = got
    print(f"-- ort: {string_type(got, **kws)}")
    diff = max_diff(expected, got, hist=True)
    print(f"-- diff {string_diff(diff)}")

    # memorize the data
    diff["dtype"] = onnx_dtype_name(itype)
    diff["provider"] = provider
    diff.update(diff["rep"])
    del diff["rep"]
    del diff["dnan"]
    del diff[">100.0"]
    del diff[">10.0"]
    data.append(diff)
-- running on CPU with FLOAT
-- running with torch
-- torch: #1[CT1s2x1024x1152[240.0281982421875,333.2430725097656:A285.0493383403518]]
-- running with ort
-- ort: #1[A1s2x1024x1152[240.02822875976562,333.2430419921875:A285.0493383583897]]
-- diff abs=0.000152587890625, rel=5.568541537966721e-07, n=2359296.0/#1484455>0.0-#2032>0.0001

-- running on CPU with FLOAT16
-- running with torch
-- torch: #1[CT10s2x1024x1152[240.0,333.25:A285.05135504404706]]
-- running with ort
-- ort: #1[A10s2x1024x1152[240.0,333.25:A285.0493660502964]]
-- diff abs=0.25, rel=0.0009765586853176355, n=2359296.0/#585826>0.0-#585826>0.0001-#585826>0.001-#585826>0.01-#585826>0.1

-- running on CUDA with FLOAT
-- running with torch
-- torch: #1[GT1s32x1024x1152[234.32650756835938,335.265869140625:A285.04125916001027]]
-- running with ort
-- ort: #1[A1s32x1024x1152[234.3227996826172,335.26470947265625:A285.03857064794204]]
-- diff abs=0.019561767578125, rel=6.741366235333735e-05, n=37748736.0/#37648559>0.0-#37048890>0.0001-#31279193>0.001-#522757>0.01

-- running on CUDA with FLOAT16
-- running with torch
-- torch: #1[GT10s32x1024x1152[234.375,335.25:A285.04108573993045]]
-- running with ort
-- ort: #1[A10s32x1024x1152[234.375,335.25:A285.0410857167509]]
-- diff abs=0.5, rel=0.00187969218160834, n=37748736.0/#1601>0.0-#1601>0.0001-#1601>0.001-#1601>0.01-#1601>0.1
df = pandas.DataFrame(data).set_index(["provider", "dtype"])
print(df)
                       abs           rel            sum  ...   >0.01    >0.1  >1.0
provider dtype                                           ...
CPU      FLOAT    0.000153  5.568542e-07      57.637527  ...       0       0     0
         FLOAT16  0.250000  9.765587e-04  146385.375000  ...  585826  585826     0
CUDA     FLOAT    0.019562  6.741366e-05  130786.453049  ...  522757       0     0
         FLOAT16  0.500000  1.879692e-03     400.375000  ...    1601    1601     0

[4 rows x 10 columns]

Visually.

df["abs"].plot.bar(title="Discrepancies ORT / torch for LayerNorm(X) @ W + B")
Discrepancies ORT / torch for LayerNorm(X) @ W + B

The discrepancies are significant on CUDA, higher for float16. Let’s see which operator is responsible for them, LayerNormalization or MatMul.

The discrepancies come from?

We mix torch and onnxruntime to execute the kernels.

data = []

for mod, provider, itype in itertools.product(
    ["ORT-TORCH", "TORCH-ORT"], ["CPU", "CUDA"], [TFLOAT, TFLOAT16]
):
    ttype = onnx_dtype_to_torch_dtype(itype)
    np_dtype = onnx_dtype_to_np_dtype(itype)
    tch_feeds, _ = cast_feeds(itype, provider, feeds)
    if tch_feeds is None:
        continue

    custom_kernels = (
        {("", "LayerNormalization"): LayerNormalizationOrt}
        if mod == "ORT-TORCH"
        else {("", "MatMul"): MatMulOrt}
    )

    model = get_model(itype)
    print()
    print(f"-- {mod} running on {provider} with {onnx_dtype_name(itype)}")
    sess = TorchOnnxEvaluator(
        model,
        custom_kernels=custom_kernels,
        providers=[f"{provider}ExecutionProvider"],
    )
    got = sess.run(None, tch_feeds)
    print(f"-- {mod}: {string_type(got, **kws)}")

    difft = max_diff(baseline[itype, provider, "torch"], got)
    print(f"-- diff with torch {string_diff(difft)}")
    diffo = max_diff(baseline[itype, provider, "ort"], got)
    print(f"-- diff with ort {string_diff(diffo)}")

    data.append(
        dict(
            model=mod,
            dtype=onnx_dtype_name(itype),
            provider=provider,
            diff_ort=diffo["abs"],
            diff_torch=difft["abs"],
        )
    )
-- ORT-TORCH running on CPU with FLOAT
-- ORT-TORCH: #1[CT1s2x1024x1152[240.02818298339844,333.2430725097656:A285.04933834518954]]
-- diff with torch abs=9.1552734375e-05, rel=3.225682013818846e-07, n=2359296.0
-- diff with ort abs=0.000152587890625, rel=5.813544669095066e-07, n=2359296.0

-- ORT-TORCH running on CPU with FLOAT16
-- ORT-TORCH: #1[CT10s2x1024x1152[240.0,333.25:A285.0513520770603]]
-- diff with torch abs=0.5, rel=0.001848422002136776, n=2359296.0
-- diff with ort abs=0.25, rel=0.0009765586853176355, n=2359296.0

-- ORT-TORCH running on CUDA with FLOAT
-- ORT-TORCH: #1[GT1s32x1024x1152[234.32652282714844,335.2658386230469:A285.04125915866217]]
-- diff with torch abs=0.0001220703125, rel=4.5536574584636475e-07, n=37748736.0
-- diff with ort abs=0.019561767578125, rel=6.741820726159815e-05, n=37748736.0

-- ORT-TORCH running on CUDA with FLOAT16
-- ORT-TORCH: #1[GT10s32x1024x1152[234.375,335.25:A285.04108571675084]]
-- diff with torch abs=0.5, rel=0.00187969218160834, n=37748736.0
-- diff with ort abs=0, rel=0

-- TORCH-ORT running on CPU with FLOAT
-- TORCH-ORT: #1[CT1s2x1024x1152[240.0251007080078,333.2402648925781:A285.04672243771427]]
-- diff with torch abs=0.018798828125, rel=6.496689800285014e-05, n=2359296.0
-- diff with ort abs=0.018798828125, rel=6.496689800285014e-05, n=2359296.0

-- TORCH-ORT running on CPU with FLOAT16
-- TORCH-ORT: #1[CT10s2x1024x1152[240.0,333.25:A285.04913934071857]]
-- diff with torch abs=0.5, rel=0.0018867853328855364, n=2359296.0
-- diff with ort abs=0.25, rel=0.0009765586853176355, n=2359296.0

-- TORCH-ORT running on CUDA with FLOAT
-- TORCH-ORT: #1[GT1s32x1024x1152[234.32289123535156,335.26470947265625:A285.03857107711394]]
-- diff with torch abs=0.019561767578125, rel=6.741366235333735e-05, n=37748736.0
-- diff with ort abs=0.002410888671875, rel=8.311597291643774e-06, n=37748736.0

-- TORCH-ORT running on CUDA with FLOAT16
-- TORCH-ORT: #1[GT10s32x1024x1152[234.375,335.25:A285.04108573993045]]
-- diff with torch abs=0, rel=0
-- diff with ort abs=0.5, rel=0.0018832320782219277, n=37748736.0
df = pandas.DataFrame(data).set_index(["model", "provider", "dtype"])
df = df.sort_index()
print(df)
                            diff_ort  diff_torch
model     provider dtype
ORT-TORCH CPU      FLOAT    0.000153    0.000092
                   FLOAT16  0.250000    0.500000
          CUDA     FLOAT    0.019562    0.000122
                   FLOAT16  0.000000    0.500000
TORCH-ORT CPU      FLOAT    0.018799    0.018799
                   FLOAT16  0.250000    0.500000
          CUDA     FLOAT    0.002411    0.019562
                   FLOAT16  0.500000    0.000000

Visually.

df[["diff_ort", "diff_torch"]].plot.bar(
    title="ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B"
)
ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B

Total running time of the script: (1 minutes 8.626 seconds)

Related examples

Reproducible Parallelized Reduction is difficult

Reproducible Parallelized Reduction is difficult

Gallery generated by Sphinx-Gallery