Source code for onnx_extended.tools.stats_nodes

import pprint
from collections import Counter
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
import numpy as np
from onnx import (
    AttributeProto,
    FunctionProto,
    GraphProto,
    ModelProto,
    NodeProto,
    SparseTensorProto,
    TensorProto,
)
from ..reference import CReferenceEvaluator, to_array_extended


[docs]def enumerate_nodes( onx: Union[FunctionProto, GraphProto, ModelProto], recursive: bool = True ) -> Iterable[ Tuple[ Tuple[str, ...], Union[GraphProto, FunctionProto], Union[NodeProto, TensorProto, SparseTensorProto], ] ]: """ Enumerates all nodes in a model. :param onx: the model :param recursive: look into subgraphs :return: enumerate tuple *(name, parent, node)* """ if isinstance(onx, ModelProto): for c, parent, node in enumerate_nodes(onx.graph, recursive=recursive): yield (onx.graph.name,) + c, parent, node for f in onx.functions: for c, parent, node in enumerate_nodes(f, recursive=recursive): yield (f.name,) + c, parent, node elif isinstance(onx, (GraphProto, FunctionProto)): if isinstance(onx, GraphProto): for init in onx.initializer: yield (init.name,), onx, init for initp in onx.sparse_initializer: yield (initp.indices.name or initp.values.name,), onx, initp for i, node in enumerate(onx.node): assert isinstance( node, NodeProto ), f"A NodeProto is expected not {type(node)}." if node.op_type == "Constant": yield (node.output[0],), onx, node else: yield (node.name or f"#{i}",), onx, node if recursive: for att in node.attribute: if att.g: for c, parent, node in enumerate_nodes( att.g, recursive=recursive ): if isinstance(node, NodeProto): n = node.name or f"#{i}" elif isinstance(node, TensorProto): n = node.name elif isinstance(node, SparseTensorProto): n = node.indices.name or node.values.name else: raise TypeError(f"Unexpected type {type(node)}.") yield (f"{n}/{att.name}",) + c, parent, node
def extract_attributes(node: NodeProto) -> Dict[str, Tuple[AttributeProto, Any]]: """ Extracts all atributes of a node. :param node: node proto :return: dictionary """ atts: Dict[str, Tuple[AttributeProto, Any]] = {} for att in node.attribute: if hasattr(att, "ref_attr_name") and att.ref_attr_name: atts[att.name] = (att, None) continue if att.type == AttributeProto.INT: atts[att.name] = (att, att.i) continue if att.type == AttributeProto.FLOAT: atts[att.name] = (att, att.f) continue if att.type == AttributeProto.INTS: atts[att.name] = (att, np.array(att.ints)) continue if att.type == AttributeProto.FLOATS: atts[att.name] = (att, np.array(att.floats, dtype=np.float32)) continue if att.type == AttributeProto.GRAPH and hasattr(att, "g") and att.g is not None: atts[att.name] = (att, None) continue if att.type == AttributeProto.SPARSE_TENSORS: atts[att.name] = (att, to_array_extended(att.sparse_tensor)) continue if att.type == AttributeProto.TENSOR: atts[att.name] = (att, to_array_extended(att.t)) continue if att.type == AttributeProto.TENSORS: atts[att.name] = (att, [to_array_extended(t) for t in att.tensors]) continue if att.type == AttributeProto.SPARSE_TENSORS: atts[att.name] = (att, [to_array_extended(t) for t in att.sparse_tensors]) continue if att.type == AttributeProto.STRING: atts[att.name] = (att, att.s.decode("utf-8")) continue if att.type == AttributeProto.STRINGS: atts[att.name] = (att, np.array([s.decode("utf-8") for s in att.strings])) continue return atts class _Statistics: """ Common class to statistics classes. """ def __init__(self): self._statistics: Dict[str, Any] = {} def __len__(self) -> int: "Returns the number of statistics" return len(self._statistics) def add(self, name: str, value: Any): "Adds one statictics." assert name not in self._statistics, f"Statistics {name!r} was already added." self._statistics[name] = value def __iter__(self) -> Iterable[Tuple[str, Any]]: for it in self._statistics.items(): yield it def __getitem__(self, name: str) -> Any: "Returns one statistics." return self._statistics[name] def get(self, name: str, default_value: Optional[Any] = None) -> Any: "Returns one statistics or a default value if not found." return self._statistics.get(name, default_value) def __str__(self) -> str: "Usual" return f"{self.__class__.__name__}(\n{pprint.pformat(self._statistics)})" @property def dict_values(self) -> Dict[str, Any]: """ Converts the statistics the class holds into a single row in order to build a dataframe. """ raise NotImplementedError( f"Property 'dict_values' not implemented for class {type(self)}." )
[docs]class NodeStatistics(_Statistics): """ Stores many statistics for NodeProto. """ def __init__(self, parent: Union[GraphProto, FunctionProto], node: NodeProto): _Statistics.__init__(self) self.parent = parent self.node = node def __str__(self) -> str: return ( f"{self.__class__.__name__}(<{self.parent.name}>, <{self.node.op_type}>,\n" f"{pprint.pformat(self._statistics)})" ) @property def dict_values(self) -> Dict[str, Any]: "Returns the statistics as a dictionary." obs = {} for k, v in self._statistics.items(): if isinstance( v, (int, float, str, np.int64, np.int32, np.float32, np.float64) ): obs[k] = v elif isinstance(v, set): obs[k] = ",".join(map(str, sorted(v))) elif isinstance(v, Counter): for kk, vv in v.items(): obs[f"{k}__{kk}"] = vv elif isinstance(v, list): if len(v) == 0: continue if isinstance(v[0], (HistTreeStatistics, TreeStatistics)): # It is the statistics for every tree. # Let's skip that. continue raise TypeError( f"Unexpected type {type(v)} for statistics {k!r} " f"with element {type(v[0])}." ) elif isinstance(v, _Statistics): dv = v.dict_values for kk, vv in dv.items(): if isinstance(vv, (int, float, str)): obs[f"{k}__{kk}"] = vv else: raise TypeError(f"Unexpected type {type(v)} for statistics {k!r}: {v}.") return obs
[docs]class TreeStatistics(_Statistics): """ Stores many statistics on a tree extracted from TreeEnsemble* operators. """ def __init__(self, node: NodeProto, tree_id: int): _Statistics.__init__(self) self.node = node self.tree_id = tree_id def __str__(self) -> str: return ( f"{self.__class__.__name__}(<{self.node.op_type}>, {self.tree_id},\n" f"{pprint.pformat(self._statistics)})" )
[docs]class HistTreeStatistics(_Statistics): """ Stores statistics on thresholds. """ def __init__( self, node: NodeProto, featureid: int, values: np.ndarray, bins: int = 20 ): _Statistics.__init__(self) self.node = node self.featureid = featureid self.add("min", values.min()) self.add("max", values.max()) self.add("mean", values.mean()) self.add("median", np.median(values)) self.add("size", len(values)) n_distinct = len(set(values)) self.add("n_distinct", n_distinct) self.add("hist", np.histogram(values, bins)) if n_distinct <= 50: self.add("v_distinct", set(values)) def __str__(self) -> str: "Usual" return ( f"{self.__class__.__name__}(<{self.node.op_type}>, {self.featureid},\n" f"{pprint.pformat(self._statistics)})" )
[docs]class HistStatistics(_Statistics): """ Stores statistics on constants. """ def __init__( self, parent: Union[GraphProto, FunctionProto], node: Union[NodeProto, TensorProto, SparseTensorProto], bins: int = 20, ): _Statistics.__init__(self) self.parent = parent self.node = node values = self.values self.add("sparse", 1 if self.is_sparse else 0) self.add("shape", values.shape) self.add("dtype", values.dtype) self.add("min", values.min()) self.add("max", values.max()) self.add("mean", values.mean()) self.add("median", np.median(values)) flat = values.ravel() self.add("size", values.size) n_distinct = len(flat) self.add("n_distinct", n_distinct) if values.size > 1: try: self.add("hist", np.histogram(values, bins)) except IndexError as e: raise RuntimeError( f"Unable to process values with shape={values.shape}, " f"dtype={values.dtype}, {values}." ) from e else: self.add("hist", (values, np.array([1], dtype=np.int64))) if n_distinct <= 50: self.add("v_distinct", set(flat)) @property def dict_values(self) -> Dict[str, Any]: "Returns the statistics as a dictionary." obs = {} for k in [ "size", "shape", "dtype", "min", "max", "mean", "median", "n_distinct", ]: obs[k] = self[k] hist = self["hist"] if hist[0].size > 0 and len(hist[0].shape) > 0: for i, v in enumerate(hist[0]): obs[f"hist_y_{i}"] = v for i, v in enumerate(hist[1]): obs[f"hist_x_{i}"] = v return obs @property def is_sparse(self) -> bool: "Tells if the tensor is sparse." return isinstance(self.node, SparseTensorProto) @property def name(self) -> str: "Returns the name of the tensor." if isinstance(self.node, SparseTensorProto): return self.node.indices.name or self.node.values.name if isinstance(self.node, NodeProto): return self.node.output[0] return self.node.name def __str__(self) -> str: "Usual" if isinstance(self.node, NodeProto): return ( f"{self.__class__.__name__}(<{self.parent.name}>, " f"<{self.node.op_type}>,\n" f"{pprint.pformat(self._statistics)})" ) return ( f"{self.__class__.__name__}(<{self.parent.name}>, <{self.name}>,\n" f"{pprint.pformat(self._statistics)})" ) @property def values(self): "Returns the values as an array." if isinstance(self.node, NodeProto): model = CReferenceEvaluator(self.node) return model.run(None, {})[0] return to_array_extended(self.node)
[docs]def stats_tree_ensemble( parent: Union[GraphProto, FunctionProto], node: NodeProto ) -> NodeStatistics: """ Computes statistics on every tree of a TreeEnsemble. :param parent: function or graph proto hosting the node :param node: node :return: instance of NodeStatistics """ stats = NodeStatistics(parent, node) atts = {k: v[1] for k, v in extract_attributes(node).items()} unique = set(atts["nodes_treeids"]) stats.add("kind", "Regressor" if "n_targets" in atts else "Classifier") stats.add("n_trees", len(unique)) stats.add( "n_outputs", atts["n_targets"] if "n_targets" in atts else len(atts["class_ids"]), ) stats.add("max_featureid", max(atts["nodes_featureids"])) stats.add("n_features", len(set(atts["nodes_featureids"]))) stats.add("n_rules", len(set(atts["nodes_modes"]))) stats.add("rules", set(atts["nodes_modes"])) stats.add("hist_rules", Counter(atts["nodes_modes"])) features = [] for fid in sorted(set(atts["nodes_featureids"])): indices = atts["nodes_featureids"] == fid features.append(HistTreeStatistics(node, fid, atts["nodes_values"][indices])) stats.add("features", features) atts_nodes = {k: v for k, v in atts.items() if k.startswith("nodes")} tree_stats = [] for treeid in sorted(unique): tr = TreeStatistics(node, treeid) indices = atts_nodes["nodes_treeids"] == treeid atts_tree = {k: v[indices] for k, v in atts_nodes.items()} tr.add("n_nodes", len(atts_tree["nodes_nodeids"])) tr.add("n_leaves", len(atts_tree["nodes_modes"] == "LEAF")) tr.add("max_featureid", max(atts_tree["nodes_featureids"])) tr.add("n_features", len(set(atts_tree["nodes_featureids"]))) tr.add("n_rules", len(set(atts_tree["nodes_modes"]))) tr.add("rules", set(atts_tree["nodes_modes"])) tr.add("hist_rules", Counter(atts_tree["nodes_modes"])) tree_stats.append(tr) stats.add("trees", tree_stats) return stats
def stats_constant( parent: Union[GraphProto, FunctionProto], node: Union[NodeProto, TensorProto, SparseTensorProto], ) -> HistStatistics: """ Computes basic statistics on constants. :param parent: function or graph proto hosting the node :param node: node :return: instance of NodeStatistics """ return HistStatistics(parent, node)
[docs]def enumerate_stats_nodes( onx: Union[FunctionProto, GraphProto, ModelProto], recursive: bool = True, stats_fcts: Optional[ Dict[ Tuple[str, str], Callable[ [ Union[GraphProto, FunctionProto], Union[NodeProto, TensorProto, SparseTensorProto], ], Union[NodeStatistics, HistStatistics], ], ] ] = None, ) -> Iterable[ Tuple[ Tuple[str, ...], Union[GraphProto, FunctionProto], Union[NodeStatistics, HistStatistics], ] ]: """ Computes statistics of nodes functions. :param onx: the model :param recursive: look into subgraphs :param stats_fcts: a dicionary of functions to call for every node, the key is *(domain, op_type)*, if None, uses the default statistiques :return: enumerate tuple *(name, parent, statistics)* """ if stats_fcts is None: dom_optim = "onnx_extented.ortops.optim.cpu" stats_fcts = { ("ai.onnx.ml", "TreeEnsembleRegressor"): stats_tree_ensemble, ("ai.onnx.ml", "TreeEnsembleClassifier"): stats_tree_ensemble, (dom_optim, "TreeEnsembleRegressor"): stats_tree_ensemble, (dom_optim, "TreeEnsembleClassifier"): stats_tree_ensemble, ("", "Constant"): stats_constant, } for name, parent, node in enumerate_nodes(onx, recursive=recursive): if isinstance(node, NodeProto): if (node.domain, node.op_type) in stats_fcts: stat = stats_fcts[node.domain, node.op_type](parent, node) yield name, parent, stat elif ("", "Constant") in stats_fcts: stati = stats_fcts["", "Constant"](parent, node) if stati["dtype"] in (np.int64, np.int32) and stati["size"] < 10: # This is probably a shape. It is skipped. continue yield name, parent, stati