Gemm or Matmul + Add

Order of computation matters. 1 + 1e-20 - 1 != 1 - 1 + 1e-20 if the precision of the computation is taken into account. What an operator Gemm in onnxruntime, the most simple way to represent a linear neural layer.

A model with many choices

import cpuinfo
import numpy as np
import pandas
import matplotlib.pyplot as plt
import onnx
import onnx.helper as oh
import torch
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
from onnx_diagnostic.reference import OnnxruntimeEvaluator
from onnxruntime import (
    InferenceSession,
    SessionOptions,
    __version__ as version_onnxruntime,
    GraphOptimizationLevel,
)

print(f"onnxruntime version = {version_onnxruntime}")
print(f"cpu name = {cpuinfo.get_cpu_info()['brand_raw']}")
if torch.cuda.is_available():
    print(f"gpu name = {torch.cuda.get_device_name(0)}")
    print(f"cuda version = {torch.version.cuda}")
onnxruntime version = 1.24.0
cpu name = 13th Gen Intel(R) Core(TM) i7-13800H
gpu name = NVIDIA GeForce RTX 4060 Laptop GPU
cuda version = 13.0

The version is important. Numerical differences are observed with onnxruntime<=1.22. Let’s see how to make them happen.

def make_model_gemm(itype: int) -> onnx.ModelProto:
    return oh.make_model(
        oh.make_graph(
            [
                oh.make_node("Gemm", ["A", "X", "B"], ["GemmOnly"]),
                oh.make_node("Gemm", ["A", "X"], ["gmm"]),
                oh.make_node("Add", ["gmm", "B"], ["GemmAdd"]),
                oh.make_node("MatMul", ["A", "X"], ["mm"]),
                oh.make_node("Add", ["mm", "B"], ["MatMulAdd"]),
                oh.make_node("FusedMatMul", ["A", "X"], ["fmm"], domain="com.microsoft"),
                oh.make_node("Add", ["fmm", "B"], ["FusedMatMulAdd"]),
                oh.make_node("Cast", ["A"], ["Afloat"], to=onnx.TensorProto.FLOAT),
                oh.make_node("Cast", ["B"], ["Bfloat"], to=onnx.TensorProto.FLOAT),
                oh.make_node("Cast", ["X"], ["Xfloat"], to=onnx.TensorProto.FLOAT),
                oh.make_node("Gemm", ["Afloat", "Xfloat"], ["gmmfloat"]),
                oh.make_node("Add", ["gmmfloat", "Bfloat"], ["gemmaddfloat"]),
                oh.make_node("Cast", ["gemmaddfloat"], ["CastGemmAddCast"], to=itype),
                oh.make_node("Gemm", ["Afloat", "Xfloat", "Bfloat"], ["GemmOnlyfloat"]),
                oh.make_node("Cast", ["GemmOnlyfloat"], ["CastGemmOnlyCast"], to=itype),
            ],
            "test",
            [
                oh.make_tensor_value_info("A", itype, ["a", "b"]),
                oh.make_tensor_value_info("X", itype, ["b", "c"]),
                oh.make_tensor_value_info("B", itype, ["c"]),
            ],
            [
                oh.make_tensor_value_info("GemmOnly", itype, ["a", "c"]),
                oh.make_tensor_value_info("GemmAdd", itype, ["a", "c"]),
                oh.make_tensor_value_info("FusedMatMulAdd", itype, ["a", "c"]),
                oh.make_tensor_value_info("MatMulAdd", itype, ["a", "c"]),
                oh.make_tensor_value_info("CastGemmAddCast", itype, ["a", "c"]),
                oh.make_tensor_value_info("CastGemmOnlyCast", itype, ["a", "c"]),
            ],
        ),
        opset_imports=[oh.make_opsetid("", 22)],
        ir_version=10,
    )


def matrix_diff(tensors):
    mat = np.zeros((len(tensors), len(tensors)), dtype=np.float32)
    for i, t in enumerate(tensors):
        for j in range(i + 1, len(tensors)):
            mat[i, j] = max_diff(t, tensors[j])["abs"]
            mat[j, i] = mat[i, j]
    return mat


