Source code for experimental_experiment.xoptim.patterns_ml.tree_ensemble

import inspect
from typing import Any, Dict, List, Optional
import numpy as np
import onnx.numpy_helper as onh
from onnx import AttributeProto, NodeProto
from ..patterns_api import MatchResult, PatternOptimization


[docs] class TreeEnsembleRegressorMulPattern(PatternOptimization): """ Replaces TreeEnsembleRegressor + Mul(., scalar) with TreeEnsembleRegressor. """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "TreeEnsembleRegressor" or node.domain != "ai.onnx.ml": return self.none() next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) != 1: return self.none(node, inspect.currentframe().f_lineno) if next_nodes[0].op_type != "Mul" or next_nodes[0].domain != "": return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(next_nodes[0].input[1]): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node, next_nodes[0]], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 tree_node: NodeProto, mul_node: NodeProto, ) -> List[NodeProto]: cst = g.get_constant_scalar(mul_node.input[1]) names = {"target_weights", "target_weights_as_tensor"} weights = None atts = [] for att in tree_node.attribute: if att.name in names: assert weights is None, f"Both {names} can be set at the same time." weights = att else: atts.append(att) if att.name == "target_weights": kwargs = {att.name: [float(f * cst) for f in att.floats]} else: value = onh.to_array(att.t) kwargs = {att.name: onh.from_array(value * cst, name=att.name)} new_tree = g.make_node( tree_node.op_type, tree_node.input, mul_node.output, name=f"{self.__class__.__name__}--{tree_node.name}", domain=tree_node.domain, **kwargs, ) new_tree.attribute.extend(atts) return [new_tree]
[docs] class TreeEnsembleRegressorConcatPattern(PatternOptimization): """ Replaces multiple TreeEnsembleRegressor + Concat(., axis=1) with one TreeEnsembleRegressor. All trees must have only one target (it can be extended to multiple) and is assigned a distinct dimension. The aggregation must be SUM. """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "TreeEnsembleRegressor" or node.domain != "ai.onnx.ml": return self.none() next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) != 1: return self.none(node, inspect.currentframe().f_lineno) concat_node = next_nodes[0] if concat_node.op_type not in ("Sigmoid", "Concat") or concat_node.domain != "": return self.none(node, inspect.currentframe().f_lineno) if concat_node.op_type == "Sigmoid": next_nodes = g.next_nodes(concat_node.output[0]) if len(next_nodes) != 1: return self.none(node, inspect.currentframe().f_lineno) concat_node = next_nodes[0] axis = g.get_attribute(concat_node, "axis", exc=False) if axis is None or axis.i not in (0, 1): return self.none(node, inspect.currentframe().f_lineno) sigmoid = [] trees = [] post_transform = None inputs = None base_values_none = None for treeo_or_sigmoid in concat_node.input: if g.is_used_more_than_once(treeo_or_sigmoid): return self.none(node, inspect.currentframe().f_lineno) t = g.node_before(treeo_or_sigmoid) if t.op_type == "Sigmoid": if g.is_used_more_than_once(t.input[0]): return self.none(node, inspect.currentframe().f_lineno) sigmoid.append(t) t = g.node_before(t.input[0]) if t.op_type != "TreeEnsembleRegressor" or t.domain != "ai.onnx.ml": return self.none(node, inspect.currentframe().f_lineno) if inputs is None: inputs = list(t.input) elif inputs != list(t.input): # not the same input return self.none(node, inspect.currentframe().f_lineno) n_targets = g.get_attribute(t, "n_targets", exc=False) if n_targets is None or n_targets.i != 1: # It could be implemented in that case as well. return self.none(node, inspect.currentframe().f_lineno) # only SUM is allowed agg = g.get_attribute(t, "aggregate_function", exc=False) if agg is not None and agg.s != b"SUM": return self.none(node, inspect.currentframe().f_lineno) # one unique post_transform is allowed post = g.get_attribute(t, "post_transform", exc=False) if post is None: if post_transform is not None and post_transform != post: return self.none(node, inspect.currentframe().f_lineno) post_transform = b"NONE" elif post_transform is None: post_transform = post.s elif post_transform != post.s: return self.none(node, inspect.currentframe().f_lineno) if len(sigmoid) > 0 and post_transform != b"NONE": return self.none(node, inspect.currentframe().f_lineno) # specific rule for base_values: all none or all filled bb = ( g.get_attribute(t, "base_values", exc=False) is not None or g.get_attribute(t, "base_values_as_tensor", exc=False) is not None ) if base_values_none is None: base_values_none = bb elif bb != base_values_none: return self.none(node, inspect.currentframe().f_lineno) trees.append(t) if len(sigmoid) != 0 and len(sigmoid) != len(trees): return self.none(node, inspect.currentframe().f_lineno) return MatchResult( self, [concat_node, *sigmoid, *trees], self.apply, insert_at=concat_node )
@classmethod def get_attribute_value( cls, g: "GraphBuilder", node: NodeProto, name: str, exc: bool = True # noqa: F821 ) -> Any: att = g.get_attribute(node, name, exc=exc) if not exc and att is None: return None if att.type == AttributeProto.INTS: return att.ints if att.type == AttributeProto.FLOATS: return att.floats if att.type == AttributeProto.STRING: return att.s.decode("ascii") if att.type == AttributeProto.STRINGS: return [t.decode("ascii") for t in att.strings] if att.type == AttributeProto.TENSOR: return onh.to_array(att.t) raise AssertionError(f"Unexpected attribute type {att.type} in {att}") @classmethod def _first_tree_id( cls, g: "GraphBuilder", trees: List[NodeProto] # noqa: F821 ) -> Dict[int, int]: res = {} current = 0 for i, t in enumerate(trees): res[i] = current nodes_treeids = cls.get_attribute_value(g, t, "nodes_treeids") target_treeids = cls.get_attribute_value(g, t, "target_treeids") max_id = max(max(nodes_treeids), max(target_treeids)) + 1 current += max_id return res @classmethod def _merge( cls, g: "GraphBuilder", # noqa: F821 trees: List[NodeProto], name: str, as_tensor: Optional[str] = None, increment=False, unique=False, first_tree_id: Optional[Dict[int, int]] = None, ) -> Any: if as_tensor is None: if unique: return cls.get_attribute_value(g, trees[0], name) collected = [cls.get_attribute_value(g, t, name) for t in trees] if not increment: assert ( first_tree_id is None ), "increment is False but first_tree_id is not None" merged = [] for c in collected: merged.extend(c) return merged # all attribute value are necessarily integers # nodes_treeids, target_treeids, target_ids assert first_tree_id is not None, "increment is True but first_tree_id is None" merged = [] for i, value in enumerate(collected): first = first_tree_id[i] for v in value: merged.append(v + first) return merged assert not increment, "as_tensor is true, increment is true, incompatible" assert not unique, "as_tensor is true, unique is true, incompatible" assert first_tree_id is None, "increment is False but first_tree_id is not None" collected = [] for tree in trees: val = g.get_attribute(tree, name, exc=False) if val is not None: collected.append(np.array(val.floats, dtype=np.float32)) continue collected.append(cls.get_attribute_value(g, tree, as_tensor, exc=False)) if any(c is None for c in collected): return None merged = np.hstack(collected) return onh.from_array(merged, name=as_tensor) def _opset_process( cls, g: "GraphBuilder", atts: Dict[str, Any] # noqa: F821 ) -> Dict[str, Any]: """ as_tensor not supported in opset ai.onnx.ml < 3. """ if g.opsets["ai.onnx.ml"] >= 3: new_atts = {} for k, v in atts.items(): if v is None: continue new_atts[k] = v return new_atts new_atts = {} for k, v in atts.items(): if v is None: continue if not k.endswith("_as_tensor"): new_atts[k] = v continue nv = onh.to_array(v) nk = k[: -len("_as_tensor")] new_atts[nk] = [float(x) for x in nv.tolist()] return new_atts
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 concat_node: NodeProto, *trees_or_sigmoid: NodeProto, ) -> List[NodeProto]: sigmoid = [t for t in trees_or_sigmoid if t.op_type == "Sigmoid"] trees = [t for t in trees_or_sigmoid if t.op_type != "Sigmoid"] first_tree_id = self._first_tree_id(g, trees) axis = g.get_attribute(concat_node, "axis").i new_atts = dict( aggregate_function="SUM", n_targets=len(trees), nodes_falsenodeids=self._merge(g, trees, "nodes_falsenodeids"), nodes_featureids=self._merge(g, trees, "nodes_featureids"), nodes_hitrates_as_tensor=self._merge( g, trees, "nodes_hitrates", "nodes_hitrates_as_tensor" ), nodes_missing_value_tracks_true=self._merge( g, trees, "nodes_missing_value_tracks_true" ), nodes_modes=self._merge(g, trees, "nodes_modes"), nodes_nodeids=self._merge(g, trees, "nodes_nodeids"), nodes_treeids=self._merge( g, trees, "nodes_treeids", increment=True, first_tree_id=first_tree_id ), nodes_truenodeids=self._merge(g, trees, "nodes_truenodeids"), nodes_values_as_tensor=self._merge( g, trees, "nodes_values", "nodes_values_as_tensor" ), post_transform=self._merge(g, trees, "post_transform", unique=True), target_ids=self._merge( g, trees, "target_ids", increment=True, first_tree_id=dict(enumerate(range(len(trees)))), ), target_nodeids=self._merge(g, trees, "target_nodeids"), target_treeids=self._merge( g, trees, "target_treeids", increment=True, first_tree_id=first_tree_id ), target_weights_as_tensor=self._merge( g, trees, "target_weights", "target_weights_as_tensor" ), ) bv = self._merge(g, trees, "base_values", "base_values_as_tensor") if bv is not None: new_atts["base_values_as_tensor"] = bv outputs = ( concat_node.output if axis == 1 and len(sigmoid) == 0 else [g.unique_name(f"{self.__class__.__name__}_{concat_node.output[0]}")] ) new_tree = g.make_node( trees[0].op_type, trees[0].input, outputs, name=f"{self.__class__.__name__}--{trees[0].name}", domain=trees[0].domain, **self._opset_process(g, new_atts), ) if axis == 1: if len(sigmoid) == 0: return [new_tree] sigmoid = g.make_node( "Sigmoid", outputs, concat_node.output, name=f"{self.__class__.__name__}--{trees[0].name}", ) return [new_tree, sigmoid] sig_node = [] if len(sigmoid) > 0: sig_outs = g.unique_name(f"{self.__class__.__name__}_{concat_node.output[0]}Sig") sig_node.append( g.make_node( "Sigmoid", outputs, [sig_outs], name=f"{self.__class__.__name__}--{trees[0].name}", ) ) outputs = [sig_outs] transpose_output = g.unique_name(f"{self.__class__.__name__}_{concat_node.output[0]}T") transpose = g.make_node( "Transpose", outputs, [transpose_output], perm=[1, 0], name=f"{self.__class__.__name__}--{trees[0].name}", ) new_shape = g.make_initializer( "", np.array([-1, 1], dtype=np.int64), source="TreeEnsembleRegressorConcatPattern.apply.new_shape", ) reshape = g.make_node( "Reshape", [transpose_output, new_shape], concat_node.output, name=f"{self.__class__.__name__}--{trees[0].name}", ) return [new_tree, *sig_node, transpose, reshape]