import numpy as np
from onnx import FunctionProto, ModelProto, GraphProto, AttributeProto
from onnx.helper import (
make_model,
set_model_props,
make_graph,
make_node,
make_attribute,
make_function,
tensor_dtype_to_np_dtype,
)
from onnx.numpy_helper import from_array
[docs]
def replace_initializer_by_constant_of_shape(
onx, threshold=128, op_type="ConstantOfShape", domain=""
):
"""
Replaces initializers by nodes *ConstantOfShape* to reduce
the size and still write a unit test.
:param onx: ModelProto
:param threshold: every initializer under this threshold is not impacted
:param op_type: replace by this node
:param domain: replace by this domain
:return: onx, modified ModelProto
"""
if isinstance(onx, FunctionProto):
modified = False
new_nodes = []
for node in onx.node:
if node.op_type == "Constant":
from onnx_array_api.reference import ExtendedReferenceEvaluator
ref = ExtendedReferenceEvaluator(node)
cst = ref.run(None, {})[0]
size = np.prod(cst.shape)
if size <= threshold:
new_nodes.append(node)
continue
new_name = f"{node.output[0]}__SHAPE"
new_nodes.append(
make_node(
"Constant",
[],
[new_name],
value=from_array(
np.array(cst.shape, dtype=np.int64), name=new_name
),
)
)
dtype = cst.dtype
assert op_type != "Constant"
new_nodes.append(
make_node(
op_type,
[new_name],
node.output,
value=from_array(np.array([0.5], dtype=dtype)),
domain=domain,
)
)
modified = True
continue
new_nodes.append(node)
if not modified:
return onx
onxf = make_function(
domain=onx.domain,
fname=onx.name,
inputs=onx.input,
outputs=onx.output,
nodes=new_nodes,
doc_string=onx.doc_string,
overload=onx.overload,
opset_imports=[],
)
if onx.opset_import:
onxf.opset_import.extend(onx.opset_import)
if onx.value_info:
onxf.value_info.extend(onx.value_info)
if onx.attribute:
onxf.attribute.extend(onx.attribute)
if onx.attribute_proto:
onxf.attribute_proto.extend(onx.attribute_proto)
return onxf
if isinstance(onx, ModelProto):
new_graph = replace_initializer_by_constant_of_shape(
onx.graph, threshold=threshold, op_type=op_type, domain=domain
)
new_functions = [
replace_initializer_by_constant_of_shape(
f, threshold=threshold, op_type=op_type, domain=domain
)
for f in onx.functions
]
model = make_model(
new_graph,
functions=new_functions,
producer_name=onx.producer_name,
producer_version=onx.producer_version,
ir_version=onx.ir_version,
doc_string=onx.doc_string,
domain=onx.domain,
model_version=onx.model_version,
)
if len(onx.metadata_props) > 0: # pragma: no cover
values = {p.key: p.value for p in onx.metadata_props}
set_model_props(model, values)
del model.opset_import[:] # pylint: disable=E1101
for oimp in onx.opset_import:
op_set = model.opset_import.add() # pylint: disable=E1101
if oimp.domain == "" and oimp.version < 9:
raise RuntimeError(
f"ConstantOfShape was introduced in "
f"opset 9 but opset is {oimp.version}."
)
op_set.domain = oimp.domain
op_set.version = oimp.version
return model
if not isinstance(onx, GraphProto):
raise TypeError(f"onx should be a GraphProto as this stage not {type(onx)}.")
new_nodes = []
removed = set()
additional_inputs = []
new_inits = []
for init in onx.initializer:
dims = tuple(init.dims)
size = np.prod(dims)
if size <= threshold:
new_inits.append(init)
continue
new_name = f"{init.name}__SHAPE"
new_inits.append(
from_array(np.array(list(dims), dtype=np.int64), name=new_name)
)
dtype = tensor_dtype_to_np_dtype(init.data_type)
node = make_node(
op_type,
[new_name],
[init.name],
value=from_array(np.array([0.5], dtype=dtype)),
domain=domain,
)
new_nodes.append(node)
removed.add(init.name)
new_sparse_inits = []
for init in onx.sparse_initializer:
dims = tuple(init.dims)
size = np.prod(dims)
if size <= threshold:
new_sparse_inits.append(init)
continue
raise NotImplementedError(
f"This feature is not yet implemented for sparse initializer"
f"(name={init.name!r})."
)
for node in onx.node:
if node.op_type == "Constant":
from onnx_array_api.reference import ExtendedReferenceEvaluator
ref = ExtendedReferenceEvaluator(node)
cst = ref.run(None, {})[0]
size = np.prod(cst.shape)
if size <= threshold:
new_nodes.append(node)
continue
new_name = f"{node.output[0]}__SHAPE"
new_inits.append(
from_array(np.array(cst.shape, dtype=np.int64), name=new_name)
)
dtype = cst.dtype
new_nodes.append(
make_node(
op_type,
[new_name],
node.output,
value=from_array(np.array([0.5], dtype=dtype)),
domain=domain,
)
)
continue
modified = False
atts = []
for att in node.attribute:
if (
att.type == AttributeProto.GRAPH
and hasattr(att, "g")
and att.g is not None
):
modified = True
g = replace_initializer_by_constant_of_shape(
att.g, threshold=threshold, op_type=op_type, domain=domain
)
att = make_attribute(att.name, g)
atts.append(att)
if modified:
new_node = make_node(node.op_type, node.input, node.output)
new_node.attribute.extend(atts)
new_nodes.append(new_node)
else:
new_nodes.append(node)
graph = make_graph(
new_nodes,
onx.name,
[i for i in onx.input if i.name not in removed] + additional_inputs,
onx.output,
initializer=new_inits,
sparse_initializer=new_sparse_inits,
)
return graph