itype = onnx.TensorProto.FLOAT16
dtype = np.float16
model = make_model_gemm(itype)

A = np.random.randn(1280, 256).astype(dtype)
X = np.random.randn(256, 256).astype(dtype)
B = np.random.randn(256).astype(dtype)
feeds = dict(A=A, X=X, B=B)

We disable all the optimization made by onnxruntime to make the computation follows what we want to verify.

opts = SessionOptions()
opts.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
opts.optimized_model_filepath = "plot_gemm_or_matmul.optimized.onnx"
sess = InferenceSession(model.SerializeToString(), opts, providers=["CPUExecutionProvider"])
results = [A @ X + B, *sess.run(None, feeds)]
diffs = matrix_diff(results)

print(diffs)
[[0.         0.0625     0.0625     0.0625     0.0625     0.0625
  0.0625    ]
 [0.0625     0.         0.00097656 0.00097656 0.00097656 0.00097656
  0.        ]
 [0.0625     0.00097656 0.         0.         0.         0.
  0.00097656]
 [0.0625     0.00097656 0.         0.         0.         0.
  0.00097656]
 [0.0625     0.00097656 0.         0.         0.         0.
  0.00097656]
 [0.0625     0.00097656 0.         0.         0.         0.
  0.00097656]
 [0.0625     0.         0.00097656 0.00097656 0.00097656 0.00097656
  0.        ]]
opset: domain='' version=22
opset: domain='ai.onnx.ml' version=5
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft' version=1
opset: domain='com.microsoft.experimental' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='org.pytorch.aten' version=1
input: name='A' type=dtype('float16') shape=['a', 'b']
input: name='X' type=dtype('float16') shape=['b', 'c']
input: name='B' type=dtype('float16') shape=['c']
Cast(X, saturate=1, to=1) -> Xfloat
Cast(B, saturate=1, to=1) -> InsertedPrecisionFreeCast_B
Cast(X, saturate=1, to=1) -> InsertedPrecisionFreeCast_X
Cast(A, saturate=1, to=1) -> InsertedPrecisionFreeCast_A
  Gemm(InsertedPrecisionFreeCast_A, InsertedPrecisionFreeCast_X, InsertedPrecisionFreeCast_B, transA=0, alpha=1.00, beta=1.00, transB=0) -> InsertedPrecisionFreeCast_GemmOnly
    Cast(InsertedPrecisionFreeCast_GemmOnly, saturate=1, to=10) -> GemmOnly
  Gemm(InsertedPrecisionFreeCast_A, InsertedPrecisionFreeCast_X, transA=0, alpha=1.00, beta=1.00, transB=0) -> InsertedPrecisionFreeCast_gmm
  Add(InsertedPrecisionFreeCast_gmm, InsertedPrecisionFreeCast_B) -> InsertedPrecisionFreeCast_GemmAdd
    Cast(InsertedPrecisionFreeCast_GemmAdd, saturate=1, to=10) -> GemmAdd
  MatMul(InsertedPrecisionFreeCast_A, InsertedPrecisionFreeCast_X) -> InsertedPrecisionFreeCast_mm
  Add(InsertedPrecisionFreeCast_mm, InsertedPrecisionFreeCast_B) -> InsertedPrecisionFreeCast_MatMulAdd
    Cast(InsertedPrecisionFreeCast_MatMulAdd, saturate=1, to=10) -> MatMulAdd
  FusedMatMul[com.microsoft](InsertedPrecisionFreeCast_A, InsertedPrecisionFreeCast_X, transA=0, transB=0, alpha=1.00, transBatchB=0, transBatchA=0) -> InsertedPrecisionFreeCast_fmm
  Add(InsertedPrecisionFreeCast_fmm, InsertedPrecisionFreeCast_B) -> InsertedPrecisionFreeCast_FusedMatMulAdd
    Cast(InsertedPrecisionFreeCast_FusedMatMulAdd, saturate=1, to=10) -> FusedMatMulAdd
