Source code for experimental_experiment.xoptim.patterns.onnx_shape

import inspect
from typing import List, Optional
from onnx import NodeProto
from ..patterns_api import MatchResult, PatternOptimization


[docs] class ShapeBasedShapeShapeAddPattern(PatternOptimization): """Tries to find another to get a dimension obtained with the addition of two.""" def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Add" or node.domain != "": return self.none() shape1 = g.node_before(node.input[0]) if shape1 is None or shape1.op_type != "Shape" or shape1.domain != "": return self.none(node, inspect.currentframe().f_lineno) shape2 = g.node_before(node.input[1]) if shape2 is None or shape2.op_type != "Shape" or shape2.domain != "": return self.none(node, inspect.currentframe().f_lineno) # ishape1 = g.get_shape_renamed(shape1.input[0]) # ishape2 = g.get_shape_renamed(shape2.input[0]) # value1 = g.builder.value_as_shape(node.input[0]) # value2 = g.builder.value_as_shape(node.input[1]) # input_shapes = [g.get_shape_renamed(i) for i in g.builder.input_names] # g.builder._known_value_shape # g.builder.constraints_) # g.builder.replacements_dimensions_ return self.none(node, inspect.currentframe().f_lineno)
# return MatchResult(self, [shape1, shape2, node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 shape1_node: NodeProto, shape2_node: NodeProto, add_node: NodeProto, ) -> List[NodeProto]: raise NotImplementedError(f"{self.___class__.__name__} is not implemented yet.")