Source code for experimental_experiment.gradient.ops.op_broadcast_gradient_args

import numpy
from onnx.reference.op_run import OpRun


[docs] class BroadcastGradientArgs(OpRun): op_domain = "com.microsoft" def _run(self, a_shape, b_shape): A_dims = a_shape B_dims = b_shape a_size = len(a_shape) b_size = len(b_shape) ndim = max(a_size, b_size) i = a_size - 1 j = b_size - 1 k = ndim - 1 a_axes = [] b_axes = [] while i >= 0 and j >= 0: A_dim = A_dims[i] B_dim = B_dims[j] if A_dim != B_dim: if A_dim == 1: a_axes.append(k) elif B_dim == 1: b_axes.append(k) else: a = A_dims[:a_size] b = B_dims[:b_size] raise RuntimeError( f"Broadcast is not possible between inputs of shapes: {a} and {b}." ) i -= 1 j -= 1 k -= 1 if i < 0: while k >= 0: a_axes.append(k) k -= 1 else: while k >= 0: b_axes.append(k) k -= 1 return ( numpy.array(a_axes, dtype=numpy.int64), numpy.array(b_axes, dtype=numpy.int64), )