Cast(B, saturate=1, to=1) -> Bfloat
Cast(A, saturate=1, to=1) -> Afloat
  Gemm(Afloat, Xfloat, Bfloat, transA=0, alpha=1.00, beta=1.00, transB=0) -> GemmOnlyfloat
    Cast(GemmOnlyfloat, saturate=1, to=10) -> CastGemmOnlyCast
  Gemm(Afloat, Xfloat, transA=0, alpha=1.00, beta=1.00, transB=0) -> gmmfloat
  Add(gmmfloat, Bfloat) -> gemmaddfloat
    Cast(gemmaddfloat, saturate=1, to=10) -> CastGemmAddCast
output: name='GemmOnly' type=dtype('float16') shape=['a', 'c']
output: name='GemmAdd' type=dtype('float16') shape=['a', 'c']
output: name='FusedMatMulAdd' type=dtype('float16') shape=['a', 'c']
output: name='MatMulAdd' type=dtype('float16') shape=['a', 'c']
output: name='CastGemmAddCast' type=dtype('float16') shape=['a', 'c']
output: name='CastGemmOnlyCast' type=dtype('float16') shape=['a', 'c']

It seems some cast were still inserted.

Let’s try with CUDA and float32 if it is available.

A = torch.randn((1280, 1280), dtype=torch.float32)
X = torch.randn((1280, 1280), dtype=torch.float32)
B = torch.randn((1280), dtype=torch.float32)

for itype, dtype, device in [
    (onnx.TensorProto.FLOAT16, torch.float16, "cpu"),
    (onnx.TensorProto.FLOAT, torch.float32, "cpu"),
    (onnx.TensorProto.FLOAT16, torch.float16, "cuda"),
    (onnx.TensorProto.FLOAT, torch.float32, "cuda"),
]:
    if device == "cuda" and not torch.cuda.is_available():
        continue
    a = A.to(dtype).to(device)
    x = X.to(dtype).to(device)
    b = B.to(dtype).to(device)
    feeds = dict(A=a, X=x, B=b)
    model = make_model_gemm(itype)

    sess = OnnxruntimeEvaluator(model, whole=True)
    results = sess.run(None, feeds)
    diffs = matrix_diff(results)
    print(f"------ dtype={dtype}, device={device!r}")
    print(diffs)
------ dtype=torch.float16, device='cpu'
[[0.     0.0625 0.0625 0.     0.0625 0.    ]
 [0.0625 0.     0.     0.0625 0.     0.0625]
 [0.0625 0.     0.     0.0625 0.     0.0625]
 [0.     0.0625 0.0625 0.     0.0625 0.    ]
 [0.0625 0.     0.     0.0625 0.     0.0625]
 [0.     0.0625 0.0625 0.     0.0625 0.    ]]
------ dtype=torch.float32, device='cpu'
[[0.0000000e+00 3.0517578e-05 3.0517578e-05 0.0000000e+00 3.0517578e-05
  0.0000000e+00]
 [3.0517578e-05 0.0000000e+00 0.0000000e+00 3.0517578e-05 0.0000000e+00
  3.0517578e-05]
 [3.0517578e-05 0.0000000e+00 0.0000000e+00 3.0517578e-05 0.0000000e+00
  3.0517578e-05]
 [0.0000000e+00 3.0517578e-05 3.0517578e-05 0.0000000e+00 3.0517578e-05
  0.0000000e+00]
 [3.0517578e-05 0.0000000e+00 0.0000000e+00 3.0517578e-05 0.0000000e+00
  3.0517578e-05]
 [0.0000000e+00 3.0517578e-05 3.0517578e-05 0.0000000e+00 3.0517578e-05
  0.0000000e+00]]
------ dtype=torch.float16, device='cuda'
[[0.    0.125 0.125 0.    0.125 0.125]
 [0.125 0.    0.    0.125 0.125 0.125]
 [0.125 0.    0.    0.125 0.125 0.125]
 [0.    0.125 0.125 0.    0.125 0.125]
 [0.125 0.125 0.125 0.125 0.    0.   ]
 [0.125 0.125 0.125 0.125 0.    0.   ]]
