yobx.xshape.shape_type_compute#

yobx.xshape.shape_type_compute.broadcast_shape(sh1: Tuple[int | torch.SymInt | torch.SymFloat | float | str, ...], sh2: Tuple[int | torch.SymInt | torch.SymFloat | float | str, ...], graph_builder: ShapeBuilder | None = None) Tuple[int | torch.SymInt | torch.SymFloat | float | str, ...][source]#

Computes the shape for many broadcasting operators. This function should be used while converting the graph into ONNX because it assumes the broadcast is possible and adds the necessary constraints on the dynamic in the GraphBuilder shapes to make it work.

Parameters:
  • sh1 – first shape

  • sh2 – second shape

  • graph_builder – if not None, the function register any constraint which might appear while applying the broadcast

Returns:

resulting shape

yobx.xshape.shape_type_compute.set_shape_type_custom(self: ShapeBuilder, node: NodeProto, exc: bool = False)[source]#

Sets the shape and type if it can.

yobx.xshape.shape_type_compute.set_shape_type_op_any(self: ShapeBuilder, node: NodeProto, exc: bool = False)[source]#

Sets the shape and type if it can.

yobx.xshape.shape_type_compute.set_type_shape_binary_op(g: ShapeBuilder, name: str, *input_names: List[str], begin: int = 0, cmp_op: bool = False, itype: int | None = None) bool[source]#

Sets the shape and type for a binary operator (add, mul, …).

yobx.xshape.shape_type_compute.set_type_shape_complex_module(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type ComplexModule (extracts real/imaginary part).

yobx.xshape.shape_type_compute.set_type_shape_fused_matmul(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type FusedMatMul.

yobx.xshape.shape_type_compute.set_type_shape_gemm(g: ShapeBuilder, name: str, x: str, y: str, transA: int, transB: int)[source]#

Sets the output shape for node type Gemm.

yobx.xshape.shape_type_compute.set_type_shape_matmul(g: ShapeBuilder, name: str, x: str, y: str) bool[source]#

Sets the output shape for node type MatMul.

yobx.xshape.shape_type_compute.set_type_shape_multi_head_attention(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type MultiHeadAttention.

yobx.xshape.shape_type_compute.set_type_shape_reduce_op(g: ShapeBuilder, name: str, x: str, keepdim: int, axes: Tuple[int] | None = None)[source]#

Sets the output shape for any Reduce type.

yobx.xshape.shape_type_compute.set_type_shape_reshape(g: ShapeBuilder, name: str, input_name: str, new_shape: Sequence[int])[source]#

Sets the output shape for node type Reshape

yobx.xshape.shape_type_compute.set_type_shape_scatter_nd_of_shape(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node types ScatterNDOfShape and MaskedScatterNDOfShape.

yobx.xshape.shape_type_compute.set_type_shape_shared_input(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shapes for nodes with two outputs sharing the same inputs.

yobx.xshape.shape_type_compute.set_type_shape_to_complex(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type ToComplex (converts float to complex).

yobx.xshape.shape_type_compute.set_type_shape_transpose_2d_cast_fp16(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type Transpose2DCastFP16 (transposes and casts to float16).

yobx.xshape.shape_type_compute.set_type_shape_transpose_2d_cast_fp32(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type Transpose2DCastFP32 (transposes and casts to float32).

yobx.xshape.shape_type_compute.set_type_shape_tree_ensemble(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node types TreeEnsemble and TreeEnsembleRegressor.

yobx.xshape.shape_type_compute.set_type_shape_tri_matrix(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type TriMatrix.

yobx.xshape.shape_type_compute.set_type_shape_unary_op(g: ShapeBuilder, name: str, input_name: str, itype: int | None = None) bool[source]#

Sets the shape and type for an unary operator (abs, exp, …).

yobx.xshape.shape_type_compute.set_type_shape_unary_op_abs(g: ShapeBuilder, name: str, input_name: str, itype: int | None = None) bool[source]#

Sets the shape and type for an unary operator (abs, exp, …).