Source code for yobx.xshape._onnx_helper

import onnx
import onnx.helper as oh
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union


[docs] def element_wise_binary_op_types() -> Set[str]: """ Returns the list of element-wise operators. .. runpython:: :showcode: import pprint from yobx.xshape._onnx_helper import ( element_wise_binary_op_types, ) pprint.pprint(element_wise_binary_op_types()) """ return { "Add", "And", "BitwiseAnd", "BitwiseOr", "BitwiseXor", "Div", "Max", "Mean", "Min", "Mul", "Mod", "Or", "Sub", "Sum", "Xor", }
[docs] def element_wise_op_cmp_types() -> Set[str]: """ Returns the list of element-wise operators doing comparisons. .. runpython:: :showcode: import pprint from yobx.xshape._onnx_helper import element_wise_op_cmp_types pprint.pprint(element_wise_op_cmp_types()) """ return { "Equal", "Greater", "GreaterOrEqual", "Less", "LessOrEqual", }
[docs] def unary_like_op_types() -> Set[str]: """ Returns the list of unary *like* operators. They do not change the shape. They may change the type. .. runpython:: :showcode: import pprint from yobx.xshape._onnx_helper import unary_like_op_types pprint.pprint(unary_like_op_types()) """ return { "Abs", "Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "BitShift", "BitwiseNot", "Cast", "CastLike", "Ceil", "Celu", "Clip", "Cos", "Cosh", "CumSum", "DequantizeLinear", "DynamicQuantizeLinear", "Elu", "Erf", "Exp", "Floor", "HardSigmoid", "HardSwish", "IsInf", "LeakyRelu", "Log", "LogSoftmax", "LpNormalization", "LRN", "MeanVarianceNormalization", "Mish", "Neg", "Not", "PRelu", "QuantizeLinear", "Reciprocal", "Relu", "Round", "Selu", "Shrink", "Sigmoid", "Sign", "Sin", "Sinh", "Softmax", "SoftmaxCrossEntropyLoss", "Softplus", "Softsign", "Sqrt", "Tan", "Tanh", "ThresholdedRelu", "ThresholdRelu", "Trilu", "Trunc", }
[docs] def str_tensor_proto_type() -> str: """ Returns the following string: .. runpython:: :showcode: from yobx.xshape._onnx_helper import str_tensor_proto_type print(str_tensor_proto_type()) """ mapping = [ (getattr(onnx.TensorProto, att), att) for att in dir(onnx.TensorProto) if att.upper() == att and isinstance(getattr(onnx.TensorProto, att), int) ] mapping.sort() return ", ".join(f"{k}:{v}" for k, v in mapping)
[docs] def enumerate_subgraphs(graph: onnx.GraphProto) -> Iterator[onnx.GraphProto]: """ Enumerates all inputs from a node including all the hidden inputs from subgraphs. """ yield graph for node in graph.node: if node.op_type[0] in "LSI" and node.op_type in {"Loop", "Scan", "If", "SequenceMap"}: for att in node.attribute: if att.type == onnx.AttributeProto.GRAPH: yield from enumerate_subgraphs(att.g)
def _rewrite_info(info: onnx.ValueInfoProto): shape = [] for i, dim in enumerate(info.type.tensor_type.shape.dim): if dim.dim_param: shape.append(dim.dim_param) else: name = f"dim{i}_{info.name}" shape.append(name) return oh.make_tensor_value_info(info.name, info.type.tensor_type.elem_type, shape)
[docs] def overwrite_shape_in_model_proto( model: onnx.ModelProto, n_in: Optional[int] = None ) -> onnx.ModelProto: """ Removes inferred shapes. Overwrites input shapes to make them all dynamic. ``n_in`` indicates the number of inputs for which the shape must be rewritten. """ assert isinstance(model, onnx.ModelProto), f"Unexpected type {type(model)} for model." for subgraph in enumerate_subgraphs(model.graph): new_info = [ _rewrite_info(inp) if n_in is None or i < n_in else inp for i, inp in enumerate(subgraph.input) ] del subgraph.input[:] subgraph.input.extend(new_info) new_info = [_rewrite_info(i) for i in subgraph.output] del subgraph.output[:] subgraph.output.extend(new_info) return model
[docs] def replace_static_dimensions_by_strings( model: onnx.ModelProto, ) -> Tuple[onnx.ModelProto, Dict[str, Union[str, int]]]: """ Replaces static dimensions by dynamic dimensions in a model. :param model: ModelProto :return: the modified model, a mapping ``{new_name: value}`` """ mapping: Dict[str, Union[str, int]] = {} new_inputs = [] for i in model.graph.input: if not i.type.tensor_type: continue dim = i.type.tensor_type.shape.dim shape: List[int | str] = [] for d in dim: if not d.dim_param and d.dim_value: shape.append(f"DIM{d.dim_value}") else: shape.append(d.dim_param or d.dim_value) mapping[shape[-1]] = d.dim_param or d.dim_value # type: ignore[index] new_inputs.append(oh.make_tensor_value_info(i.name, i.type.tensor_type.elem_type, shape)) new_outputs = [] for i in model.graph.output: if not i.type.tensor_type: continue dim = i.type.tensor_type.shape.dim shape = [] for d in dim: if not d.dim_param and d.dim_value: shape.append(f"DIM{d.dim_value}") else: shape.append(d.dim_param or d.dim_value) mapping[shape[-1]] = d.dim_param or d.dim_value # type: ignore[index] new_outputs.append(oh.make_tensor_value_info(i.name, i.type.tensor_type.elem_type, shape)) model = oh.make_model( oh.make_graph( model.graph.node, model.graph.name, new_inputs, new_outputs, model.graph.initializer, sparse_initializer=model.graph.sparse_initializer, doc_string=model.doc_string, ), opset_imports=model.opset_import, ir_version=model.ir_version, functions=model.functions, ) return model, mapping