------ dtype=torch.float32, device='cuda'
[[0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]

A weird bias

In the previous example, the coefficients of the bias are similar to the others coefficients. What if we make them a lot higher.

A = A / A.max()
X = X / X.max()
B = (torch.arange(1280, dtype=torch.float32) + 1) / 1280 * 16
labels = ["F.linear", *[o.name for o in model.graph.output], "a @ x + b"]
all_results = {}

for itype, dtype, device in [
    (onnx.TensorProto.FLOAT, torch.float32, "cpu"),
    (onnx.TensorProto.FLOAT16, torch.float16, "cpu"),
    # missing implementation in onnxruntime
    # (onnx.TensorProto.BFLOAT16, torch.bfloat16, "cpu"),
    (onnx.TensorProto.FLOAT, torch.float32, "cuda"),
    (onnx.TensorProto.FLOAT16, torch.float16, "cuda"),
    (onnx.TensorProto.BFLOAT16, torch.bfloat16, "cuda"),
]:
    if device == "cuda" and not torch.cuda.is_available():
        continue
    a = A.to(dtype).to(device)
    x = X.to(dtype).to(device)
    b = B.to(dtype).to(device)
    feeds = dict(A=a, X=x, B=b)
    model = make_model_gemm(itype)

    filename = f"plot_gemm_or_matmul.{itype}.{device}.onnx"
    sess = OnnxruntimeEvaluator(
        model,
        whole=True,
        graph_optimization_level=GraphOptimizationLevel.ORT_DISABLE_ALL,
        optimized_model_filepath=filename,
    )
    results = [torch.nn.functional.linear(a, x.T, b), *sess.run(None, feeds), a @ x + b]
    all_results[device, dtype] = results
    has_cast = "Cast" in [n.op_type for n in onnx.load(filename).graph.node]
    diffs = matrix_diff(results)
    df = pandas.DataFrame(diffs, columns=labels, index=labels)
    print(f"------ has_cast={has_cast}, dtype={dtype}, device={device!r}, max(b)={b.max()}")
    print(df)
------ has_cast=True, dtype=torch.float32, device='cpu', max(b)=16.0
                  F.linear  GemmOnly   GemmAdd  FusedMatMulAdd  MatMulAdd  CastGemmAddCast  CastGemmOnlyCast  a @ x + b
F.linear          0.000000  0.000008  0.000006        0.000006   0.000006         0.000006          0.000008   0.000006
GemmOnly          0.000008  0.000000  0.000008        0.000008   0.000008         0.000008          0.000000   0.000008
GemmAdd           0.000006  0.000008  0.000000        0.000000   0.000000         0.000000          0.000008   0.000004
FusedMatMulAdd    0.000006  0.000008  0.000000        0.000000   0.000000         0.000000          0.000008   0.000004
MatMulAdd         0.000006  0.000008  0.000000        0.000000   0.000000         0.000000          0.000008   0.000004
CastGemmAddCast   0.000006  0.000008  0.000000        0.000000   0.000000         0.000000          0.000008   0.000004
CastGemmOnlyCast  0.000008  0.000000  0.000008        0.000008   0.000008         0.000008          0.000000   0.000008
a @ x + b         0.000006  0.000008  0.000004        0.000004   0.000004         0.000004          0.000008   0.000000
------ has_cast=True, dtype=torch.float16, device='cpu', max(b)=16.0
                  F.linear  GemmOnly   GemmAdd  FusedMatMulAdd  MatMulAdd  CastGemmAddCast  CastGemmOnlyCast  a @ x + b
F.linear          0.000000  0.015625  0.015625        0.015625   0.015625         0.015625          0.015625   0.015625
GemmOnly          0.015625  0.000000  0.015625        0.015625   0.015625         0.015625          0.000000   0.015625
GemmAdd           0.015625  0.015625  0.000000        0.000000   0.000000         0.000000          0.015625   0.015625
FusedMatMulAdd    0.015625  0.015625  0.000000        0.000000   0.000000         0.000000          0.015625   0.015625
MatMulAdd         0.015625  0.015625  0.000000        0.000000   0.000000         0.000000          0.015625   0.015625
CastGemmAddCast   0.015625  0.015625  0.000000        0.000000   0.000000         0.000000          0.015625   0.015625
CastGemmOnlyCast  0.015625  0.000000  0.015625        0.015625   0.015625         0.015625          0.000000   0.015625
a @ x + b         0.015625  0.015625  0.015625        0.015625   0.015625         0.015625          0.015625   0.000000
------ has_cast=True, dtype=torch.float32, device='cuda', max(b)=16.0
                  F.linear  GemmOnly   GemmAdd  FusedMatMulAdd  MatMulAdd  CastGemmAddCast  CastGemmOnlyCast  a @ x + b
F.linear          0.000000  0.002478  0.002478        0.002478   0.002478         0.002478          0.002478   0.000006
GemmOnly          0.002478  0.000000  0.000000        0.000000   0.000000         0.000000          0.000000   0.002479
GemmAdd           0.002478  0.000000  0.000000        0.000000   0.000000         0.000000          0.000000   0.002479
FusedMatMulAdd    0.002478  0.000000  0.000000        0.000000   0.000000         0.000000          0.000000   0.002479
MatMulAdd         0.002478  0.000000  0.000000        0.000000   0.000000         0.000000          0.000000   0.002479
CastGemmAddCast   0.002478  0.000000  0.000000        0.000000   0.000000         0.000000          0.000000   0.002479
CastGemmOnlyCast  0.002478  0.000000  0.000000        0.000000   0.000000         0.000000          0.000000   0.002479
a @ x + b         0.000006  0.002479  0.002479        0.002479   0.002479         0.002479          0.002479   0.000000
------ has_cast=True, dtype=torch.float16, device='cuda', max(b)=16.0
                  F.linear  GemmOnly   GemmAdd  FusedMatMulAdd  MatMulAdd  CastGemmAddCast  CastGemmOnlyCast  a @ x + b
F.linear          0.000000  0.015625  0.015625        0.015625   0.015625         0.015625          0.015625   0.015625
GemmOnly          0.015625  0.000000  0.015625        0.015625   0.015625         0.015625          0.015625   0.015625
GemmAdd           0.015625  0.015625  0.000000        0.000000   0.000000         0.015625          0.015625   0.000000
FusedMatMulAdd    0.015625  0.015625  0.000000        0.000000   0.000000         0.015625          0.015625   0.000000
MatMulAdd         0.015625  0.015625  0.000000        0.000000   0.000000         0.015625          0.015625   0.000000
CastGemmAddCast   0.015625  0.015625  0.015625        0.015625   0.015625         0.000000          0.000000   0.015625
CastGemmOnlyCast  0.015625  0.015625  0.015625        0.015625   0.015625         0.000000          0.000000   0.015625
a @ x + b         0.015625  0.015625  0.000000        0.000000   0.000000         0.015625          0.015625   0.000000
------ has_cast=True, dtype=torch.bfloat16, device='cuda', max(b)=16.0
                  F.linear  GemmOnly  GemmAdd  FusedMatMulAdd  MatMulAdd  CastGemmAddCast  CastGemmOnlyCast  a @ x + b
F.linear             0.000     0.125    0.125           0.125      0.125            0.125             0.125      0.125
GemmOnly             0.125     0.000    0.125           0.125      0.125            0.125             0.125      0.125
GemmAdd              0.125     0.125    0.000           0.000      0.000            0.125             0.125      0.000
FusedMatMulAdd       0.125     0.125    0.000           0.000      0.000            0.125             0.125      0.000
MatMulAdd            0.125     0.125    0.000           0.000      0.000            0.125             0.125      0.000
CastGemmAddCast      0.125     0.125    0.125           0.125      0.125            0.000             0.000      0.125
CastGemmOnlyCast     0.125     0.125    0.125           0.125      0.125            0.000             0.000      0.125
a @ x + b            0.125     0.125    0.000           0.000      0.000            0.125             0.125      0.000

Cast is inserted on CPU because some kernel are not available for float16. Even though, we can see huge discrepancies happening.

bias value vs discrepancies

Let’s compare torch linear with GemmOnly.

def make_figure_axis(all_results, i, j):
    labs = labels[i], labels[j]
    fig, ax = plt.subplots(len(all_results), 2, figsize=(12, 4 * len(all_results)))
    for pos, ((device, dtype), results) in enumerate(all_results.items()):
        m1, m2 = results[i], results[j]
        diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0]
        print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}")
        expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2
        ax[pos, 0].plot(
            B.tolist(), (diff.detach().cpu() + torch.rand(1280) * expand).tolist(), "."
        )
        ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}", fontsize=10)

        corr = matrix_diff(results)
        ax[pos, 1].imshow(corr, cmap="Wistia", vmin=0, vmax=corr.max())
        # ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
        ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45, ha="right", fontsize=10)
        ax[pos, 1].set_yticks(range(len(labels)), labels, fontsize=10)
        ax[pos, 1].set_title(f"max={diff.max():1.2g}", fontsize=10)
        for _i in range(corr.shape[0]):
            for _j in range(corr.shape[1]):
                ax[pos, 1].text(
                    _j,
                    _i,
                    f"{corr[_i, _j]:1.1g}",
                    ha="center",
                    va="center",
                    color="black",
                    fontsize=8,
                )
    fig.suptitle(
        f"Left column: discrepancies {labs[0]} VS {labs[1]}\n"
        f"Right column: max absolute error, across all configuration\n"
        f"white is good, orange is not"
    )
    return fig, ax


