Source code for onnx_array_api.translate_api.translate

from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import to_array
from ..reference import to_array_extended
from .base_emitter import EventType
from .light_emitter import LightEmitter


[docs] class Translater: """ Translates an ONNX graph into a code following the light API. """ def __init__( self, proto: Union[ModelProto, FunctionProto, GraphProto], emitter: Optional[LightEmitter] = None, ): self.proto_ = proto self.emitter = emitter or LightEmitter() def __repr__(self) -> str: return f"{self.__class__.__name__}(<{type(self.proto_)})"
[docs] def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: """ Exports into a code. :param as_str: as a single string or by rows :param single_line: tries to compress the output into a single line :return: list of instructions """ rows = [] last_event = None if isinstance(self.proto_, ModelProto): opsets = {d.domain: d.version for d in self.proto_.opset_import} rows.extend(self.emitter(EventType.START, opsets=opsets)) inputs = self.proto_.graph.input outputs = self.proto_.graph.output nodes = self.proto_.graph.node initializers = self.proto_.graph.initializer sparse_initializers = self.proto_.graph.sparse_initializer attributes = [] last_event = EventType.TO_ONNX_MODEL is_function = False elif isinstance(self.proto_, (FunctionProto, GraphProto)): inputs = self.proto_.input outputs = self.proto_.output nodes = self.proto_.node if isinstance(self.proto_, GraphProto): initializers = self.proto_.initializer sparse_initializers = self.proto_.sparse_initializer else: initializers = [] sparse_initializers = [] attributes = ( self.proto_.attribute if hasattr(self.proto_, "attribute") else [] ) is_function = isinstance(self.proto_, FunctionProto) last_event = ( EventType.TO_ONNX_FUNCTION if is_function else EventType.TO_ONNX_MODEL ) else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") if sparse_initializers: raise NotImplementedError("Sparse initializer not supported yet.") if is_function: rows.extend( self.emitter( EventType.BEGIN_FUNCTION, name=self.proto_.name, domain=self.proto_.domain, ) ) elif isinstance(self.proto_, GraphProto): rows.extend(self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.name)) else: rows.extend( self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name) ) for i in initializers: rows.extend( self.emitter( EventType.INITIALIZER, name=i.name, init=i, value=to_array_extended(i), ) ) rows.extend(self.emitter(EventType.BEGIN_SIGNATURE)) for i in inputs: if is_function: rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i)) else: rows.extend( self.emitter( EventType.INPUT, name=i.name, elem_type=i.type.tensor_type.elem_type, shape=tuple( d.dim_value or d.dim_param for d in i.type.tensor_type.shape.dim ), ) ) if is_function and attributes: rows.extend( self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes)) ) rows.extend(self.emitter(EventType.END_SIGNATURE)) for node in nodes: atts = self.extract_attributes(node) rows.extend( self.emitter( EventType.NODE, op_type=node.op_type, inputs=node.input, outputs=node.output, domain=node.domain, atts=atts, ) ) rows.extend(self.emitter(EventType.BEGIN_RETURN)) for o in outputs: if is_function: rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o)) else: rows.extend( self.emitter( EventType.OUTPUT, name=o.name, elem_type=o.type.tensor_type.elem_type, shape=tuple( d.dim_value or d.dim_param for d in o.type.tensor_type.shape.dim ), ) ) rows.extend(self.emitter(EventType.END_RETURN)) if isinstance(self.proto_, (GraphProto, FunctionProto)): name = self.proto_.name else: name = self.proto_.graph.name rows.extend( self.emitter( EventType.END_FUNCTION if is_function else EventType.END_GRAPH, name=name, ) ) if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0: for fu in self.proto_.functions: cl = self.__class__(fu, self.emitter) text = cl.export(False, single_line=False) rows.extend(text) rows.extend(self.emitter(last_event)) if as_str: return self.emitter.join(rows, single_line=single_line) return rows
[docs] def extract_attributes( self, node: NodeProto ) -> Dict[str, Tuple[AttributeProto, Any]]: """ Extracts all atributes of a node. :param node: node proto :return: dictionary """ atts: Dict[str, Tuple[AttributeProto, Any]] = {} for att in node.attribute: if hasattr(att, "ref_attr_name") and att.ref_attr_name: atts[att.name] = (att, None) continue if att.type == AttributeProto.INT: atts[att.name] = (att, att.i) continue if att.type == AttributeProto.FLOAT: atts[att.name] = (att, att.f) continue if att.type == AttributeProto.INTS: atts[att.name] = (att, np.array(att.ints)) continue if att.type == AttributeProto.FLOATS: atts[att.name] = (att, np.array(att.floats, dtype=np.float32)) continue if ( att.type == AttributeProto.GRAPH and hasattr(att, "g") and att.g is not None ): atts[att.name] = (att, None) continue if att.type == AttributeProto.SPARSE_TENSORS: atts[att.name] = (att, to_array(att.sparse_tensor)) continue if att.type == AttributeProto.TENSOR: atts[att.name] = (att, to_array(att.t)) continue if att.type == AttributeProto.TENSORS: atts[att.name] = (att, [to_array(t) for t in att.tensors]) continue if att.type == AttributeProto.SPARSE_TENSORS: atts[att.name] = (att, [to_array(t) for t in att.sparse_tensors]) continue if att.type == AttributeProto.STRING: atts[att.name] = (att, att.s.decode("utf-8")) continue if att.type == AttributeProto.STRINGS: atts[att.name] = ( att, np.array([s.decode("utf-8") for s in att.strings]), ) continue raise ValueError( f"Attribute {att.name!r} with type {att.type} cannot be extracted yet." ) return atts