Source code for experimental_experiment.xbuilder._onnx_helper

from typing import Dict, Iterator, Optional, Set, Tuple, Union
import numpy as np
from onnx import AttributeProto, FunctionProto, GraphProto, NodeProto, TensorShapeProto
from onnx.defs import onnx_opset_version, get_all_schemas_with_history


def _default_OPSET_TO_IR_VERSION() -> Dict[int, int]:
    """
    Returns the dictionary mapping the main opset
    to the corresponding `ir_version`.
    """
    return {
        1: 3,
        2: 3,
        3: 3,
        4: 3,
        5: 3,
        6: 3,
        7: 3,
        8: 4,
        9: 4,
        10: 5,
        11: 6,
        12: 7,
        13: 7,
        14: 7,
        15: 8,
        16: 8,
        17: 8,
        18: 8,
        19: 9,
        20: 9,
        21: 10,
    }


OPSET_ML_TO_OPSET: Dict[int, int] = {1: 11, 2: 15, 3: 18}


[docs] def enumerate_subgraphs( node: NodeProto, recursive: bool = True ) -> Iterator[Tuple[Tuple[NodeProto, str, GraphProto], ...]]: """ Returns the subgraphs inside a graph. """ for att in node.attribute: if att.type == AttributeProto.GRAPH and att.g: this = node, att.name, att.g yield this for no in att.g.node: for tu in enumerate_subgraphs(no): yield this + tu
_history = None def _nice_shape(shape: TensorShapeProto) -> str: els = [] for sh in shape.dim: els.append(str(sh.dim_value) if sh.HasField("dim_value") else sh.dim_param) return "x".join(els)
[docs] def compatible_opsets(domain: str, op_type: str, current: int, new_version: int) -> bool: """ Tells if two opset version for a particular operator type means the same version of it. :param domain: domain, only `ai.onnx` and `ai.onnx.ml` are checked. :param op_type: operator type :param current: current domain version :param new_version: new version :return: result """ global _history if _history is None: res = {} for schema in get_all_schemas_with_history(): domain = schema.domain version = schema.since_version name = schema.name if domain not in res: res[domain] = {} if name not in res[domain]: res[domain][name] = {} res[domain][name][version] = schema _history = res assert domain in _history, f"Unable to find domain {domain!r} in {list(sorted(_history))}." assert op_type in _history[domain], ( f"Unable to find op_type {op_type!r}, domain={domain!r} " f"in {list(sorted(_history[domain]))}" ) hist = _history[domain][op_type] version = list(sorted(hist)) # noqa: C413 pos = np.searchsorted(version, current, side="right") - 1 assert pos >= 0, ( f"Available version for {op_type!r} from {domain!r}, " f"incompatible version is {current}" ) if pos < len(version) - 1: a, b = version[pos], version[pos + 1] return a <= new_version < b return new_version >= version[pos]
def _get_default_opset_for_domain(domain: str, main_opset: Optional[int] = None) -> int: """ Returns the associated for a domain given the main opset. :param domain: domain :param main_opset: opset for the main domain :return: version """ if main_opset is None: main_opset = choose_consistent_domain_opset("") if domain == "": return main_opset if domain == "ai.onnx.ml": if main_opset >= 18: return 3 if main_opset >= 6: return 2 return 1 if domain == "ai.onnx.training": return 1 return None
[docs] def choose_consistent_domain_opset( domain: str, opsets: Optional[Dict[str, int]] = None ) -> int: """ Chooses a compatible opset for a particular domain given this existing one. Only works for `ai.onnx.ml`, otherwise return 1. :param domain: new domain :param opsets: existing opsets :return: version """ opsets = opsets or {} if domain in opsets: return opsets[domain] if domain == "": assert "ai.onnx.ml" not in opsets, ( "If ai.onnx.ml is part of your model, " "your should add a version for the main opset as well." ) return onnx_opset_version() - 2 if domain != "ai.onnx.ml": return 1 return _get_default_opset_for_domain(domain)
[docs] def element_wise_binary_op_types() -> Set[str]: """ Returns the list of element-wise operators. .. runpython:: :showcode: import pprint from experimental_experiment.xbuilder._onnx_helper import ( element_wise_binary_op_types, ) pprint.pprint(element_wise_binary_op_types()) """ return { "Add", "Div", "Mul", "Sub", "And", "Or", "Mod", "Xor", }
[docs] def element_wise_op_cmp_types() -> Set[str]: """ Returns the list of element-wise operators doing comparisons. .. runpython:: :showcode: import pprint from experimental_experiment.xbuilder._onnx_helper import element_wise_op_cmp_types pprint.pprint(element_wise_op_cmp_types()) """ return { "Equal", "Greater", "GreaterOrEqual", "Less", "LessOrEqual", }
[docs] def unary_like_op_types() -> Set[str]: """ Returns the list of unary *like* operators. They do not change the shape. They may change the type. .. runpython:: :showcode: import pprint from experimental_experiment.xbuilder._onnx_helper import unary_like_op_types pprint.pprint(unary_like_op_types()) """ return { "Abs", "Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "BitShift", "Cast", "CastLike", "Ceil", "Celu", "Clip", "Cos", "Cosh", "DequantizeLinear", "DynamicQuantizeLinear", "Elu", "Erf", "Exp", "IsInf", "Log", "LogSoftmax", "Neg", "Not", "PRelu", "QuantizeLinear", "Reciprocal", "Relu", "Round", "Selu", "Sigmoid", "Sign", "Sin", "Sinh", "Softmax", "SoftmaxCrossEntropyLoss", "Softplus", "Softsign", "Sqrt", "Tan", "Tanh", "ThresholdRelu", }
[docs] def same_function_proto( f1: FunctionProto, f2: FunctionProto, verbose: int = 0 ) -> Union[str, bool]: """ Compares two functions and tells if they are equal. :param f1: first function :param f2: second function :param verbose: to know why the comparison failed, the function returns a string in that case or True :return: comparison They may have different names. """ if len(f1.input) != len(f2.input): return "different number of inputs" if verbose else False if len(f1.output) != len(f2.output): return "different number of outputs" if verbose else False if len(f1.node) != len(f2.node): return "different number of nodes" if verbose else False if len(f1.attribute) != len(f2.attribute): return "different number of attributes" if verbose else False if len(f1.attribute_proto) != len(f2.attribute_proto): return "different number of attributes (2)" if verbose else False if list(f1.attribute) != list(f2.attribute): return "different attribute names" if verbose else False if [a.SerializeToString() for a in f1.attribute_proto] != [ a.SerializeToString() for a in f2.attribute_proto ]: return "different attribute protos" if verbose else False mapped = dict(zip(f1.input, f2.input)) for i, (n1, n2) in enumerate(zip(f1.node, f2.node)): if n1.op_type != n2.op_type: return ( f"different node type at position {i} - {n1.op_type} != {n2.op_type}" if verbose else False ) if len(n1.input) != len(n2.input): return f"different number of inputs at node {i}" if verbose else False if len(n1.output) != len(n2.output): return f"different number of outputs at node {i}" if verbose else False if len(n1.attribute) != len(n2.attribute): return f"different number of attributes at node {i}" if verbose else False n2_input = [mapped[i] for i in n1.input] if list(n2.input) != n2_input: return ( f"different input names at node {i}, {n1.input}, {n2.input} != {n2_input}" if verbose else False ) for a1, a2 in zip(n1.attribute, n2.attribute): # The test should be improved for subgraphs. if a1.SerializeToString() != a2.SerializeToString(): return f"different attribute {a1.name!r}" if verbose else False mapped.update(dict(zip(n1.output, n2.output))) f2_output = [mapped[i] for i in f1.output] if list(f2.output) != f2_output: return ( f"different output names, {f1.output}, {f2.output} != {f2_output}" if verbose else False ) return True