from enum import IntEnum
from logging import getLogger
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import FunctionProto, NodeProto, TensorProto
from onnx.helper import (
make_node,
make_tensor,
make_tensor_value_info,
tensor_dtype_to_np_dtype,
)
from onnx.numpy_helper import float8e4m3_to_float32, float8e5m2_to_float32
from onnx.reference.custom_element_types import float8e4m3fn
from onnx.reference.op_run import to_array_extended
from onnx.onnx_cpp2py_export.defs import SchemaError
try:
from onnx.reference.ops.op_cast import Cast_19 as Cast
from onnx.reference.ops.op_quantize_linear import (
QuantizeLinear_19 as QuantizeLinear,
)
except ImportError:
from onnx.reference.ops.op_cast import Cast
from onnx.reference.ops.op_quantize_linear import QuantizeLinear
from ...helper import (
make_dynamic_quantize_linear_function_proto,
make_matmul_reshape_transpose_back_function_proto,
make_matmul_reshape_transpose_function_proto,
make_simple_dynamic_quantize_linear_function_proto,
)
from ...reference import from_array_extended
from ...validation.cython.fp8 import cast_float32_to_e4m3fn
from .errors import QuantizationError
from .onnx_graph_struct import _get_shape, Graph, Node, NodeKind
logger = getLogger("onnx-extended/transformer")
[docs]class QuantizeOptions(IntEnum):
"""
Quantization options.
* `NONE`: no option
* `OPTIMIZE`: assumes there is no nan values,
choose less generic functions such as the implementation
of DynamicQuantizeLinear
"""
NONE = 0
OPTIMIZE = 1
def estimation_quantization_scale(
coef: np.array,
to: int = TensorProto.FLOAT8E4M3FN,
threshold: float = 0.99999,
) -> Tuple[float, float]:
"""
Estimates the scale parameter for the quantization to float 8 assuming
the distribution of the coefficients is gaussian.
"""
if to in (
TensorProto.FLOAT8E4M3FN,
TensorProto.FLOAT8E4M3FNUZ,
TensorProto.FLOAT8E5M2,
TensorProto.FLOAT8E5M2FNUZ,
):
if to == TensorProto.FLOAT8E4M3FN:
fct = float8e4m3_to_float32
elif to == TensorProto.FLOAT8E4M3FNUZ:
def fct(x):
return float8e4m3_to_float32(x, uz=True)
elif to == TensorProto.FLOAT8E5M2:
fct = float8e5m2_to_float32
elif to == TensorProto.FLOAT8E5M2FNUZ:
def fct(x):
return float8e5m2_to_float32(x, uz=True, fn=True)
else:
raise ValueError(f"Unexpected to={to!r}.")
float8 = [fct(i) for i in range(0, 256)]
quant_float = [f for f in float8 if not np.isnan(f) and not np.isinf(f)]
std_coef = np.mean(coef.ravel() ** 2) ** 0.5
std_quant = np.std(np.array(quant_float, dtype=np.float32))
zero = 0.0
scale = std_quant.astype(coef.dtype) / std_coef.astype(coef.dtype)
elif to == TensorProto.UINT8:
qu = np.quantile(coef.ravel(), [1 - threshold, threshold])
scale = 255 / (qu[1] - qu[0])
zero = qu[0] * scale
elif to == TensorProto.INT8:
qu = np.quantile(coef.ravel(), [1 - threshold, threshold])
scale = 254 / (qu[1] - qu[0])
zero = (qu[0] + qu[1]) / 2 * scale
else:
raise ValueError(f"Unexpected quantization type for to={to}.")
return np.array(1.0 / scale, dtype=coef.dtype), -zero
def quantize_weights(
node: Node,
elem_type: int,
index: int,
transpose: bool = False,
) -> Tuple[Node, Node, Node, Union[bool, Tuple[int, ...]], Tuple[int, ...]]:
"""
Quantizes a tensor into a tensor of element type *elem_type*.
:param node: Node to quantize
:param elem_type: element type
:param transpose: transpose the weight before doing it
:param verbose: logs what is quantized
:return: three new nodes, quantized weights, scale, zero point
and the original shape of the constant, shape
See function :func:`quantize_float8_matmul`
"""
tensor = node.get_tensor()
values = to_array_extended(tensor)
shape = values.shape
if len(values.shape) == 1:
raise NotImplementedError(
f"Input {index} is a constant, it must have at "
f"least 2 dimensions not {values.shape}."
)
if len(values.shape) > 2:
# Needs to be reshaped.
reshaped = values.shape
if index == 0:
# first input
values = values.reshape((-1, values.shape[-1]))
else:
# second input
values = (
values.reshape((-1,) + values.shape[-2:])
.transpose((1, 0, 2))
.reshape((values.shape[-2], -1))
)
else:
reshaped = False
if transpose:
values = values.T
scale, zp = estimation_quantization_scale(values, to=elem_type)
zpt = make_tensor("zp", elem_type, [], [zp])
zpa = to_array_extended(zpt)
if elem_type == TensorProto.FLOAT8E4M3FN:
new_values = cast_float32_to_e4m3fn(values / scale).astype(float8e4m3fn)
else:
new_values = QuantizeLinear.eval(values, scale, zpa)
new_tensor = from_array_extended(
new_values, name=node.parent.generate_name(node.outname)
)
node_weight = node.create_with_new_values(new_tensor)
new_scale = from_array_extended(
scale, name=node.parent.generate_name(f"{node.outname}_scale")
)
node_scale = node.create_with_new_values(new_scale)
new_zp = from_array_extended(
zpa, name=node.parent.generate_name(f"{node.outname}_zp")
)
node_zp = node.create_with_new_values(new_zp)
return node_weight, node_scale, node_zp, reshaped, shape
class _MainQuantizeState:
def __init__(self, version, domain_ops, local_functions, opset, index_transpose):
self.domain_dq = "local.quant.domain"
if domain_ops is None:
domain_ops = {}
self.domain_ops = domain_ops
if version == "onnxruntime":
self.op_gemm = "GemmFloat8"
self.domain_gemm = domain_ops.get("GemmFloat8", "com.microsoft")
elif version == "onnx-extended":
self.op_gemm = "CustomGemmFloat8E4M3FN"
self.domain_gemm = domain_ops.get(
"CustomGemmFloat8E4M3FN", "onnx_extented.ortops.tutorial.cuda"
)
else:
raise ValueError(f"Unexpected value {version!r} for version.")
if local_functions is None:
raise ValueError("local_functions cannot be None")
self.local_functions = local_functions
self.opset = opset
self.index_transpose = index_transpose
class _QuantizeState:
def __init__(self, logger, quantize_options, main_state):
self.removed = []
self.added = []
self.input_names = []
self.was_reshaped = [False, False]
self.shapes = []
self.logger = logger
self.quantize_options = quantize_options
self.main_state = main_state
def quantize_connstant_weights(self, node, name, elem_type, index, do_transpose):
cst = node.parent.get_node_producer(name)
weight, scale, zero_point, reshaped, shape = quantize_weights(
cst, elem_type, index, transpose=do_transpose
)
self.added.extend([weight.proto, scale.proto, zero_point.proto])
self.input_names.append([weight.outname, scale.outname, zero_point.outname])
self.removed.append(cst)
self.was_reshaped[index] = reshaped
self.shapes.append(shape)
self.logger.debug("[quantize_weights] static quantize shape=%r", shape)
def quantize_dynamic(self, node, name, elem_type, index, do_transpose):
# Not a constant
shape = node.parent.get_shape(name)
self.shapes.append(shape)
self.logger.debug("[quantize_float8_matmul] quantize dynamic shape %r", shape)
if len(shape) == 2:
if do_transpose:
# no need
temp_name = node.parent.generate_name(f"{name}_tr")
self.added.append(
make_node(
"Transpose",
[name],
[temp_name],
perm=[1, 0],
name=node.parent.generate_node_name(f"tra8_{name}"),
)
)
else:
temp_name = name
else:
if shape is None:
raise QuantizationError(
f"Shape is unknown for result {name!r} in node {node}. "
f"This input cannot be transposed with certainty."
)
temp_name = node.parent.generate_name(f"{name}_tr")
fname = f"MatMulReshapeTranspose{'T' if do_transpose else 'N'}{index}"
self.added.append(
make_node(
fname,
[name],
[temp_name],
perm=[1, 0],
domain=self.main_state.domain_dq,
name=node.parent.generate_node_name("MMRT"),
)
)
self.was_reshaped[index] = name
if (
self.main_state.domain_dq,
fname,
) not in self.main_state.local_functions:
# use local functions
self.main_state.local_functions[
self.main_state.domain_dq, fname
] = make_matmul_reshape_transpose_function_proto(
domain=self.main_state.domain_dq,
opset=self.main_state.opset,
index=index,
transpose=do_transpose,
)
new_name = node.parent.generate_name(f"{name}_f8")
scale = node.parent.generate_name(f"{name}_scale")
zero_point = node.parent.generate_name(f"{name}_zp")
if self.quantize_options == QuantizeOptions.NONE:
fname, fmake = (
"DynamicQuantizeLinear",
make_dynamic_quantize_linear_function_proto,
)
elif self.quantize_options == QuantizeOptions.OPTIMIZE:
suffix = {
TensorProto.FLOAT8E4M3FN: "E4M3FN",
TensorProto.FLOAT8E4M3FNUZ: "E4M3FNUZ",
TensorProto.FLOAT8E5M2: "E5M2",
TensorProto.FLOAT8E5M2FNUZ: "E5M2FNUZ",
}
fname, fmake = (
f"DynamicQuantizeLinear{suffix[elem_type]}",
lambda domain=None, opset=None: (
make_simple_dynamic_quantize_linear_function_proto(
domain=domain, opset=opset, to=elem_type
)
),
)
else:
raise RuntimeError(
f"Unexpected value {self.quantize_options!r} for quantize_options."
)
proto = make_node(
fname,
[temp_name],
[new_name, scale, zero_point],
to=elem_type,
domain=self.main_state.domain_dq,
name=node.parent.generate_node_name("DQL"),
)
dql = Node(None, node.parent, proto, NodeKind.NODE)
self.added.extend([dql.proto])
self.input_names.append(dql.outputs)
if (
self.main_state.domain_dq == "local.quant.domain"
and (self.main_state.domain_dq, fname)
not in self.main_state.local_functions
):
# use local functions
self.main_state.local_functions[self.main_state.domain_dq, fname] = fmake(
domain=self.main_state.domain_dq, opset=self.main_state.opset
)
def finalize(self, node, name, output_type):
if self.was_reshaped[0] and self.was_reshaped[1]:
raise QuantizationError(
f"MatMul cannot be replaced by operator Gemm as both inputs "
f"are not matrices. Their shapes are {self.shapes}. "
f"Node name is {node.name!r}."
)
if output_type in {
TensorProto.INT8,
TensorProto.UINT8,
TensorProto.FLOAT8E4M3FN,
TensorProto.FLOAT8E4M3FNUZ,
TensorProto.FLOAT8E5M2,
TensorProto.FLOAT8E5M2FNUZ,
}:
# output is quantized, there is a need for a scale
scale_out = node.parent.generate_name(f"{name}_scaleout")
self.added.append(
make_node(
"Mul",
[self.input_names[0][1], self.input_names[1][1]],
[scale_out],
name=node.parent.generate_node_name(f"mul8_{name}"),
)
)
else:
# output is not quantized, no need for an output scale
scale_out = ""
gemm_inputs = [
self.input_names[0][0], # A
self.input_names[1][0], # B
"", # C
self.input_names[0][1], # scaleA
self.input_names[1][1], # scaleB
scale_out, # scaleR
]
if self.was_reshaped[0] or self.was_reshaped[1]:
gemm_outputs = [node.parent.generate_name(f"{name}_gemm")]
do_reshape = True
else:
gemm_outputs = node.outputs
do_reshape = False
self.added.append(
make_node(
self.main_state.op_gemm,
gemm_inputs,
gemm_outputs,
rowMajor=0,
dtype=output_type,
transA=1 if (self.main_state.index_transpose & 1) else 0,
transB=1 if (self.main_state.index_transpose & 2) else 0,
domain=self.main_state.domain_gemm,
computeType="CUBLAS_COMPUTE_32F_FAST_TF32",
name=node.parent.generate_node_name("GEMMFP8"),
)
)
if do_reshape:
# One of the inputs had 3 dimensions.
index = 0 if self.was_reshaped[0] else 1
fname = f"MatMulReshapeTransposeBack{index}"
mmshape = node.parent.generate_name(f"{name}_mmshape")
shape = self.was_reshaped[index]
if isinstance(shape, tuple):
self.added.append(
make_node(
"Constant",
[],
[mmshape],
value=make_tensor(
mmshape, TensorProto.INT64, [len(shape)], list(shape)
),
)
)
elif isinstance(shape, str):
self.added.append(make_node("Shape", [shape], [mmshape]))
else:
raise TypeError(
f"Unexpected shape={shape} in was_reshaped={self.was_reshaped}."
)
self.added.append(
make_node(
fname,
[gemm_outputs[0], mmshape],
node.outputs,
perm=[1, 0],
domain=self.main_state.domain_dq,
name=node.parent.generate_node_name("MMRTB"),
)
)
if (
self.main_state.domain_dq,
fname,
) not in self.main_state.local_functions:
# use local functions
self.main_state.local_functions[
self.main_state.domain_dq, fname
] = make_matmul_reshape_transpose_back_function_proto(
domain=self.main_state.domain_dq,
opset=self.main_state.opset,
index=index,
)
elif self.was_reshaped[1]:
raise NotImplementedError(
f"Shape is {shape}. This case is not implemented yet."
)
self.removed.append(node)
opsets = {self.main_state.domain_gemm: 1}
if self.main_state.domain_dq != "":
opsets.update({self.main_state.domain_dq: 1})
return opsets
def quantize_float8_matmul(
node: Node,
elem_type: int,
output_type: int,
version: str,
opset: int,
local_functions: Optional[Dict[str, FunctionProto]] = None,
index_transpose: int = 2,
domain_ops: Optional[Dict[str, str]] = None,
quantize_options: QuantizeOptions = QuantizeOptions.NONE,
) -> Optional[TransformResults]:
"""
Quantizes matrix multiplications.
:param node: matrix multiplication
:param elem_type: float 8 type to quantize into
:param output_type: output type, result of the quantization
:param version: `'onnxruntime'` to use operators from onnx and onnxruntime,
`'onnx-extended'` to use experimental operators
:param opset: main opset (used to specify local functions)
:param local_functions: None to avoid using local functions,
otherwise a dictionary with the existing local functions to
add to the model
:param quiet: True to silently skip failing nodes
:param index_transpose: which input to transpose before calling gemm:
0 (none), 1 (first), 2 (second), 3 for both
:param domain_ops: domain to use for operators used as keys in the dictionary
:param quantize_options: see :class:`QuantizeOptions`
:return: nodes to remove, nodes to add, new opsets
**About tensors with more than two dimensions.**
Gemm only supports matrices or 2D tensors.
When one input does not follow that schema,
it uses the following tricks. First one when the first input
has more than one dimension. We could describe that
with the Einstein summation. We want `abij, jk -> abik`.
::
abij, jk -> abik
abij ~> Cj
Cj, jk -> Ck # gemm
Ck ~> abik
In python:
.. runpython::
:showcode:
import numpy as np
a = np.arange(16).reshape((2, 2, 2, 2))
b = np.arange(4).reshape((2, 2)) * 10
expected = a @ b
a2 = a.reshape((-1, a.shape[-1]))
b2 = b
res = a2 @ b2
final = res.reshape(a.shape[:-2] + (-1, res.shape[-1]))
print("------")
print(a)
print(b)
print("------")
print(expected)
print("------")
print(final)
assert expected.tolist() == final.tolist()
For the other, we use the trick `a @ b = (b' @ a')'`.
Second one when the second input has more than one dimension.
We want `ij, abjk -> abik`.
::
ij, abjk -> abik
abjk ~> Cjk ~> jCk ~> jD
ij, jD -> iD # gemm
iD ~> iCk ~> Cik ~> abik
.. runpython::
:showcode:
import numpy as np
a = np.arange(4).reshape((2, 2)) * 10
b = np.arange(16).reshape((2, 2, 2, 2))
expected = a @ b
a2 = a
b2 = (
b.reshape((-1,) + b.shape[-2:])
.transpose((1, 0, 2))
.reshape((b.shape[-2], -1))
)
res = a2 @ b2
final = (
res.reshape(a.shape[0], -1, b.shape[-1])
.transpose((1, 0, 2))
.reshape(b.shape[:-2] + (-1, b.shape[-1]))
)
print("------")
print(a)
print(b)
print("------")
print(expected)
print("------")
print(final)
assert expected.tolist() == final.tolist()
Both sides cannot have more than two dimensions in the current implementation.
"""
main_state = _MainQuantizeState(
version, domain_ops, local_functions, opset, index_transpose
)
if node.op_type == "MatMul":
quantize_state = _QuantizeState(logger, quantize_options, main_state)
for index, name in enumerate(node.inputs):
do_transpose = bool((1 << index) & index_transpose)
if node.parent.is_constant(name):
# Quantized constant weights
quantize_state.quantize_connstant_weights(
node, name, elem_type, index, do_transpose
)
else:
quantize_state.quantize_dynamic(
node, name, elem_type, index, do_transpose
)
opsets = quantize_state.finalize(node, name, output_type)
return TransformResults(
removed_nodes=quantize_state.removed,
added_nodes=quantize_state.added,
new_opsets=opsets,
local_functions=main_state.local_functions,
)
raise NotImplementedError(
f"Quantization into float 8 not yet implemented for {node.op_type!r}."
)
[docs]def quantize_float8(
graph: Graph,
elem_type: int = TensorProto.FLOAT8E4M3FN,
output_type: int = TensorProto.FLOAT,
early_stop: int = -1,
version: str = "onnxruntime",
quiet: bool = False,
index_transpose: int = 2,
domain_ops: Optional[Dict[str, str]] = None,
exceptions: Optional[List[Dict[str, str]]] = None,
quantize_options: QuantizeOptions = QuantizeOptions.NONE,
) -> Optional[Graph]:
"""
Transforms a graph to introduce quantized weights.
This transformation requires opset 20. The graph is
upgraded if the main opset is below. It is better to do
it before calling this function.
:param graph: Graph
:param elem_type: quantization type
:param output_type: output type
:param early_stop: -1 to go through all nodes or a value `n > 0`
to stop after n changes
:param version: `'onnxruntime'` to use operators from onnx and onnxruntime,
`'onnx-extended'` to use experimental operators
:param quiet: catch exception and silently skip failing nodes
:param index_transpose: which input to transpose before calling gemm:
0 (none), 1 (first), 2 (second), 3 for both
:param domain_ops: domain to use for operators used as keys in the dictionary
:param exceptions: exclude nodes from the quantization,
`[{"name": "node_name1"}, {"name": "node_name2"}]` will exclude
these two node names from the quantization
:param quantize_options: see :class:`QuantizeOptions`
:return: Graph or None if not modified
Transformation are logged with logger `onnx-extended/transformer`.
The graph is modified inplace.
Enables the logs gives a better idea of the progress.
"""
main_opset = graph.get_opset("")
if main_opset < 19:
logger.info(
"[quantize_float8] upgrade model from opset %d to %s", main_opset, 19
)
graph.upgrade_opsets({"": 19})
main_opset = 19
if len(graph.functions) > 0:
raise NotImplementedError("Quantization of local functions is not implemented.")
local_functions = graph.functions.copy()
n_local_functions = len(local_functions)
new_opsets = {}
to_add = []
n_nodes = len(graph)
n_changes = 0
for index, node in enumerate(graph):
if exceptions is not None:
# checks if a node is not excluded from the list to convert
has_matched = False
for exc in exceptions:
if node.match(exc):
has_matched = True
break
if has_matched:
continue
if node.op_type in {"MatMul", "Gemm"}:
logger.info("[quantize_float8] %d/%d quantize %s", index, n_nodes, node)
try:
results = quantize_float8_matmul(
node,
elem_type,
output_type,
version=version,
opset=main_opset,
local_functions=local_functions,
index_transpose=index_transpose,
domain_ops=domain_ops,
quantize_options=quantize_options,
)
except (QuantizationError, NotImplementedError) as e:
if quiet:
logger.warn(
"[quantize_float8] %d/%d failed to quantize due to %s",
index,
n_nodes,
e,
)
continue
raise e
if len(local_functions) != len(results.local_functions) or id(
local_functions
) != id(results.local_functions):
raise RuntimeError(
f"Inconsistency lenghts: "
f"{len(local_functions)} != {len(results.local_functions)}, "
f"ids: {id(local_functions)} != {id(results.local_functions)}."
)
if results is None:
continue
rem, add = results.removed_nodes, results.added_nodes
to_add.append((rem, add))
if len(results.new_opsets) > 0:
n_changes += 1
new_opsets.update(results.new_opsets)
if early_stop > 0 and n_changes >= early_stop:
break
if len(to_add) == 0:
return None
for rem, add in to_add:
for r in rem:
logger.debug("[quantize_float8] del %s", r)
added = graph.replace_nodes([r.index for r in rem], add, new_opsets)
for a in added:
logger.debug("[quantize_float8] add %s", a)
graph.simplify()
if local_functions is not None and len(local_functions) > n_local_functions:
graph.add_functions(local_functions.values())
return graph
def _cast_constant(
node: Node,
from_type: int,
to_type: int,
) -> Optional[TransformResults]:
"""
Converts a node if it is a tensor of element type *from_type*
and cast it into *to_type*.
:param node: node
:param from_type: element type to replace
:param to_type: type to cast into
:return: TransformResults
"""
op_type = node.op_type
if op_type in {"Constant", "initializer"}:
cst = node.get_tensor()
if cst.data_type != from_type:
return None
arr = to_array_extended(cst)
try:
cast = Cast.eval(arr, to=to_type)
except SchemaError:
# cast does not work
np_type = tensor_dtype_to_np_dtype(to_type)
cast = arr.astype(np_type)
if op_type == "initializer":
return TransformResults(
removed_nodes=[node],
added_nodes=[from_array_extended(cast, name=node.proto.name)],
)
if op_type == "Constant":
return TransformResults(
removed_nodes=[node],
added_nodes=[
make_node(
"Constant",
[],
node.outputs,
value=from_array_extended(cast, name=node.outputs[0]),
)
],
)
if op_type in {"input", "output"}:
if isinstance(node.proto, str):
# A FunctionProto input, nothing to do.
return None
ttype = node.proto.type
if not ttype.tensor_type or ttype.tensor_type.elem_type != from_type:
return None
return TransformResults(
removed_nodes=[node],
added_nodes=[
make_tensor_value_info(
node.proto.name,
to_type,
shape=_get_shape(ttype),
doc_string=node.proto.doc_string,
shape_denotation=ttype.denotation,
)
],
)
if op_type in {"Cast"}:
to = node.getattr("to", int)
if to != from_type:
return None
return TransformResults(
removed_nodes=[node],
added_nodes=[make_node("Cast", node.inputs, node.outputs, to=to_type)],
)
raise RuntimeError(f"Unexpected node type {op_type!r}.")
[docs]def cast_constant(
graph: Graph,
from_type: int = TensorProto.FLOAT,
to_type: int = TensorProto.FLOAT16,
quiet: bool = False,
) -> Optional[Graph]:
"""
Converts all constants and initializers to the same type.
It also modifies the input.
:param graph: Graph
:param from_type: type of the constants to convert
:param to_type: new type for the constants
:param quiet: catch exception and silently skip failing nodes
:return: Graph or None if not modified
Transformation are logged with logger `onnx-extended/transformer`.
The graph is modified inplace.
Enables the logs gives a better idea of the progress.
"""
if len(graph.functions) > 0:
raise NotImplementedError("Conversion of local functions is not implemented.")
to_add = []
n_nodes = len(graph)
for index, node in enumerate(graph):
if node.op_type in {"Constant", "initializer", "input", "output", "Cast"}:
logger.info("[cast_constant] %d/%d convert %s", index, n_nodes, node)
results = _cast_constant(node, from_type, to_type)
if results is None:
continue
rem, add = results.removed_nodes, results.added_nodes
to_add.append((rem, add))
if len(to_add) == 0:
return None
for rem, add in to_add:
for r in rem:
logger.debug("[cast_constant] del %s", r)
added = graph.replace_nodes([r.index for r in rem], add)
for a in added:
logger.debug("[cast_constant] add %s", a)
return graph.simplify()