Source code for experimental_experiment.xshape._onnx_helper

import onnx
import onnx.helper as oh
from typing import Iterator, Optional, Set


[docs] def element_wise_binary_op_types() -> Set[str]: """ Returns the list of element-wise operators. .. runpython:: :showcode: import pprint from experimental_experiment.xshape._onnx_helper import ( element_wise_binary_op_types, ) pprint.pprint(element_wise_binary_op_types()) """ return { "Add", "And", "Div", "Mul", "Mod", "Or", "Sub", "Xor", }
[docs] def element_wise_op_cmp_types() -> Set[str]: """ Returns the list of element-wise operators doing comparisons. .. runpython:: :showcode: import pprint from experimental_experiment.xshape._onnx_helper import element_wise_op_cmp_types pprint.pprint(element_wise_op_cmp_types()) """ return { "Equal", "Greater", "GreaterOrEqual", "Less", "LessOrEqual", }
[docs] def unary_like_op_types() -> Set[str]: """ Returns the list of unary *like* operators. They do not change the shape. They may change the type. .. runpython:: :showcode: import pprint from experimental_experiment.xshape._onnx_helper import unary_like_op_types pprint.pprint(unary_like_op_types()) """ return { "Abs", "Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "BitShift", "Cast", "CastLike", "Ceil", "Celu", "Clip", "Cos", "Cosh", "DequantizeLinear", "DynamicQuantizeLinear", "Elu", "Erf", "Exp", "IsInf", "Log", "LogSoftmax", "Neg", "Not", "PRelu", "QuantizeLinear", "Reciprocal", "Relu", "Round", "Selu", "Sigmoid", "Sign", "Sin", "Sinh", "Softmax", "SoftmaxCrossEntropyLoss", "Softplus", "Softsign", "Sqrt", "Tan", "Tanh", "ThresholdRelu", }
[docs] def str_tensor_proto_type() -> str: """ Returns the following string: .. runpython:: :showcode: from experimental_experiment.xshape._onnx_helper import str_tensor_proto_type print(str_tensor_proto_type()) """ mapping = [ (getattr(onnx.TensorProto, att), att) for att in dir(onnx.TensorProto) if att.upper() == att and isinstance(getattr(onnx.TensorProto, att), int) ] mapping.sort() return ", ".join(f"{k}:{v}" for k, v in mapping)
[docs] def enumerate_subgraphs(graph: onnx.GraphProto) -> Iterator[onnx.GraphProto]: """ Enumerates all inputs from a node including all the hidden inputs from subgraphs. """ yield graph for node in graph.node: if node.op_type[0] in "LSI" and node.op_type in {"Loop", "Scan", "If", "SequenceMap"}: for att in node.attribute: if att.type == onnx.AttributeProto.GRAPH: yield from enumerate_subgraphs(att.g)
def _rewrite_info(info: onnx.ValueInfoProto): shape = [] for i, dim in enumerate(info.type.tensor_type.shape.dim): if dim.dim_param: shape.append(dim.dim_param) else: name = f"dim{i}_{info.name}" shape.append(name) return oh.make_tensor_value_info(info.name, info.type.tensor_type.elem_type, shape)
[docs] def overwrite_shape_in_model_proto( model: onnx.ModelProto, n_in: Optional[int] = None ) -> onnx.ModelProto: """ Removes inferred shapes. Overwrites input shapes to make them all dynamic. ``n_in`` indicates the number of inputs for which the shape must be rewritten. """ assert isinstance(model, onnx.ModelProto), f"Unexpected type {type(model)} for model." for subgraph in enumerate_subgraphs(model.graph): new_info = [ _rewrite_info(inp) if n_in is None or i < n_in else inp for i, inp in enumerate(subgraph.input) ] del subgraph.input[:] subgraph.input.extend(new_info) new_info = [_rewrite_info(i) for i in subgraph.output] del subgraph.output[:] subgraph.output.extend(new_info) return model