Source code for experimental_experiment.reference.ops.op_tri_matrix
import numpy as np
from onnx.reference.op_run import OpRun
[docs]
class TriMatrix(OpRun):
op_domain = "onnx_extended.ortops.optim.cuda"
def _run(self, shape, csts):
lower, diag, upper = list(csts)
dtype = csts.dtype
mat = np.empty(tuple(shape), dtype=dtype)
i = np.arange(shape[0], dtype=np.int32).reshape((-1, 1))
j = np.arange(shape[1], dtype=np.int32).reshape((1, -1))
mat[i > j] = lower
mat[i < j] = upper
mat[i == j] = diag
return (mat,)