Source code for experimental_experiment.xoptim.patterns_ort.batch_normalization

import inspect
from typing import List, Optional
import numpy as np
from onnx import NodeProto
from ...helpers import tensor_dtype_to_np_dtype
from ..patterns_api import MatchResult, PatternOptimization


[docs] class OrtBatchNormalizationTrainingPattern(PatternOptimization): """ onnxruntime does not support batch normalization with training=1. """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if not g.has_processor("CUDA"): return self.none(node, inspect.currentframe().f_lineno) if node.op_type != "BatchNormalization" or node.domain != "": return self.none(node, inspect.currentframe().f_lineno) training_mode = g.get_attribute(node, "training_mode", exc=False) if training_mode is None or training_mode.i == 0: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 bn_node: NodeProto, ) -> List[NodeProto]: rk = g.get_rank(bn_node.input[0]) axes = [i for i in range(rk) if i != 1] axes_name = g.make_initializer( "", np.array(axes, dtype=np.int64), source="BatchNormalizationTrainingPattern.1" ) current_mean_axis = g.unique_name(f"{bn_node.input[0]}_mean") current_mean = g.unique_name(f"{bn_node.input[0]}_mean") mean_node = g.make_node( "ReduceMean", [bn_node.input[0], axes_name], [current_mean_axis], name=f"{self.__class__.__name__}--{bn_node.name}", keepdims=1, ) mean_node2 = g.make_node( "Squeeze", [current_mean_axis, axes_name], [current_mean], name=f"{self.__class__.__name__}--{bn_node.name}", ) centered = g.unique_name(f"{bn_node.input[0]}_center") diff_node = g.make_node( "Sub", [bn_node.input[0], current_mean_axis], [centered], name=f"{self.__class__.__name__}--{bn_node.name}", ) current_var = g.unique_name(f"{bn_node.input[0]}_var") x2 = g.unique_name(f"{bn_node.input[0]}_var2") var2_node = g.make_node( "Mul", [centered, centered], [x2], name=f"{self.__class__.__name__}--{bn_node.name}", ) var_node = g.make_node( "ReduceMean", [x2, axes_name], [current_var], name=f"{self.__class__.__name__}--{bn_node.name}", keepdims=0, ) atts = g.get_attributes_with_default(bn_node, epsilon=None, momentum=None) new_nb_node = g.make_node( "BatchNormalization", [*bn_node.input[:3], current_mean, current_var], [bn_node.output[0]], training_mode=0, name=f"{self.__class__.__name__}--{bn_node.name}", **atts, ) # running_mean, running_var ns = [] if bn_node.output[1] not in ("", None) and bn_node.output[2] not in ("", None): momentum = atts.get( "momentum", 0.9 ) # this value is defined by onnx specifications dtype = tensor_dtype_to_np_dtype(g.get_type(bn_node.input[0])) mom_name = g.make_initializer( "", np.array([momentum], dtype=dtype), source="BatchNormalizationTrainingPattern.2", ) mom_1_name = g.make_initializer( "", np.array([1 - momentum], dtype=dtype), source="BatchNormalizationTrainingPattern.3", ) p1_mean = g.unique_name(f"{bn_node.output[1]}_m1") p1_var = g.unique_name(f"{bn_node.output[2]}_m1") p2_mean = g.unique_name(f"{bn_node.output[1]}_m2") p2_var = g.unique_name(f"{bn_node.output[2]}_m2") same_type = g.get_type(bn_node.output[1]) == g.get_type(bn_node.output[0]) mean_name, var_name = ( bn_node.output[1:] if same_type else ( g.unique_name(f"{bn_node.output[1]}_m3"), g.unique_name(f"{bn_node.output[2]}_m3"), ) ) ns.extend( [ g.make_node( "Mul", [bn_node.input[3], mom_name], [p1_mean], name=f"{self.__class__.__name__}--{bn_node.name}", ), g.make_node( "Mul", [bn_node.input[4], mom_name], [p1_var], name=f"{self.__class__.__name__}--{bn_node.name}", ), g.make_node( "Mul", [current_mean, mom_1_name], [p2_mean], name=f"{self.__class__.__name__}--{bn_node.name}", ), g.make_node( "Mul", [current_var, mom_1_name], [p2_var], name=f"{self.__class__.__name__}--{bn_node.name}", ), g.make_node( "Add", [p1_mean, p2_mean], [mean_name], name=f"{self.__class__.__name__}--{bn_node.name}", ), g.make_node( "Add", [p1_var, p2_var], [var_name], name=f"{self.__class__.__name__}--{bn_node.name}", ), ] ) if not same_type: itype = g.get_type(bn_node.output[1]) ns.extend( [ g.make_node( "Cast", [mean_name], [bn_node.output[1]], to=itype, name=f"{self.__class__.__name__}--{bn_node.name}", ), g.make_node( "Cast", [var_name], [bn_node.output[2]], to=itype, name=f"{self.__class__.__name__}--{bn_node.name}", ), ] ) return [mean_node, mean_node2, diff_node, var2_node, var_node, new_nb_node, *ns]