Source code for onnx_array_api.translate_api.inner_emitter

from typing import Any, Dict, List, Optional, Tuple
from onnx import AttributeProto
from ..annotations import ELEMENT_TYPE_NAME
from .base_emitter import BaseEmitter
from .translate import Translater


[docs] class InnerEmitter(BaseEmitter): """ Converts event into proper code. """
[docs] def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: """ Renders an attribute value into a string. :param value: value to converter :return: rows to append before, actual value """ if value[0].type == AttributeProto.GRAPH: tr = Translater(value[0].g, emitter=self) rows = tr.export(as_str=False, single_line=False) new_rows = [f"def _make_local_graph_{value[0].name}():"] for line in rows: if "make_model" in line: break new_rows.append(" " + line) new_rows.append(" return graph") new_rows.append(f"{value[0].name} = _make_local_graph_{value[0].name}()") return new_rows, value[0].name return super().render_attribute_value(value)
def _make_attribute( self, name: str, attr_type: int, ref_attr_name: Optional[str] = None ) -> str: if ref_attr_name is None: raise NotImplementedError( f"Cannot create attribute with name={name!r}, attr_type={attr_type}." ) return f"make_ref_attribute(key={name!r}, attr_type={attr_type}, ref_attr_name={ref_attr_name!r})"
[docs] def join(self, rows: List[str], single_line: bool = False) -> str: "Returns the separators. `single_line` is unused." return "\n".join(rows)
def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: lines = ["opset_imports = ["] opsets = kwargs.get("opsets", {}) for k, v in opsets.items(): lines.append(f" make_opsetid({k!r}, {v!r}),") lines.append("]") return lines def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: lines = [ "model = make_model(", " graph,", " functions=functions,", " opset_imports=opset_imports", ")", ] return lines def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: lines = [ "inputs = []", "outputs = []", "nodes = []", "initializers = []", "sparse_initializers = []", "functions = []", ] return lines def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs.get("name", "noname") lines = [ "graph = make_graph(", " nodes,", f" {name!r},", " inputs,", " outputs,", " initializers,", " sparse_initializer=sparse_initializers,", ")", ] return lines def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] value = kwargs["value"] repl = {"bool": "bool_", "object": "object_", "str": "str_"} fra = "from_array" sdtype = repl.get(str(value.dtype), str(value.dtype)) if sdtype.startswith("("): from onnx.reference.custom_element_types import float8e4m3fn if sdtype == str(float8e4m3fn): sdtype = "float8e4m3fn" fra = "from_array_extended" else: raise NotImplementedError(f"Unexpected dtype={sdtype}.") else: sdtype = f"np.{sdtype}" return [ "initializers.append(", f" {fra}(", f" np.array({value.tolist()}, dtype={sdtype}),", f" name={name!r}", " )", ")", ] def _emit_io(self, container: str, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] elem_type = kwargs.get("elem_type", None) shape = kwargs.get("shape", None) if elem_type and shape: return [ f"{container}.append(make_tensor_value_info({name!r}, TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r}))" ] if elem_type: return [ f"{container}.append(make_tensor_value_info({name!r}, TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape=[]))" ] return [ f"{container}.append(make_tensor_value_info({name!r}, TensorProto.UNDEFINED, []))" ] def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: return self._emit_io("inputs", **kwargs) def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: return self._emit_io("outputs", **kwargs) def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: op_type = kwargs["op_type"] inputs = kwargs["inputs"] outputs = kwargs["outputs"] if kwargs.get("domain", "") != "": domain = kwargs["domain"] before_lines = [] lines = [ "nodes.append(", " make_node_extended(", f" {op_type!r},", f" {inputs},", f" {outputs},", ] domain = kwargs.get("domain", "") if domain: lines.append(f" domain={domain!r},") atts = kwargs.get("atts", {}) for k, v in atts.items(): before, value = self.render_attribute_value(v) before_lines.extend(before) lines.append(f" {k}={value},") lines[-1] = lines[-1][:-1] lines.extend([" )", ")"]) return before_lines + lines def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: lines = [ "", f"name_f = {kwargs['name']!r}", f"domain_f = {kwargs['domain']!r}", "nodes = []", "inputs = []", "outputs = []", "atts = []", ] return lines def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: return [] def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: return [f"inputs.append({kwargs['name']!r})"] def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: return [f"outputs.append({kwargs['name']!r})"] def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: atts = kwargs["attributes"] if isinstance(atts, list) and all(map(lambda t: isinstance(t, str), atts)): return [f"atts.extend({atts!r})"] raise NotImplementedError(f"Unable to process function attributes {atts!r}.") def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: lines = [ "functions.append(", " make_function(", " domain_f, ", " name_f, ", " inputs, ", " outputs, ", " nodes, ", " attributes=atts, ", " opset_imports=opset_imports,", " )", ")", ] return lines