fig, ax = make_figure_axis(all_results, 0, 1)
fig.tight_layout()
fig.savefig("plot_gemm_or_matmul_add1.png")
Left column: discrepancies F.linear VS GemmOnly Right column: max absolute error, across all configuration white is good, orange is not, F.linear-GemmOnly cpu/torch.float32, max=7.6e-06, F.linear-GemmOnly cpu/torch.float16, max=0.016, F.linear-GemmOnly cuda/torch.float32, max=0.0025, F.linear-GemmOnly cuda/torch.float16, max=0.016, F.linear-GemmOnly cuda/torch.bfloat16, max=0.12
labels=('F.linear', 'GemmOnly'), cpu/torch.float32: max(diff)=7.62939453125e-06
labels=('F.linear', 'GemmOnly'), cpu/torch.float16: max(diff)=0.015625
labels=('F.linear', 'GemmOnly'), cuda/torch.float32: max(diff)=0.0024776458740234375
labels=('F.linear', 'GemmOnly'), cuda/torch.float16: max(diff)=0.015625
labels=('F.linear', 'GemmOnly'), cuda/torch.bfloat16: max(diff)=0.125

Let’s compare with A @ X + B.

fig, ax = make_figure_axis(all_results, -1, 1)
fig.tight_layout()
fig.savefig("plot_gemm_or_matmul_add2.png")
Left column: discrepancies a @ x + b VS GemmOnly Right column: max absolute error, across all configuration white is good, orange is not, a @ x + b-GemmOnly cpu/torch.float32, max=7.6e-06, a @ x + b-GemmOnly cpu/torch.float16, max=0.016, a @ x + b-GemmOnly cuda/torch.float32, max=0.0025, a @ x + b-GemmOnly cuda/torch.float16, max=0.016, a @ x + b-GemmOnly cuda/torch.bfloat16, max=0.12
labels=('a @ x + b', 'GemmOnly'), cpu/torch.float32: max(diff)=7.62939453125e-06
labels=('a @ x + b', 'GemmOnly'), cpu/torch.float16: max(diff)=0.015625
labels=('a @ x + b', 'GemmOnly'), cuda/torch.float32: max(diff)=0.0024785995483398438
labels=('a @ x + b', 'GemmOnly'), cuda/torch.float16: max(diff)=0.015625
labels=('a @ x + b', 'GemmOnly'), cuda/torch.bfloat16: max(diff)=0.125

Discrepancies do not happen all the time but it is very likely to happen. The use of Gemm with a bias not null should be used when torch is doing the same and it seems to depend on the type as well. The difference is even higher for bfloat16.

Total running time of the script: (0 minutes 26.400 seconds)

Related examples

Reproducible Parallelized Reduction is difficult

Reproducible Parallelized Reduction is difficult

LayerNormalization implementation cannot be exchanged

LayerNormalization implementation cannot be exchanged

From a LLM to processing a prompt

From a LLM to processing a prompt

Gallery generated by Sphinx-Gallery