Source code for onnx_array_api.validation.tools

from typing import Union
import numpy
from onnx import (
    AttributeProto,
    GraphProto,
    FunctionProto,
    ModelProto,
    NodeProto,
    TensorProto,
)
from onnx.helper import (
    make_attribute,
    make_function,
    make_graph,
    make_model,
    make_node,
    set_model_props,
)
from ..reference import from_array_extended as from_array, to_array_extended as to_array


[docs]def randomize_proto( onx: Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto] ) -> Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto]: """ Randomizes float initializers or constant nodes. :param onx: onnx model or proto :return: same object """ if isinstance(onx, TensorProto): t = to_array(onx) mini, maxi = t.min(), t.max() new_t = numpy.clip( numpy.random.random(t.shape) * (maxi - mini) + mini, mini, maxi ) return from_array(new_t.astype(t.dtype), name=onx.name) if isinstance(onx, ModelProto): new_graph = randomize_proto(onx.graph) new_functions = [randomize_proto(f) for f in onx.functions] onnx_model = make_model( new_graph, functions=new_functions, ir_version=onx.ir_version, producer_name=onx.producer_name, domain=onx.domain, doc_string=onx.doc_string, opset_imports=list(onx.opset_import), ) if len(onx.metadata_props) > 0: values = {p.key: p.value for p in onx.metadata_props} set_model_props(onnx_model, values) return onnx_model if isinstance(onx, (GraphProto, FunctionProto)): nodes = [] for node in onx.node: if node.op_type in "Constant": nodes.append(randomize_proto(node)) continue changed = False atts = [] for att in node.attribute: if att.type == AttributeProto.GRAPH: new_g = randomize_proto(att.g) att = make_attribute(att.name, new_g) changed = True atts.append(att) if changed: new_node = make_node( node.op_type, node.input, node.output, domain=node.domain ) new_node.attribute.extend(node.attribute) nodes.append(new_node) continue nodes.append(node) if isinstance(onx, FunctionProto): new_onx = make_function( onx.domain, onx.name, onx.input, onx.output, nodes, opset_imports=onx.opset_import, ) return new_onx inits = [randomize_proto(init) for init in onx.initializer] sp_inits = [randomize_proto(init) for init in onx.sparse_initializer] graph = make_graph( nodes, onx.name, onx.input, onx.output, initializer=inits, sparse_initializer=sp_inits, ) return graph raise TypeError(f"Unexpected type for onx {type(onx)}.")