Source code for onnx_array_api.translate_api.light_emitter

from typing import Any, Dict, List
from ..annotations import ELEMENT_TYPE_NAME
from .base_emitter import BaseEmitter


[docs] class LightEmitter(BaseEmitter): """ Converts event into proper code. """
[docs] def join(self, rows: List[str], single_line: bool = False) -> str: "Join the rows" if single_line: return ".".join(rows) return "".join(["(\n ", "\n .".join(rows), "\n)"])
def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: opsets = kwargs.get("opsets", {}) opset = opsets.get("", None) if opset is not None: del opsets[""] args = [] if opset: args.append(f"opset={opset}") if opsets: args.append(f"opsets={opsets}") return [f"start({', '.join(args)})"] def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: return ["to_onnx()"] def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: return [] def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: return [] def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: return [] def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] value = kwargs["value"] repl = {"bool": "bool_", "object": "object_", "str": "str_"} sdtype = repl.get(str(value.dtype), str(str(value.dtype))) return [ f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))", f"rename({name!r})", ] def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] elem_type = kwargs.get("elem_type", None) shape = kwargs.get("shape", None) if elem_type and shape: return [ f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" ] if elem_type: return [ f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" ] return [f"vin({name!r})"] def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: inst = [] if "name" in kwargs: name = kwargs["name"] inst.append(f"bring({name!r})") elem_type = kwargs.get("elem_type", None) shape = kwargs.get("shape", None) if elem_type and shape: inst.append( f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" ) elif elem_type: inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})") else: inst.append("vout()") return inst def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: op_type = kwargs["op_type"] inputs = kwargs["inputs"] outputs = kwargs["outputs"] if kwargs.get("domain", "") != "": domain = kwargs["domain"] op_type = f"{domain}.{op_type}" atts = kwargs.get("atts", {}) args = [] for k, v in atts.items(): before, vatt = self.render_attribute_value(v) if before: raise NotImplementedError("Graph attribute not supported yet.") args.append(f"{k}={vatt}") str_inputs = ", ".join([f"{i!r}" for i in inputs]) inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"] if len(outputs) == 1: inst.append(f"rename({outputs[0]!r})") else: str_outputs = ", ".join([f"{o!r}" for o in outputs]) inst.append(f"rename({str_outputs})") return inst