Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for experimental_experiment.xoptim.patterns.onnx_conv
import inspect
from typing import List , Optional
from onnx import NodeProto
from ..patterns_api import MatchResult , PatternOptimization
[docs]
class ConvBiasNullPattern ( PatternOptimization ):
"""
Checks that a Conv has a null bias.
"""
def __init__ ( self , verbose : int = 0 , priority : int = 0 ):
super () . __init__ ( verbose , priority )
[docs]
def match (
self ,
g : "GraphBuilderPatternOptimization" , # noqa: F821
node : NodeProto ,
matched : List [ MatchResult ],
) -> Optional [ MatchResult ]:
if node . op_type != "Conv" or node . domain != "" :
return self . none ()
if len ( node . input ) < 3 :
return self . none ( node , inspect . currentframe () . f_lineno )
if not g . is_constant ( node . input [ 2 ]):
return self . none ( node , inspect . currentframe () . f_lineno )
cst = g . get_computed_constant ( node . input [ 2 ])
if cst is None or cst . min () != 0 or cst . max () != 0 :
return self . none ( node , inspect . currentframe () . f_lineno )
return MatchResult ( self , [ node ], self . apply , insert_at = node )
[docs]
def apply (
self ,
g : "GraphBuilder" , # noqa: F821
node : NodeProto ,
) -> List [ NodeProto ]:
new_node = g . make_node (
"Conv" ,
node . input [: 2 ],
node . output ,
name = f " { self . __class__ . __name__ } -- { node . name } " ,
doc_string = node . doc_string ,
)
new_node . attribute . extend ( node . attribute )
return [ new_node ]