Source code for onnx_array_api.light_api.translate
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import to_array
from .emitter import EventType, Emitter
[docs]class Translater:
"""
Translates an ONNX graph into a code following the light API.
"""
def __init__(
self,
proto: Union[ModelProto, FunctionProto, GraphProto],
emitter: Optional[Emitter] = None,
):
self.proto_ = proto
self.emitter = emitter or Emitter()
def __repr__(self) -> str:
return f"{self.__class__.__name__}(<{type(self.proto_)})"
[docs] def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
"""
Exports into a code.
:param as_str: as a single string or by rows
:param single_line: tries to compress the output into a single line
:return: list of instructions
"""
rows = []
if isinstance(self.proto_, ModelProto):
opsets = {d.domain: d.version for d in self.proto_.opset_import}
rows.extend(self.emitter(EventType.START, opsets=opsets))
inputs = self.proto_.graph.input
outputs = self.proto_.graph.output
nodes = self.proto_.graph.node
initializers = self.proto_.graph.initializer
sparse_initializers = self.proto_.graph.sparse_initializer
elif isinstance(self.proto_, (FunctionProto, GraphProto)):
inputs = self.proto_.input
outputs = self.proto_.output
nodes = self.proto_.node
if isinstance(self.proto_, GraphProto):
initializers = self.proto_.initializer
sparse_initializers = self.proto_.sparse_initializer
else:
initializers = []
sparse_initializers = []
else:
raise ValueError(f"Unexpected type {type(self.proto_)} for proto.")
if sparse_initializers:
raise NotImplementedError("Sparse initializer not supported yet.")
rows.extend(
self.emitter(
EventType.BEGIN_FUNCTION
if isinstance(self.proto_, FunctionProto)
else EventType.BEGIN_GRAPH
)
)
for i in initializers:
rows.extend(
self.emitter(
EventType.INITIALIZER, name=i.name, init=i, value=to_array(i)
)
)
for i in inputs:
if isinstance(i, str):
rows.extend(self.emitter(EventType.INPUT, name=i))
else:
rows.extend(
self.emitter(
EventType.INPUT,
name=i.name,
elem_type=i.type.tensor_type.elem_type,
shape=tuple(
d.dim_value or d.dim_param
for d in i.type.tensor_type.shape.dim
),
)
)
for node in nodes:
atts = self.extract_attributes(node)
rows.extend(
self.emitter(
EventType.NODE,
op_type=node.op_type,
inputs=node.input,
outputs=node.output,
domain=node.domain,
atts=atts,
)
)
for o in outputs:
if isinstance(o, str):
rows.extend(self.emitter(EventType.INPUT, name=o))
else:
rows.extend(
self.emitter(
EventType.OUTPUT,
name=o.name,
elem_type=o.type.tensor_type.elem_type,
shape=tuple(
d.dim_value or d.dim_param
for d in o.type.tensor_type.shape.dim
),
)
)
if isinstance(self.proto_, (GraphProto, FunctionProto)):
name = self.proto_.name
else:
name = self.proto_.graph.name
rows.extend(
self.emitter(
EventType.END_FUNCTION
if isinstance(self.proto_, FunctionProto)
else EventType.END_GRAPH,
name=name,
)
)
if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0:
raise NotImplementedError("Local functions are not yet implemented.")
rows.extend(self.emitter(EventType.TO_ONNX))
if as_str:
return self.emitter.join(rows, single_line=single_line)
return rows
[docs] def extract_attributes(
self, node: NodeProto
) -> Dict[str, Tuple[AttributeProto, Any]]:
"""
Extracts all atributes of a node.
:param node: node proto
:return: dictionary
"""
atts: Dict[str, Tuple[AttributeProto, Any]] = {}
for att in node.attribute:
if hasattr(att, "ref_attr_name") and att.ref_attr_name:
atts[att.name] = (att, None)
continue
if att.type == AttributeProto.INT:
atts[att.name] = (att, att.i)
continue
if att.type == AttributeProto.FLOAT:
atts[att.name] = (att, att.f)
continue
if att.type == AttributeProto.INTS:
atts[att.name] = (att, np.array(att.ints))
continue
if att.type == AttributeProto.FLOATS:
atts[att.name] = (att, np.array(att.floats, dtype=np.float32))
continue
if (
att.type == AttributeProto.GRAPH
and hasattr(att, "g")
and att.g is not None
):
atts[att.name] = (att, None)
continue
if att.type == AttributeProto.SPARSE_TENSORS:
atts[att.name] = (att, to_array(att.sparse_tensor))
continue
if att.type == AttributeProto.TENSOR:
atts[att.name] = (att, to_array(att.t))
continue
if att.type == AttributeProto.TENSORS:
atts[att.name] = (att, [to_array(t) for t in att.tensors])
continue
if att.type == AttributeProto.SPARSE_TENSORS:
atts[att.name] = (att, [to_array(t) for t in att.sparse_tensors])
continue
if att.type == AttributeProto.STRING:
atts[att.name] = (att, att.s.decode("utf-8"))
continue
if att.type == AttributeProto.STRINGS:
atts[att.name] = (
att,
np.array([s.decode("utf-8") for s in att.strings]),
)
continue
raise ValueError(
f"Attribute {att.name!r} with type {att.type} cannot be extracted yet."
)
return atts