Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for experimental_experiment.xoptim.patterns.onnx_equal
import inspect
from typing import List , Optional
from onnx import NodeProto
from ..patterns_api import MatchResult , PatternOptimization
[docs]
class UnsqueezeEqualPattern ( PatternOptimization ):
"""
Replaces the sequence R -> Equal -> Unsqueeze, R -> Unsqueeze,
into R -> Unsqueeze -> Equal.
"""
[docs]
def match (
self ,
g : "GraphBuilderPatternOptimization" , # noqa: F821
node : NodeProto ,
matched : List [ MatchResult ],
) -> Optional [ MatchResult ]:
if node . op_type != "Equal" or node . domain != "" :
return self . none ()
if not g . is_constant_scalar ( node . input [ 1 ]):
return self . none ( node , inspect . currentframe () . f_lineno )
after = g . next_nodes ( node . output [ 0 ])
if len ( after ) != 1 :
return self . none ( node , inspect . currentframe () . f_lineno )
next_path = g . next_nodes ( node . input [ 0 ])
if len ( next_path ) != 2 :
return self . none ( node , inspect . currentframe () . f_lineno )
if next_path [ 0 ] . op_type == node . op_type and next_path [ 1 ] . op_type == "Unsqueeze" :
if next_path [ 1 ] . input [ 1 ] != after [ 0 ] . input [ 1 ]:
return self . none ( node , inspect . currentframe () . f_lineno )
return MatchResult ( self , [ next_path [ 1 ], node , after [ 0 ]], self . apply )
if next_path [ 1 ] . op_type == node . op_type and next_path [ 0 ] . op_type == "Unsqueeze" :
if next_path [ 0 ] . input [ 1 ] != after [ 0 ] . input [ 1 ]:
return self . none ( node , inspect . currentframe () . f_lineno )
return MatchResult ( self , [ next_path [ 0 ], node , after [ 0 ]], self . apply )
return self . none ( node , inspect . currentframe () . f_lineno )
[docs]
def apply (
self ,
g : "GraphBuilder" , # noqa: F821
node_unsqueeze : NodeProto ,
node_equal : NodeProto ,
node_equal_unsqueeze : NodeProto ,
) -> List [ NodeProto ]:
return [
node_unsqueeze ,
g . make_node (
node_equal . op_type ,
[ node_unsqueeze . output [ 0 ], node_equal . input [ 1 ]],
[ node_equal_unsqueeze . output [ 0 ]],
domain = node_equal . domain ,
name = f " { self . __class__ . __name__ } -- { node_equal . name } " ,
),
]