Source code for onnx_array_api.plotting.text_plot

import pprint
from collections import OrderedDict
import numpy
from onnx import AttributeProto
from ..reference import to_array_extended as to_array
from ._helper import _get_shape, _get_type, attributes_as_dict


def _rule(r):
    if r == "BRANCH_LEQ":
        return "<="
    if r == "BRANCH_LT":
        return "<"
    if r == "BRANCH_GEQ":
        return ">="
    if r == "BRANCH_GT":
        return ">"
    if r == "BRANCH_EQ":
        return "=="
    if r == "BRANCH_NEQ":
        return "!="
    raise ValueError(f"Unexpected rule {r!r}.")


def _number2str(i):
    if isinstance(i, int):
        return str(i)
    if int(i) == i:
        return str(int(i))
    return f"{i:1.2f}"


[docs]def onnx_text_plot_tree(node): """ Gives a textual representation of a tree ensemble. :param node: `TreeEnsemble*` :return: text .. runpython:: :showcode: :warningout: DeprecationWarning, FutureWarning import numpy from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeRegressor from skl2onnx import to_onnx from onnx_array_api.plotting.text_plot import onnx_text_plot_tree iris = load_iris() X, y = iris.data.astype(numpy.float32), iris.target clr = DecisionTreeRegressor(max_depth=3) clr.fit(X, y) onx = to_onnx(clr, X) res = onnx_text_plot_tree(onx.graph.node[0]) print(res) """ class Node: "Node representation." def __init__(self, i, atts): self.nodes_hitrates = None self.nodes_missing_value_tracks_true = None for k, v in atts.items(): if k.startswith("nodes"): setattr(self, k, v[i]) self.depth = 0 self.true_false = "" self.targets = [] def append_target(self, tid, weight): self.targets.append(dict(target_id=tid, weight=weight)) def process_node(self): "node to string" if self.nodes_modes == "LEAF": if not self.targets: text = f"{self.true_false}f" elif len(self.targets) == 1: t = self.targets[0] text = ( f"{self.true_false}f " f"{t['target_id']}:{_number2str(t['weight'])}" ) else: ts = " ".join( map( lambda t: f"{t['target_id']}:{_number2str(t['weight'])}", self.targets, ) ) text = f"{self.true_false}f {ts}" else: text = "%sn X%d %s %r" % ( self.true_false, self.nodes_featureids, _rule(self.nodes_modes), self.nodes_values, ) if self.nodes_hitrates and self.nodes_hitrates != 1: text += f" hi={self.nodes_hitrates!r}" if self.nodes_missing_value_tracks_true: text += f" miss={self.nodes_missing_value_tracks_true!r}" return f"{' ' * self.depth}{text}" def process_tree(atts, treeid): "tree to string" rows = [f"treeid={treeid!r}"] if "base_values" in atts: if treeid < len(atts["base_values"]): rows.append(f"base_value={atts['base_values'][treeid]!r}") short = {} for prefix in ["nodes", "target", "class"]: if (f"{prefix}_treeids") not in atts: continue idx = [ i for i in range(len(atts[f"{prefix}_treeids"])) if atts[f"{prefix}_treeids"][i] == treeid ] for k, v in atts.items(): if k.startswith(prefix): if "classlabels" in k: short[k] = list(v) else: short[k] = [v[i] for i in idx] nodes = OrderedDict() for i in range(len(short["nodes_treeids"])): nodes[i] = Node(i, short) prefix = "target" if "target_treeids" in short else "class" for i in range(len(short[f"{prefix}_treeids"])): idn = short[f"{prefix}_nodeids"][i] node = nodes[idn] node.append_target( tid=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i] ) def iterate(nodes, node, depth=0, true_false=""): node.depth = depth node.true_false = true_false yield node if node.nodes_falsenodeids > 0: for n in iterate( nodes, nodes[node.nodes_falsenodeids], depth=depth + 1, true_false="-", ): yield n for n in iterate( nodes, nodes[node.nodes_truenodeids], depth=depth + 1, true_false="+", ): yield n for node in iterate(nodes, nodes[0]): rows.append(node.process_node()) return rows if node.op_type in ("TreeEnsembleRegressor", "TreeEnsembleClassifier"): d = attributes_as_dict(node) atts = {} for k, v in d.items(): atts[k] = v if isinstance(v, int) else list(v) trees = list(sorted(set(atts["nodes_treeids"]))) if "n_targets" in atts: rows = [f"n_targets={atts['n_targets']!r}"] else: rows = [ "n_classes=%r" % len( atts.get("classlabels_int64s", atts.get("classlabels_strings", [])) ) ] rows.append(f"n_trees={len(trees)!r}") for tree in trees: r = process_tree(atts, tree) rows.append("----") rows.extend(r) return "\n".join(rows) raise NotImplementedError( # pragma: no cover f"Type {node.op_type!r} cannot be displayed." )
def _append_succ_pred( subgraphs, successors, predecessors, node_map, node, prefix="", parent_node_name=None, ): node_name = prefix + node.name + "#" + "|".join(node.output) node_map[node_name] = node successors[node_name] = [] predecessors[node_name] = [] for name in node.input: predecessors[node_name].append(name) if name not in successors: successors[name] = [] successors[name].append(node_name) for name in node.output: successors[node_name].append(name) predecessors[name] = [node_name] if node.op_type in {"If", "Scan", "Loop", "Expression"}: for att in node.attribute: if ( att.type != AttributeProto.GRAPH or not hasattr(att, "g") or att.g is None ): continue subgraphs.append((node, att.name, att.g)) _append_succ_pred_s( subgraphs, successors, predecessors, node_map, att.g.node, prefix=node_name + ":/:", parent_node_name=node_name, parent_graph=att.g, ) def _append_succ_pred_s( subgraphs, successors, predecessors, node_map, nodes, prefix="", parent_node_name=None, parent_graph=None, ): for node in nodes: _append_succ_pred( subgraphs, successors, predecessors, node_map, node, prefix=prefix, parent_node_name=parent_node_name, ) if parent_node_name is not None: unknown = set() known = {} for i in parent_graph.initializer: known[i.name] = None for i in parent_graph.input: known[i.name] = None for n in parent_graph.node: for i in n.input: if i not in known: unknown.add(i) for i in n.output: known[i] = n if unknown: # These inputs are coming from the graph below. for name in unknown: successors[name].append(parent_node_name) predecessors[parent_node_name].append(name) def graph_predecessors_and_successors(graph): """ Returns the successors and the predecessors within on ONNX graph. """ node_map = {} successors = {} predecessors = {} subgraphs = [] _append_succ_pred_s(subgraphs, successors, predecessors, node_map, graph.node) return subgraphs, predecessors, successors, node_map def get_hidden_inputs(nodes): """ Returns the list of hidden inputs used by subgraphs. :param nodes: list of nodes :return: list of names """ inputs = set() outputs = set() for node in nodes: inputs |= set(node.input) outputs |= set(node.output) for att in node.attribute: if ( att.type != AttributeProto.GRAPH or not hasattr(att, "g") or att.g is None ): continue hidden = get_hidden_inputs(att.g.node) inits = set(i.name for i in att.g.initializer) inits |= set(i.name for i in att.g.sparse_initializer) inputs |= hidden - (inits & hidden) return inputs - (outputs & inputs) def reorder_nodes_for_display(nodes, verbose=False): """ Reorders the node with breadth first seach (BFS). :param nodes: list of ONNX nodes :param verbose: dislay intermediate informations :return: reordered list of nodes """ class temp: "Fake GraphProto." def __init__(self, nodes): self.node = nodes _, predecessors, successors, dnodes = graph_predecessors_and_successors(temp(nodes)) local_variables = get_hidden_inputs(nodes) all_outputs = set() all_inputs = set(local_variables) for node in nodes: all_outputs |= set(node.output) all_inputs |= set(node.input) common = all_outputs & all_inputs successors = {k: set(v) for k, v in successors.items()} predecessors = {k: set(v) for k, v in predecessors.items()} if verbose: pprint.pprint( [ "[reorder_nodes_for_display]", "predecessors", predecessors, "successors", successors, ] ) known = all_inputs - common new_nodes = [] done = set() def _find_sequence(node_name, known, done): inputs = dnodes[node_name].input if any(map(lambda i: i not in known, inputs)): return [] res = [node_name] while res[-1] in successors: next_names = successors[res[-1]] if res[-1] not in dnodes: next_names = set(v for v in next_names if v not in known) if len(next_names) == 1: next_name = next_names.pop() inputs = dnodes[next_name].input if any(map(lambda i: i not in known, inputs)): break res.extend(next_name) else: break else: next_names = set(v for v in next_names if v not in done) if len(next_names) == 1: next_name = next_names.pop() res.append(next_name) else: break return [r for r in res if r in dnodes and r not in done] while len(done) < len(nodes): # possible possibles = OrderedDict() for k, v in dnodes.items(): if k in done: continue if ":/:" in k: # node part of a sub graph (assuming :/: is never used in a node name) continue if predecessors[k] <= known: possibles[k] = v sequences = OrderedDict() for k, v in possibles.items(): if k in done: continue sequences[k] = _find_sequence(k, known, done) if verbose: print( "[reorder_nodes_for_display] * sequence(%s)=%s - %r" % (k, ",".join(sequences[k]), list(sequences)) ) if not sequences: raise RuntimeError( # pragma: no cover "Unexpected empty sequence (len(possibles)=%d, " "len(done)=%d, len(nodes)=%d). This is usually due to " "a name used both as result name and node node. " "known=%r." % (len(possibles), len(done), len(nodes), known) ) # find the best sequence best = None for k, v in sequences.items(): if best is None or len(v) > len(sequences[best]): # if the sequence of successors is longer best = k elif len(v) == len(sequences[best]): if new_nodes: # then choose the next successor sharing input with # previous output so = set(new_nodes[-1].output) first1 = dnodes[sequences[best][0]] first2 = dnodes[v[0]] if len(set(first1.input) & so) < len(set(first2.input) & so): best = k else: first1 = dnodes[sequences[best][0]] first2 = dnodes[v[0]] if first1.op_type > first2.op_type: best = k elif first1.op_type == first2.op_type and first1.name > first2.name: best = k if best is None: raise RuntimeError( # pragma: no cover f"Wrong implementation (len(sequence)={len(sequences)})." ) if verbose: print( "[reorder_nodes_for_display] BEST: sequence(%s)=%s" % (best, ",".join(sequences[best])) ) # process the sequence for k in sequences[best]: v = dnodes[k] new_nodes.append(v) if verbose: print(f"[reorder_nodes_for_display] + {v.name!r} ({v.op_type!r})") done.add(k) known |= set(v.output) if len(new_nodes) != len(nodes): raise RuntimeError( # pragma: no cover "The returned new nodes are different. " "len(nodes=%d) != %d=len(new_nodes). done=\n%r" "\n%s\n----------\n%s" % ( len(nodes), len(new_nodes), done, "\n".join( "%d - %s - %s - %s" % ( (n.name + "".join(n.output)) in done, n.op_type, n.name, n.name + "".join(n.output), ) for n in nodes ), "\n".join( "%d - %s - %s - %s" % ( (n.name + "".join(n.output)) in done, n.op_type, n.name, n.name + "".join(n.output), ) for n in new_nodes ), ) ) n0s = set(n.name for n in nodes) n1s = set(n.name for n in new_nodes) if n0s != n1s: raise RuntimeError( # pragma: no cover "The returned new nodes are different.\n" "%r !=\n%r\ndone=\n%r" "\n----------\n%s\n----------\n%s" % ( n0s, n1s, done, "\n".join( "%d - %s - %s - %s" % ( (n.name + "".join(n.output)) in done, n.op_type, n.name, n.name + "".join(n.output), ) for n in nodes ), "\n".join( "%d - %s - %s - %s" % ( (n.name + "".join(n.output)) in done, n.op_type, n.name, n.name + "".join(n.output), ) for n in new_nodes ), ) ) return new_nodes
[docs]def onnx_simple_text_plot( model, verbose=False, att_display=None, add_links=False, recursive=False, functions=True, raise_exc=True, sub_graphs_names=None, level=1, indent=True, ): """ Displays an ONNX graph into text. :param model: ONNX graph :param verbose: display debugging information :param att_display: list of attributes to display, if None, a default list if used :param add_links: displays links of the right side :param recursive: display subgraphs as well :param functions: display functions as well :param raise_exc: raises an exception if the model is not valid, otherwise tries to continue :param sub_graphs_names: list of sub-graphs names :param level: sub-graph level :param indent: use indentation or not :return: str An ONNX graph is printed the following way: .. runpython:: :showcode: :warningout: DeprecationWarning, FutureWarning import numpy from sklearn.cluster import KMeans from skl2onnx import to_onnx from onnx_array_api.plotting.text_plot import onnx_simple_text_plot x = numpy.random.randn(10, 3) y = numpy.random.randn(10) model = KMeans(3) model.fit(x, y) onx = to_onnx(model, x.astype(numpy.float32), target_opset=15) text = onnx_simple_text_plot(onx, verbose=False) print(text) The same graphs with links. .. runpython:: :showcode: :warningout: DeprecationWarning, FutureWarning import numpy from sklearn.cluster import KMeans from skl2onnx import to_onnx from onnx_array_api.plotting.text_plot import onnx_simple_text_plot x = numpy.random.randn(10, 3) y = numpy.random.randn(10) model = KMeans(3) model.fit(x, y) onx = to_onnx(model, x.astype(numpy.float32), target_opset=15) text = onnx_simple_text_plot(onx, verbose=False, add_links=True) print(text) Visually, it looks like the following: .. gdot:: :script: DOT-SECTION # onnx_simple_text_plot import numpy from sklearn.cluster import KMeans from skl2onnx import to_onnx from onnx_array_api.plotting.dot_plot import to_dot x = numpy.random.randn(10, 3) y = numpy.random.randn(10) model = KMeans(3) model.fit(x, y) model_onnx = to_onnx(model, x.astype(numpy.float32), target_opset=15) print("DOT-SECTION", to_dot(model_onnx)) """ use_indentation = indent if att_display is None: att_display = [ "activations", "align_corners", "allowzero", "alpha", "auto_pad", "axis", "axes", "batch_axis", "batch_dims", "beta", "bias", "blocksize", "case_change_action", "ceil_mode", "center_point_box", "clip", "coordinate_transformation_mode", "count_include_pad", "cubic_coeff_a", "decay_factor", "detect_negative", "detect_positive", "dilation", "dilations", "direction", "dtype", "end", "epsilon", "equation", "exclusive", "exclude_outside", "extrapolation_value", "fmod", "gamma", "group", "hidden_size", "high", "ignore_index", "input_forget", "is_case_sensitive", "k", "keepdims", "kernel_shape", "lambd", "largest", "layout", "linear_before_reset", "locale", "low", "max_gram_length", "max_skip_count", "mean", "min_gram_length", "mode", "momentum", "nearest_mode", "ngram_counts", "ngram_indexes", "noop_with_empty_axes", "norm_coefficient", "norm_coefficient_post", "num_scan_inputs", "output_height", "output_padding", "output_shape", "output_width", "p", "padding_mode", "pads", "perm", "pooled_shape", "reduction", "reverse", "sample_size", "sampling_ratio", "scale", "scan_input_axes", "scan_input_directions", "scan_output_axes", "scan_output_directions", "seed", "select_last_index", "size", "sorted", "spatial_scale", "start", "storage_order", "strides", "time_axis", "to", "training_mode", "transA", "transB", "type", "upper", "xs", "y", "zs", ] if sub_graphs_names is None: sub_graphs_names = {} def _get_subgraph_name(idg): if idg in sub_graphs_names: return sub_graphs_names[idg] g = "G%d" % (len(sub_graphs_names) + 1) sub_graphs_names[idg] = g return g def str_node(indent, node): atts = [] if hasattr(node, "attribute"): for att in node.attribute: done = True if hasattr(att, "ref_attr_name") and att.ref_attr_name: atts.append(f"{att.name}=${att.ref_attr_name}") continue if att.name in att_display: if att.type == AttributeProto.INT: atts.append("%s=%d" % (att.name, att.i)) elif att.type == AttributeProto.FLOAT: atts.append(f"{att.name}={att.f:1.2f}") elif att.type == AttributeProto.INTS: atts.append( "%s=%s" % (att.name, str(list(att.ints)).replace(" ", "")) ) else: done = False elif ( att.type == AttributeProto.GRAPH and hasattr(att, "g") and att.g is not None ): atts.append(f"{att.name}={_get_subgraph_name(id(att.g))}") else: done = False if done: continue if att.type in ( AttributeProto.TENSOR, AttributeProto.TENSORS, AttributeProto.SPARSE_TENSOR, AttributeProto.SPARSE_TENSORS, ): try: val = str(to_array(att.t).tolist()) except TypeError as e: raise TypeError( # pragma: no cover "Unable to display tensor type %r.\n%s" % (att.type, str(att)) ) from e if "\n" in val: val = val.split("\n", maxsplit=1) + "..." if len(val) > 10: val = val[:10] + "..." elif att.type == AttributeProto.STRING: val = str(att.s) if len(val) > 50: val = val[:40] + "..." + val[-10:] elif att.type == AttributeProto.STRINGS: n_val = list(att.strings) if len(n_val) < 5: val = ",".join(map(str, n_val)) else: val = "%d:[%s...%s]" % ( len(n_val), ",".join(map(str, n_val[:2])), ",".join(map(str, n_val[-2:])), ) elif att.type == AttributeProto.INT: val = str(att.i) elif att.type == AttributeProto.FLOAT: val = str(att.f) elif att.type == AttributeProto.INTS: n_val = list(att.ints) if len(n_val) < 6: val = f"[{','.join(map(str, n_val))}]" else: val = "%d:[%s...%s]" % ( len(n_val), ",".join(map(str, n_val[:3])), ",".join(map(str, n_val[-3:])), ) elif att.type == AttributeProto.FLOATS: n_val = list(att.floats) if len(n_val) < 5: val = f"[{','.join(map(str, n_val))}]" else: val = "%d:[%s...%s]" % ( len(n_val), ",".join(map(str, n_val[:2])), ",".join(map(str, n_val[-2:])), ) else: val = ".%d" % att.type atts.append(f"{att.name}={val}") inputs = list(node.input) if atts: inputs.extend(atts) if node.domain in ("", "ai.onnx.ml"): domain = "" else: domain = f"[{node.domain}]" return "%s%s%s(%s) -> %s" % ( " " * indent, node.op_type, domain, ", ".join(inputs), ", ".join(node.output), ) rows = [] if hasattr(model, "opset_import"): for opset in model.opset_import: rows.append(f"opset: domain={opset.domain!r} version={opset.version!r}") if hasattr(model, "graph"): if model.doc_string: rows.append(f"doc_string: {model.doc_string}") main_model = model model = model.graph else: main_model = None # inputs line_name_new = {} line_name_in = {} if level == 0: rows.append("----- input ----") for inp in model.input: if isinstance(inp, str): rows.append(f"input: {inp!r}") else: line_name_new[inp.name] = len(rows) rows.append( "input: name=%r type=%r shape=%r" % (inp.name, _get_type(inp), _get_shape(inp)) ) if hasattr(model, "attribute"): for att in model.attribute: if isinstance(att, str): rows.append(f"attribute: {att!r}") else: raise NotImplementedError( # pragma: no cover "Not yet introduced in onnx." ) # initializer if hasattr(model, "initializer"): if len(model.initializer) and level == 0: rows.append("----- initializer ----") for init in model.initializer: if numpy.prod(_get_shape(init)) < 5: content = f" -- {to_array(init).ravel()!r}" else: content = "" line_name_new[init.name] = len(rows) rows.append( "init: name=%r type=%r shape=%r%s" % (init.name, _get_type(init), _get_shape(init), content) ) if level == 0: rows.append("----- main graph ----") # successors, predecessors, it needs to support subgraphs subgraphs = graph_predecessors_and_successors(model)[0] # walk through nodes init_names = set() indents = {} for inp in model.input: if isinstance(inp, str): indents[inp] = 0 init_names.add(inp) else: indents[inp.name] = 0 init_names.add(inp.name) if hasattr(model, "initializer"): for init in model.initializer: indents[init.name] = 0 init_names.add(init.name) try: nodes = reorder_nodes_for_display(model.node, verbose=verbose) except RuntimeError as e: # pragma: no cover if raise_exc: raise e else: rows.append(f"ERROR: {e}") nodes = model.node previous_indent = None previous_out = None previous_in = None for node in nodes: add_break = False name = node.name + "#" + "|".join(node.output) if name in indents: indent = indents[name] if previous_indent is not None and indent < previous_indent: if verbose: print(f"[onnx_simple_text_plot] break1 {node.op_type}") add_break = True elif previous_in is not None and set(node.input) == previous_in: indent = previous_indent else: inds = [indents.get(i, 0) for i in node.input if i not in init_names] if not inds: indent = 0 else: mi = min(inds) indent = mi if previous_indent is not None and indent < previous_indent: if verbose: print( # pragma: no cover f"[onnx_simple_text_plot] break2 {node.op_type}" ) add_break = True if not add_break and previous_out is not None: if not (set(node.input) & previous_out): if verbose: print(f"[onnx_simple_text_plot] break3 {node.op_type}") add_break = True indent = 0 if add_break and verbose: print("[onnx_simple_text_plot] add break") for n in node.input: if n in line_name_in: line_name_in[n].append(len(rows)) else: line_name_in[n] = [len(rows)] for n in node.output: line_name_new[n] = len(rows) rows.append(str_node(indent if use_indentation else 0, node)) indents[name] = indent for i, o in enumerate(node.output): indents[o] = indent + 1 previous_indent = indents[name] previous_out = set(node.output) previous_in = set(node.input) # outputs if level == 0: rows.append("----- output ----") for out in model.output: if isinstance(out, str): if out in line_name_in: line_name_in[out].append(len(rows)) else: line_name_in[out] = [len(rows)] rows.append(f"output: name={out!r} type={'?'} shape={'?'}") else: if out.name in line_name_in: line_name_in[out.name].append(len(rows)) else: line_name_in[out.name] = [len(rows)] rows.append( "output: name=%r type=%r shape=%r" % (out.name, _get_type(out), _get_shape(out)) ) if add_links: def _mark_link(rows, lengths, r1, r2, d): maxl = max(lengths[r1], lengths[r2]) + d * 2 maxl = max(maxl, max(len(rows[r]) for r in range(r1, r2 + 1))) + 2 if rows[r1][-1] == "|": p1, p2 = rows[r1][: lengths[r1] + 2], rows[r1][lengths[r1] + 2 :] rows[r1] = p1 + p2.replace(" ", "-") rows[r1] += ("-" * (maxl - len(rows[r1]) - 1)) + "+" if rows[r2][-1] == " ": rows[r2] += "<" elif rows[r2][-1] == "|": if "<" not in rows[r2]: p = lengths[r2] rows[r2] = rows[r2][:p] + "<" + rows[r2][p + 1 :] p1, p2 = rows[r2][: lengths[r2] + 2], rows[r2][lengths[r2] + 2 :] rows[r2] = p1 + p2.replace(" ", "-") rows[r2] += ("-" * (maxl - len(rows[r2]) - 1)) + "+" for r in range(r1 + 1, r2): if len(rows[r]) < maxl: rows[r] += " " * (maxl - len(rows[r]) - 1) rows[r] += "|" diffs = [] for n, r1 in line_name_new.items(): if n not in line_name_in: continue r2s = line_name_in[n] for r2 in r2s: if r1 >= r2: continue diffs.append((r2 - r1, (n, r1, r2))) diffs.sort() for i in range(len(rows)): rows[i] += " " lengths = [len(r) for r in rows] for d, (n, r1, r2) in diffs: if d == 1 and len(line_name_in[n]) == 1: # no line for link to the next node continue _mark_link(rows, lengths, r1, r2, d) # subgraphs if recursive: for node, name, g in subgraphs: rows.append( "----- subgraph ---- %s - %s - att.%s=%s -- level=%d -- %s -> %s" % ( node.op_type, node.name, name, _get_subgraph_name(id(g)), level, ",".join(i.name for i in g.input), ",".join(i.name for i in g.output), ) ) res = onnx_simple_text_plot( g, verbose=verbose, att_display=att_display, add_links=add_links, recursive=recursive, sub_graphs_names=sub_graphs_names, level=level + 1, raise_exc=raise_exc, ) rows.append(res) # functions if functions and main_model is not None: for fct in main_model.functions: rows.append(f"----- function name={fct.name} domain={fct.domain}") if fct.doc_string: rows.append(f"----- doc_string: {fct.doc_string}") res = onnx_simple_text_plot( fct, verbose=verbose, att_display=att_display, add_links=add_links, recursive=recursive, functions=False, sub_graphs_names=sub_graphs_names, level=1, ) rows.append(res) return "\n".join(rows)
[docs]def onnx_text_plot_io(model, verbose=False, att_display=None): """ Displays information about input and output types. :param model: ONNX graph :param verbose: display debugging information :return: str An ONNX graph is printed the following way: .. runpython:: :showcode: :warningout: DeprecationWarning, FutureWarning import numpy from sklearn.cluster import KMeans from skl2onnx import to_onnx from onnx_array_api.plotting.text_plot import onnx_text_plot_io x = numpy.random.randn(10, 3) y = numpy.random.randn(10) model = KMeans(3) model.fit(x, y) onx = to_onnx(model, x.astype(numpy.float32), target_opset=15) text = onnx_text_plot_io(onx, verbose=False) print(text) """ rows = [] if hasattr(model, "opset_import"): for opset in model.opset_import: rows.append(f"opset: domain={opset.domain!r} version={opset.version!r}") if hasattr(model, "graph"): model = model.graph # inputs for inp in model.input: rows.append( "input: name=%r type=%r shape=%r" % (inp.name, _get_type(inp), _get_shape(inp)) ) # initializer for init in model.initializer: rows.append( "init: name=%r type=%r shape=%r" % (init.name, _get_type(init), _get_shape(init)) ) # outputs for out in model.output: rows.append( "output: name=%r type=%r shape=%r" % (out.name, _get_type(out), _get_shape(out)) ) return "\n".join(rows)