Source code for onnx_array_api.tools.replace_constants

import numpy as np
from onnx import FunctionProto, ModelProto, GraphProto, AttributeProto
from onnx.helper import (
    make_model,
    set_model_props,
    make_graph,
    make_node,
    make_attribute,
    make_function,
    tensor_dtype_to_np_dtype,
)
from onnx.numpy_helper import from_array


[docs] def replace_initializer_by_constant_of_shape( onx, threshold=128, op_type="ConstantOfShape", domain="" ): """ Replaces initializers by nodes *ConstantOfShape* to reduce the size and still write a unit test. :param onx: ModelProto :param threshold: every initializer under this threshold is not impacted :param op_type: replace by this node :param domain: replace by this domain :return: onx, modified ModelProto """ if isinstance(onx, FunctionProto): modified = False new_nodes = [] for node in onx.node: if node.op_type == "Constant": from onnx_array_api.reference import ExtendedReferenceEvaluator ref = ExtendedReferenceEvaluator(node) cst = ref.run(None, {})[0] size = np.prod(cst.shape) if size <= threshold: new_nodes.append(node) continue new_name = f"{node.output[0]}__SHAPE" new_nodes.append( make_node( "Constant", [], [new_name], value=from_array( np.array(cst.shape, dtype=np.int64), name=new_name ), ) ) dtype = cst.dtype assert op_type != "Constant" new_nodes.append( make_node( op_type, [new_name], node.output, value=from_array(np.array([0.5], dtype=dtype)), domain=domain, ) ) modified = True continue new_nodes.append(node) if not modified: return onx onxf = make_function( domain=onx.domain, fname=onx.name, inputs=onx.input, outputs=onx.output, nodes=new_nodes, doc_string=onx.doc_string, overload=onx.overload, opset_imports=[], ) if onx.opset_import: onxf.opset_import.extend(onx.opset_import) if onx.value_info: onxf.value_info.extend(onx.value_info) if onx.attribute: onxf.attribute.extend(onx.attribute) if onx.attribute_proto: onxf.attribute_proto.extend(onx.attribute_proto) return onxf if isinstance(onx, ModelProto): new_graph = replace_initializer_by_constant_of_shape( onx.graph, threshold=threshold, op_type=op_type, domain=domain ) new_functions = [ replace_initializer_by_constant_of_shape( f, threshold=threshold, op_type=op_type, domain=domain ) for f in onx.functions ] model = make_model( new_graph, functions=new_functions, producer_name=onx.producer_name, producer_version=onx.producer_version, ir_version=onx.ir_version, doc_string=onx.doc_string, domain=onx.domain, model_version=onx.model_version, ) if len(onx.metadata_props) > 0: # pragma: no cover values = {p.key: p.value for p in onx.metadata_props} set_model_props(model, values) del model.opset_import[:] # pylint: disable=E1101 for oimp in onx.opset_import: op_set = model.opset_import.add() # pylint: disable=E1101 if oimp.domain == "" and oimp.version < 9: raise RuntimeError( f"ConstantOfShape was introduced in " f"opset 9 but opset is {oimp.version}." ) op_set.domain = oimp.domain op_set.version = oimp.version return model if not isinstance(onx, GraphProto): raise TypeError(f"onx should be a GraphProto as this stage not {type(onx)}.") new_nodes = [] removed = set() additional_inputs = [] new_inits = [] for init in onx.initializer: dims = tuple(init.dims) size = np.prod(dims) if size <= threshold: new_inits.append(init) continue new_name = f"{init.name}__SHAPE" new_inits.append( from_array(np.array(list(dims), dtype=np.int64), name=new_name) ) dtype = tensor_dtype_to_np_dtype(init.data_type) node = make_node( op_type, [new_name], [init.name], value=from_array(np.array([0.5], dtype=dtype)), domain=domain, ) new_nodes.append(node) removed.add(init.name) new_sparse_inits = [] for init in onx.sparse_initializer: dims = tuple(init.dims) size = np.prod(dims) if size <= threshold: new_sparse_inits.append(init) continue raise NotImplementedError( f"This feature is not yet implemented for sparse initializer" f"(name={init.name!r})." ) for node in onx.node: if node.op_type == "Constant": from onnx_array_api.reference import ExtendedReferenceEvaluator ref = ExtendedReferenceEvaluator(node) cst = ref.run(None, {})[0] size = np.prod(cst.shape) if size <= threshold: new_nodes.append(node) continue new_name = f"{node.output[0]}__SHAPE" new_inits.append( from_array(np.array(cst.shape, dtype=np.int64), name=new_name) ) dtype = cst.dtype new_nodes.append( make_node( op_type, [new_name], node.output, value=from_array(np.array([0.5], dtype=dtype)), domain=domain, ) ) continue modified = False atts = [] for att in node.attribute: if ( att.type == AttributeProto.GRAPH and hasattr(att, "g") and att.g is not None ): modified = True g = replace_initializer_by_constant_of_shape( att.g, threshold=threshold, op_type=op_type, domain=domain ) att = make_attribute(att.name, g) atts.append(att) if modified: new_node = make_node(node.op_type, node.input, node.output) new_node.attribute.extend(atts) new_nodes.append(new_node) else: new_nodes.append(node) graph = make_graph( new_nodes, onx.name, [i for i in onx.input if i.name not in removed] + additional_inputs, onx.output, initializer=new_inits, sparse_initializer=new_sparse_inits, ) return graph