import time
from itertools import product
from typing import Any, Callable, Dict, List, Optional, Union
import numpy
from onnx import AttributeProto, ModelProto, NodeProto, GraphProto, FunctionProto
from onnx.helper import make_model, make_node, make_graph, make_opsetid
from ...ext_test_case import measure_time
def has_subgraph(node: NodeProto) -> bool:
"""
Tells if a node has a subgraph as an attribute.
"""
for att in node.attribute:
if att.type == AttributeProto.GRAPH:
return True
return False
def get_node_attribute(node: NodeProto, name: str) -> AttributeProto:
"""
Returns the value of one attribute.
:param node: node
:param name: attribute name
:return: value
"""
for att in node.attribute:
if att.name == name:
return att
raise KeyError(
f"Unable to find {name!r} among {list(att.name for att in node.attribute)}."
)
[docs]def change_onnx_operator_domain(
onx: Union[ModelProto, GraphProto, FunctionProto],
op_type: str,
op_domain: str = "",
new_op_type: Optional[str] = None,
new_op_domain: Optional[str] = None,
new_opset: Optional[int] = None,
**kwargs: Dict[str, Any],
) -> Union[ModelProto, GraphProto, FunctionProto]:
"""
Replaces an operator by another one in the same domain
or another one.
:param onx: proto to modify
:param op_type: operator to look for
:param op_domain: domain to look for
:param new_op_type: new operator name or None for the same name
:param new_op_domain: new domain name or None the for the same domain
:param new_opset: new opset for the new domain, if not specified,
it is 1 for any opset other than ""
:param kwargs: modified parameters, set it to None to remove them
:return: same type as the input
The function is not recursive yet.
"""
def change_node(node):
atts = []
new_kwargs = {}
for att in node.attribute:
if att.name in kwargs:
v = kwargs[att.name]
if v is None:
continue
new_kwargs[att.name] = v
continue
atts.append(att)
for k, v in kwargs.items():
if v is None or k in new_kwargs:
continue
new_kwargs[k] = v
new_node = make_node(
new_op_type or node.op_type,
node.input,
node.output,
domain=new_op_domain or node.domain,
**new_kwargs,
)
if len(atts) > 0:
new_node.attribute.extend(atts)
return new_node
if isinstance(onx, GraphProto):
new_nodes = []
modified = False
for node in onx.node:
if has_subgraph(node):
raise NotImplementedError(
f"The function is not recursive yet and cannot "
f"handle node {node.op_type!r} from domain "
f"{node.domain!r}."
)
if node.op_type == op_type and node.domain == op_domain:
new_node = change_node(node)
new_nodes.append(new_node)
modified = True
continue
new_nodes.append(node)
if not modified:
return onx
return make_graph(
new_nodes,
onx.name,
onx.input,
onx.output,
onx.initializer,
onx.sparse_initializer,
)
if isinstance(onx, FunctionProto):
raise NotImplementedError()
if not isinstance(onx, ModelProto):
raise TypeError(f"Unexpected type for onx {type(onx)}.")
new_graph = change_onnx_operator_domain(
onx.graph,
op_type=op_type,
op_domain=op_domain,
new_opset=new_opset,
new_op_type=new_op_type,
new_op_domain=new_op_domain,
**kwargs,
)
if id(new_graph) == id(onx.graph):
# no change
return onx
if new_op_domain is None:
new_op_domain = op_domain
if new_op_domain == op_domain and new_opset is not None:
raise ValueError(
f"If new_op_domain=={new_op_domain!r}, "
f"new_opset must be None not {new_opset}."
)
opsets = list(onx.opset_import)
if new_op_domain != op_domain:
opsets.append(make_opsetid(new_op_domain, new_opset or 1))
new_model = make_model(
new_graph,
functions=onx.functions,
ir_version=onx.ir_version,
producer_name=onx.producer_name,
producer_version=onx.producer_version,
model_version=onx.model_version,
doc_string=onx.doc_string,
opset_imports=opsets,
domain=onx.domain,
)
return new_model
[docs]def optimize_model(
onx: ModelProto,
feeds: Dict[str, numpy.ndarray],
transform: Callable[[ModelProto], ModelProto],
session: Callable[[ModelProto], Any],
params: Dict[str, List[Any]],
baseline: Optional[Callable[[ModelProto], Any]] = None,
verbose: bool = False,
number: int = 10,
repeat: int = 10,
warmup: int = 5,
n_tries: int = 2,
sleep: float = 0.1,
) -> List[Dict[str, Union[str, float]]]:
"""
Optimizes a model by trying out many possibilities.
:param onx: ModelProto
:param feeds: inputs as a dictionary of numpy arrays
:param transform: function taking a ModelProto and returning a ModelProto
based on the values coming from *params*
:param session: function which takes a modifed ModelProto
and return a session
:param params: dictionary of values to test `{ param_name: [ param_values ] }`
:param baseline: function which takes a modifed ModelProto
and return a session, identified as the baseline
:param verbose: use :epkg:`tqdm` to show improvment
:param number: parameter to :func:`measure_time
<onnx_extended.ext_test_case.measure_time>`
:param repeat: parameter to :func:`measure_time
<onnx_extended.ext_test_case.measure_time>`
:param warmup: parameter to :func:`measure_time
<onnx_extended.ext_test_case.measure_time>`
:param n_tries: number of times to measure, if the measurements returns
very different results, values for *number* or *repeat* should
be increased
:param sleep: time to sleep between two measurements
:return: list of results returned by :func:`measure_time
<onnx_extended.ext_test_case.measure_time>`
See example :ref:`l-plot-optim-tree-ensemble` for an example.
"""
if sleep >= 1:
raise ValueError(f"sleep={sleep} >= 1, probably a mistake.")
keys = ["TRY"] + list(params.keys())
sets = [list(range(n_tries))] + [params[k] for k in keys[1:]]
loops = list(product(*sets))
if verbose:
from tqdm import tqdm
loop = tqdm(loops)
else:
loop = loops
res = []
if baseline is not None:
sess = baseline(onx)
# one run to make run it is working
sess.run(None, feeds)
if sleep > 0:
time.sleep(sleep)
obs: Dict[str, Any] = measure_time(
lambda sess=sess: sess.run(None, feeds),
number=number,
repeat=repeat,
warmup=warmup,
)
obs["n_exp"] = 0
obs["n_exp_name"] = "TRY=0,baseline"
obs["short_name"] = "0,baseline"
obs["TRY"] = 0
obs["name"] = "baseline"
res.append(obs)
for it, values in enumerate(loop):
if verbose:
msg = [f"i={it+1}/{len(loops)}"]
msg.extend([f"{k}={v}" for k, v in zip(keys, values)])
loop.set_description(" ".join(msg))
kwargs = dict(zip(keys, values))
del kwargs["TRY"]
onx_modified = transform(onx, **kwargs)
sess = session(onx_modified)
if sleep > 0:
time.sleep(sleep)
obsl: Dict[str, Any] = measure_time(
lambda sess=sess: sess.run(None, feeds),
number=number,
repeat=repeat,
warmup=warmup,
)
obsl.update(kwargs)
obsl["n_exp"] = it
obsl["n_exp_name"] = ",".join(f"{k}={v}" for k, v in zip(keys, values))
obsl["short_name"] = ",".join(f"{v}" for v in values)
obsl["name"] = ",".join(f"{v}" for v in values[1:])
res.append(obsl)
if baseline is not None:
for n in range(1, n_tries):
sess = baseline(onx)
if sleep > 0:
time.sleep(sleep)
obsf: Dict[str, Any] = measure_time(
lambda sess=sess: sess.run(None, feeds),
number=number,
repeat=repeat,
warmup=warmup,
)
obsf["n_exp"] = 0
obsf["n_exp_name"] = f"TRY={n},baseline"
obsf["short_name"] = f"{n},baseline"
obsf["name"] = "baseline"
obsf["TRY"] = n
res.append(obsf)
return res