Source code for experimental_experiment.convert.convert_helper
import time
from typing import Any, Dict, List, Optional, Union
from onnx import ModelProto, helper as oh, load as onnx_load
from onnx.inliner import inline_local_functions
[docs]
def inline_model_proto(model_proto: ModelProto) -> ModelProto:
"""
Inlines a model.
:param model_proto: ModelProto
:return: inlined model
"""
# model = onnx.load(input_file_name, load_external_data=False)
return inline_local_functions(model_proto)
def _fix_details(model: ModelProto, verbose: int = 0) -> ModelProto:
# ScatterND + Aten ops
print("[_fix_details] START")
for node in model.graph.node:
if node.op_type == "ScatterND":
if len(node.attribute) == 0:
if verbose:
print("[_fix_details] ScatterND, add reduction to add")
node.attribute.append(oh.make_attribute("reduction", "add"))
else:
red = node.attribute[0].s
if red != b"add":
if verbose:
print("[_fix_details] ScatterND, change reduction to add")
del node.attribute[:]
node.attribute.append(oh.make_attribute("reduction", "add"))
elif node.op_type == "ATen":
fname = None
for att in node.attribute:
if att.name == "operator":
fname = att.s
if fname == b"_scaled_dot_product_efficient_attention_backward":
if verbose:
print(
"[_fix_details] ATen, delete last output for "
"_scaled_dot_product_efficient_attention_backward"
)
outputs = list(node.output)
del node.output[:]
outputs[-1] = ""
node.output.extend(outputs)
if verbose:
print("[_fix_details] DONE")
return model
[docs]
def optimize_model_proto_oxs(
model_proto: ModelProto,
verbose: int = 0,
onnx_shape_inference: bool = False,
inplace: bool = True,
stats: Optional[Dict[str, Any]] = None,
) -> ModelProto:
"""
Optimizes a model proto to optimize onnxruntime.
:param model_proto: ModelProto
:param verbose: verbosity
:param onnx_shape_inference: enable shape inference
:param inplace: the function modifies the proto inplace as well
:param stats: if not empty, stores information
:return: optimized model
You should run that before calling this function
::
onnx_model = exported.to_model_proto(
opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version
)
from experimental_experiment.convert.convert_helper import optimize_model_proto_oxs
onnx_model = optimize_model_proto_oxs(onnx_model)
"""
from onnxscript.optimizer import optimize
from onnxscript.rewriter import rewrite
if verbose:
print(
f"[optimize_model_proto_oxs] starts optimize with "
f"{len(model_proto.graph.node)} nodes and "
f"{len(model_proto.functions)} local functions"
)
first_model_proto = model_proto
begin = time.perf_counter()
model_proto = optimize(
model_proto,
num_iterations=2,
onnx_shape_inference=onnx_shape_inference,
)
if stats:
stats["oxs_optimize_time"] = time.perf_counter() - begin
if verbose:
print(
f"[optimize_model_proto_oxs] optimize done in "
f"{time.perf_counter() - begin} seconds."
)
print(
f"[optimize_model_proto_oxs] starts rewrite with "
f"{len(model_proto.graph.node)} nodes and "
f"{len(model_proto.functions)} local functions"
)
begin = time.perf_counter()
model_proto = rewrite(model_proto)
if stats:
stats["oxs_rewrite_time"] = time.perf_counter() - begin
if verbose:
print(
f"[optimize_model_proto_oxs] rewrite done in {time.perf_counter() - begin} "
f"seconds with {len(model_proto.graph.node)} nodes and "
f"{len(model_proto.functions)} local functions"
)
print(
f"[optimize_model_proto_oxs] starts inlining with "
f"{len(model_proto.graph.node)} nodes and "
f"{len(model_proto.functions)} local functions"
)
begin = time.perf_counter()
model_proto = inline_local_functions(model_proto)
if stats:
stats["oxs_inline_time"] = time.perf_counter() - begin
if verbose:
print(
f"[optimize_model_proto_oxs] inlining done in {time.perf_counter() - begin} "
f"seconds with {len(model_proto.graph.node)} nodes and "
f"{len(model_proto.functions)} local functions"
)
# _fix_details(model_proto)
if inplace:
del first_model_proto.graph.node[:]
del first_model_proto.functions[:]
del first_model_proto.graph.initializer[:]
del first_model_proto.opset_import[:]
first_model_proto.graph.node.extend(model_proto.graph.node)
first_model_proto.functions.extend(model_proto.functions)
first_model_proto.graph.initializer.extend(model_proto.graph.initializer)
first_model_proto.opset_import.extend(model_proto.opset_import)
return model_proto
[docs]
def ort_optimize(
onnx_model: Union[str, ModelProto],
output: str,
providers: Union[str, List[str]] = "cpu",
disable_aot: bool = False,
):
"""
Optimizes the model with onnxruntime.
:param onnx_model: ModelProto or file path
:param output: path for the output
:param providers: providers, cpu, cuda or a list of providers
:param disable_aot: disable AOT
"""
import onnxruntime
from .ort_helper import append_custom_libraries
opts = onnxruntime.SessionOptions()
opts.optimized_model_filepath = output
if disable_aot:
opts.add_session_config_entry("session.disable_aot_function_inlining", "1")
if providers == "cpu":
providers = ["CPUExecutionProvider"]
elif not isinstance(providers, list) and providers.startswith("cuda"):
device_id = 0 if ":" not in providers else int(providers.split(":")[1])
providers = [
("CUDAExecutionProvider", {"device_id": device_id}),
("CPUExecutionProvider", {}),
]
assert isinstance(providers, list), f"Unexpected value for providers={providers!r}"
if isinstance(onnx_model, str):
onnx_model = onnx_load(onnx_model)
append_custom_libraries(onnx_model, opts)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(),
opts,
providers=providers,
)