import os
import re
import io
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import (
FunctionProto,
GraphProto,
ModelProto,
TensorProto,
SequenceProto,
ValueInfoProto,
load,
)
from onnx.helper import tensor_dtype_to_np_dtype
from onnx.numpy_helper import to_array
from onnx.reference import ReferenceEvaluator
def _type_shape(
input_def: Union[str, ValueInfoProto]
) -> Tuple[Any, Tuple[Union[int, str], ...], Any]:
if isinstance(input_def, str):
reg = re.compile(
"([a-z][a-z0-9]*)?([(]([ a-zA-Z,0-9]+)[)])?(:([A-Z][A-Z0-9]*))?"
)
search = reg.match(input_def)
if search is None:
raise ValueError(f"Unable to interpret string {input_def!r}.")
grs = search.groups()
dtype = grs[0]
shape = None if grs[2] is None else grs[2].replace(" ", "").split(",")
law = grs[-1]
new_shape = []
if shape is not None:
for i in shape:
try:
vi = int(i)
new_shape.append(vi)
except ValueError:
new_shape.append(i)
rshape = tuple(new_shape)
else:
rshape = None
dt = None if dtype is None else getattr(np, dtype)
return dt, rshape, law
if isinstance(input_def, ValueInfoProto):
try:
ttype = input_def.type.tensor_type
except AttributeError:
raise ValueError(f"Unsupported input type {input_def!r}.")
dt = ttype.elem_type
new_shape = []
for d in ttype.shape.dim:
if d.dim_param:
new_shape.append(d.dim_param)
else:
new_shape.append(d.dim_value)
ndt = tensor_dtype_to_np_dtype(dt)
return ndt, tuple(new_shape), None
raise TypeError(f"Unexpected type {type(input_def)} for input_def.")
def _generate_random_inputs(
dtype: Any,
shape: Tuple[Union[int, str], ...],
law: Optional[str] = None,
dims: Optional[Dict[str, int]] = None,
) -> Tuple[np.ndarray, Dict[str, int]]:
"""
Creates random or specific inputs.
:param dtype: numpy dtype
:param shape: expected shape
:param law: law of the coefficients, default is 'U10', uniform law
:param dims: letter are allowed, contains the named dimensions already
mapped to a specific value
:return: tuple (array, updated dims)
Dimensions are modified inplace.
"""
if dims is None:
dims = {}
if law is None:
law = "U10"
new_shape = []
for sh in shape:
if isinstance(sh, int):
new_shape.append(sh)
elif isinstance(sh, str):
if sh not in dims:
dims[sh] = 8
new_shape.append(dims[sh])
final_shape = tuple(new_shape)
if law == "U10":
res = np.random.random(final_shape).astype(dtype)
return res, dims
raise ValueError(f"Unexpected value for law={law!r}.")
[docs]def print_proto(proto: str, fmt: str = "raw", external: bool = True):
"""
Shows an onnx model or a protobuf string on stdout.
Extension '.onnx' is considered a model,
extension '.proto' or '.pb' is a protobuf string.
:param proto: a file
:param fmt: format to use to print the model,
`raw` prints out the string produced by `print(model)`,
`nodes` only prints out the node name
:param external: loads with external data
"""
if isinstance(proto, str):
if not os.path.exists(proto):
raise FileNotFoundError(f"Unable to find file {proto!r}.")
ext = os.path.splitext(proto)[-1]
if ext == ".onnx":
with open(proto, "rb") as f:
proto_loaded = load(f, load_external_data=external)
elif ext in (".pb", ".proto"):
with open(proto, "rb") as f:
content = f.read()
exc = []
proto_loaded = None
for cls in [
TensorProto,
SequenceProto,
FunctionProto,
ModelProto,
GraphProto,
]:
inst = cls()
try:
inst.ParseFromString(content)
proto_loaded = inst
break
except Exception as e:
exc.append((cls, e))
if proto_loaded is None:
msg = "\n".join(f"type: {c}: {e}" for c, e in exc)
raise RuntimeError(f"Unable to load {proto!r}, tried:\n{msg}")
else:
raise ValueError(f"Unexpected file extension {ext!r} for file {proto!r}.")
else:
proto_loaded = proto
print(f"Type: {type(proto_loaded)}")
if fmt == "raw":
print(proto_loaded)
elif fmt == "nodes":
from .tools.graph.onnx_graph_struct import Graph
if proto_loaded is None:
raise ValueError(f"Unable to load {proto!r}.")
graph = Graph(proto_loaded)
for node in graph:
print(str(node).replace("<parent>, ", ""))
else:
raise ValueError(f"Unexpected value for fmt={fmt!r}.")
[docs]def cmd_quantize(
model: Union[ModelProto, str],
output: Optional[str] = None,
kind: str = "fp8",
scenario: str = "onnxruntime",
early_stop: Optional[int] = None,
quiet: bool = False,
verbose: int = 0,
index_transpose: int = 2,
exceptions: Optional[List[Dict[str, str]]] = None,
options: Optional["QuantizeOptions"] = None, # noqa: F821
):
"""
Quantizes a model
:param model: path to a model or ModelProto
:param output: output file
:param kind: kind of quantization
:param scenario: depends on the quantization
:param early_stop: stops early to see the preliminary results
:param quiet: do not stop an exception
:param verbose: verbosity level
:param index_transpose: which input to transpose before calling gemm:
0 (none), 1 (first), 2 (second), 3 for both
:param exceptions: exclude nodes from the quantization,
`[{"name": "node_name1"}, {"name": "node_name2"}]` will exclude
these two node names from the quantization
:param options: quantization options, see class
:class:`QuantizeOptions <onnx_extended.tools.graph.QuantizeOptions>`
"""
from .tools.graph import Graph, QuantizeOptions
if options is None:
options = QuantizeOptions.NONE
if isinstance(model, str):
if not os.path.exists(model):
raise FileNotFoundError(f"Unable to find file {model!r}.")
ext = os.path.splitext(model)[-1]
if ext == ".onnx":
with open(model, "rb") as f:
proto_loaded = load(f)
else:
proto_loaded = model
graph = Graph(proto_loaded)
if verbose:
logging.basicConfig(
level=logging.WARN
if verbose > 2
else (logging.DEBUG if verbose > 1 else logging.INFO)
)
if kind == "fp8":
from .tools.graph import quantize_float8
logger = logging.getLogger("onnx-extended")
logger.info("Model initial size: %d", len(proto_loaded.SerializeToString()))
new_graph = quantize_float8(
graph,
early_stop=early_stop or -1,
quiet=quiet,
version=scenario,
index_transpose=index_transpose,
exceptions=exceptions,
quantize_options=options,
)
if new_graph is None:
logger.warning("No node was quantized.")
return
onx2 = new_graph.to_onnx()
seq = onx2.SerializeToString()
logger.info("Model quantized size: %d", len(seq))
if output is not None:
with open(output, "wb") as f:
f.write(seq)
return
if kind == "fp16":
from .tools.graph import cast_constant
logger = logging.getLogger("onnx-extended")
logger.info("Model initial size: %d", len(proto_loaded.SerializeToString()))
new_graph = cast_constant(
graph,
quiet=quiet,
from_type=TensorProto.FLOAT,
to_type=TensorProto.FLOAT16,
)
if new_graph is None:
logger.warning("No node was modified.")
return
onx2 = new_graph.to_onnx()
seq = onx2.SerializeToString()
logger.info("Model reduced size: %d", len(seq))
if output is not None:
with open(output, "wb") as f:
f.write(seq)
return
raise ValueError(f"Unexpected value {kind!r} for kind.")
[docs]def cmd_select(
model: Union[ModelProto, str],
save: Optional[str] = None,
inputs: Optional[Union[str, List[str]]] = None,
outputs: Optional[Union[str, List[str]]] = None,
verbose: int = 0,
):
"""
Selects a subgraph in a model.
:param model: path to a model or ModelProto
:param save: model ot save in this file
:param inputs: list of inputs or empty to keep the original inputs
:param outputs: list of outputs or empty to keep the original outputs
:param verbose: verbosity level
"""
from .tools.onnx_manipulations import select_model_inputs_outputs
if isinstance(model, str):
if not os.path.exists(model):
raise FileNotFoundError(f"Unable to find file {model!r}.")
ext = os.path.splitext(model)[-1]
if ext == ".onnx":
with open(model, "rb") as f:
proto_loaded = load(f)
else:
proto_loaded = model
if verbose:
logging.basicConfig(
level=logging.WARN
if verbose > 2
else (logging.DEBUG if verbose > 1 else logging.INFO)
)
if isinstance(inputs, str):
inputs = inputs.strip().split(",")
if isinstance(outputs, str):
outputs = outputs.strip().split(",")
logger = logging.getLogger("onnx-extended")
logger.info("Initial model size: %d", len(proto_loaded.SerializeToString()))
onx2 = select_model_inputs_outputs(
proto_loaded,
inputs=inputs,
outputs=outputs,
verbose=verbose,
)
seq = onx2.SerializeToString()
logger.info("Selected model size: %d", len(seq))
if save is not None:
with open(save, "wb") as f:
f.write(seq)
def plot_profile(
filename: str,
kind: str,
out_csv: Optional[str] = None,
out_png: Optional[str] = None,
title: Optional[str] = None,
with_shape: bool = False,
verbose: int = 0,
):
"""
Plots a profiling.
:param filename: raw data to load
:param kind: kind of plot to so, see below
:param out_csv: output the data into that csv file
:param out_png: output the graph in that file
:param with_shape: consider input shape when showing results
:param title: title (optional)
:param verbose: verbosity, if > 0, prints out the data in csv format
"""
import matplotlib.pyplot as plt
from .tools.js_profile import (
js_profile_to_dataframe,
plot_ort_profile,
_preprocess_graph1,
_preprocess_graph2,
)
if verbose:
print(f"[plot_profile] load {filename!r}")
if kind == "profile_op":
df = js_profile_to_dataframe(filename, first_it_out=True, with_shape=with_shape)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plot_ort_profile(df, ax, title=title)
df = _preprocess_graph1(df)
elif kind == "profile_node":
df = js_profile_to_dataframe(
filename, first_it_out=True, agg=True, with_shape=with_shape
)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
plot_ort_profile(df, ax, title=title)
df = _preprocess_graph2(df)
else:
raise ValueError(f"Unexpected kind {kind:r}.")
if verbose:
st = io.StringIO()
df.to_csv(st)
print(st.getvalue())
if out_csv not in {"", None}:
if verbose:
print(f"[plot_profile] save {out_csv!r}")
df.to_csv(out_csv)
if out_png not in {"", None}:
if verbose:
print(f"[plot_profile] save {out_png!r}")
fig.savefig(out_png)
[docs]def cmd_plot(
filename: str,
kind: str,
out_csv: Optional[str] = None,
out_png: Optional[str] = None,
title: Optional[str] = None,
with_shape: bool = False,
verbose: int = 0,
):
"""
Plots a graph.
:param filename: raw data to load
:param kind: kind of plot to so, see below
:param out_csv: output the data into that csv file
:param out_png: output the graph in that file
:param title: title (optional)
:param with_shape: keep the shape to aggregate
:param verbose: verbosity, if > 0, prints out the data in csv format
Kinds of plots:
* `'profile_op'`: draws the profiling per node type
* `'profile_node'`: draws the profiling per node
"""
if not os.path.exists(filename):
raise FileNotFoundError(f"Unable to find {filename!r}.")
allowed = {"profile_op", "profile_node"}
if kind in allowed:
plot_profile(filename, kind, out_csv, out_png, title=title, verbose=verbose)
else:
raise ValueError(f"Unexpected kind {kind:r}, it should be {allowed}.")