Source code for experimental_experiment.xbuilder.reverse_graph_builder

from typing import Any, Dict, List, Union
import onnx
from onnx_array_api.translate_api.translate import Translater
from onnx_array_api.translate_api.builder_emitter import BuilderEmitter


[docs] class CustomBuilderEmitter(BuilderEmitter): """ Custom :class:`onnx_array_api.translate_api.builder_emitter.BuilderEmitter`. """ def __init__(self, make_model_function: str = "make_my_model"): super().__init__(make_model_function="make_my_model") def _emit_node_type(self, op_type, op_domain): if op_type in {"Squeeze", "Unsqueeze"} or op_type.startswith("Reduce"): return f"{op_type}AnyOpset" return op_type def _clean_result_name(self, name): return name.replace("#", "__").replace("-", "_") def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: rows = super()._emit_end_function(**kwargs) return [ *rows[:-1], " opts = FunctionOptions(", f" name={self.f_name!r},", f" domain={self.f_domain!r},", " move_initializer_to_constant=True,", " )", " g.make_local_function(gr, opts, optimize=False)", ]
[docs] def to_graph_builder_code(proto: onnx.ModelProto, function_name: str = "build_model") -> str: """ Produces a code building a model with :class:`experimental_experiment.xbuilder.GraphBuilder`. :param proto: model to convert into a code :param function_name: function name :return: str Example (see also :ref:`l-plot-model-to-code`): .. runpython:: :showcode: import numpy as np import onnx import onnx.helper as oh import onnx.numpy_helper as onh from experimental_experiment.xbuilder.reverse_graph_builder import ( to_graph_builder_code, ) TFLOAT = onnx.TensorProto.FLOAT TINT64 = onnx.TensorProto.INT64 model = oh.make_model( oh.make_graph( [ oh.make_node( "ConstantOfShape", ["shape"], ["cst"], value=onh.from_array(np.array([0], dtype=np.float32)), ), oh.make_node( "ScatterND", ["cst", "indices", "updates"], ["Z"], reduction="add", ), ], "create_graph", [ oh.make_tensor_value_info("shape", TINT64, [None]), oh.make_tensor_value_info("indices", TINT64, [None, None]), oh.make_tensor_value_info("updates", TFLOAT, [None, None, None]), ], [oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])], ), opset_imports=[ oh.make_opsetid("", 18), ], ir_version=9, ) print(to_graph_builder_code(model)) """ tr = Translater(proto, emitter=CustomBuilderEmitter()) code = tr.export(as_str=True) return "\n".join( [ "import numpy as np", "from onnx import TensorProto", "from onnx.numpy_helper import from_array", "from experimental_experiment.xbuilder import GraphBuilder, FunctionOptions", "", "", code.replace("array(nan", "array(np.nan"), ] )
[docs] def to_graph_pattern_matching( proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto], ) -> str: """ Produces a code matching a pattern. :param proto: model to convert into a code :return: str Example (see also :ref:`l-plot-model-to-code`): .. runpython:: :showcode: import numpy as np import onnx import onnx.helper as oh import onnx.numpy_helper as onh from experimental_experiment.xbuilder.reverse_graph_builder import ( to_graph_pattern_matching, ) TFLOAT = onnx.TensorProto.FLOAT TINT64 = onnx.TensorProto.INT64 model = oh.make_model( oh.make_graph( [ oh.make_node( "ConstantOfShape", ["shape"], ["cst"], value=onh.from_array(np.array([0], dtype=np.float32)), ), oh.make_node( "ScatterND", ["cst", "indices", "updates"], ["Z"], reduction="add", ), ], "create_graph", [ oh.make_tensor_value_info("shape", TINT64, [None]), oh.make_tensor_value_info("indices", TINT64, [None, None]), oh.make_tensor_value_info("updates", TFLOAT, [None, None, None]), ], [oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])], ), opset_imports=[ oh.make_opsetid("", 18), ], ir_version=9, ) print(to_graph_pattern_matching(model)) """ if isinstance(proto, onnx.FunctionProto): nodes = proto.node input_names = proto.input output_names = proto.output inits = set() raise NotImplementedError("Not yet implemented for FunctionProto.") elif isinstance(proto, onnx.GraphProto): nodes = proto.node input_names = [i.name for i in proto.input] output_names = [i.name for i in proto.output] inits = proto.initializer elif isinstance(proto, onnx.ModelProto): nodes = proto.graph.node input_names = [i.name for i in proto.graph.input] output_names = [i.name for i in proto.graph.output] inits = proto.graph.initializer else: raise TypeError(f"Unable to process type {type(proto)}.") assert nodes, "No node to process." assert len(output_names) == 1, ( f"Function is not implemented yet for " f"input_names={input_names!r} and output_names={output_names!r}" ) def _clean(s: str) -> str: return s.replace(".", "_").replace("-", "_") matches = { (node.op_type, node.domain, tuple(node.input), tuple(node.output)): False for node in nodes } position = { (node.op_type, node.domain, tuple(node.input), tuple(node.output)): i for i, node in enumerate(nodes) } outside = set(input_names) | set(i.name for i in inits) successors = {} predecessors = {} for node in nodes: for i in node.input: if i not in successors: successors[i] = [] successors[i].append(node) for i in node.output: predecessors[i] = node first_node = True rows = [] stack_names = [*output_names] nodes_names = [] while stack_names: rows.append("") name = stack_names.pop() if name not in predecessors: # stop here rows.append(f"# {_clean(name)} has no predecessor.") continue if name not in outside and name in successors and len(successors[name]) == 1: rows.extend( [ f"if g.is_used_more_than_once({_clean(name)}):", " return self.none(node, inspect.currentframe().f_lineno)", ] ) node = predecessors[name] if not node.input: # A constant. We skip. continue key = node.op_type, node.domain, tuple(node.input), tuple(node.output) matched = matches[key] if matched: # We skip for the time being but we should do extract verification. rows.append(f"# {_clean(name)} is already processed.") continue node_name = f"node_{position[key]}_{node.op_type}" nodes_names.append(node_name) if first_node: first_node = False assert not matched, f"Algorithm issues, matches={matches}, key={key}" rows.extend( [ f"{node_name} = node", ( f"if {_clean(node_name)}.op_type != {node.op_type!r} or " f"{_clean(node_name)}.domain != {node.domain!r}:" ), " return self.none()", ] ) matches[key] = True stack_names.extend(node.input) for i_, n_ in enumerate(node.input): rows.append(f"{_clean(n_)} = {_clean(node_name)}.input[{i_}]") continue # Another node rows.extend( [ f"{_clean(node_name)} = g.node_before({_clean(name)})", ( f"if {_clean(node_name)} is None or {_clean(node_name)}.op_type != " f"{node.op_type!r} or {_clean(node_name)}.domain != {node.domain!r}:" ), (" return self.none(node, inspect.currentframe().f_lineno)"), ] ) matches[key] = True stack_names.extend(node.input) for i_, n_ in enumerate(node.input): rows.append(f"{_clean(n_)} = {_clean(node_name)}.input[{i_}]") continue rows.extend( [ "", "# list of nodes", f"nodes = [{', '.join(map(_clean,nodes_names[::-1]))}]", ] ) return "\n".join(rows)