Source code for experimental_experiment.gradient.ops.op_broadcast_gradient_args
importnumpyfromonnx.reference.op_runimportOpRun
[docs]classBroadcastGradientArgs(OpRun):op_domain="com.microsoft"def_run(self,a_shape,b_shape):A_dims=a_shapeB_dims=b_shapea_size=len(a_shape)b_size=len(b_shape)ndim=max(a_size,b_size)i=a_size-1j=b_size-1k=ndim-1a_axes=[]b_axes=[]whilei>=0andj>=0:A_dim=A_dims[i]B_dim=B_dims[j]ifA_dim!=B_dim:ifA_dim==1:a_axes.append(k)elifB_dim==1:b_axes.append(k)else:a=A_dims[:a_size]b=B_dims[:b_size]raiseRuntimeError(f"Broadcast is not possible between inputs of shapes: {a} and {b}.")i-=1j-=1k-=1ifi<0:whilek>=0:a_axes.append(k)k-=1else:whilek>=0:b_axes.append(k)k-=1return(numpy.array(a_axes,dtype=numpy.int64),numpy.array(b_axes,dtype=numpy.int64),)