Source code for experimental_experiment.reference.ops.op_concat

import numpy as np

from onnx.reference.op_run import OpRun


[docs] class Concat(OpRun): def _preprocess(self, a: np.ndarray, axis: int) -> np.ndarray: if axis >= len(a.shape): # type: ignore new_shape = a.shape + (1,) * (axis + 1 - len(a.shape)) # type: ignore return a.reshape(new_shape) return a def _run(self, *args, axis=None): # type: ignore targs = tuple(self._preprocess(a, axis) for a in args) return (np.concatenate(targs, axis),) # type: ignore