yobx.xshape.shape_type_compute#

yobx.xshape.shape_type_compute.broadcast_shape(sh1: Tuple[int | torch.SymInt | torch.SymFloat | TracingInt | float | str, ...], sh2: Tuple[int | torch.SymInt | torch.SymFloat | TracingInt | float | str, ...], graph_builder: ShapeBuilder | None = None) Tuple[int | torch.SymInt | torch.SymFloat | TracingInt | float | str, ...][source]#

Computes the output shape for broadcasting operators (e.g. Add, Mul, Where).

The function follows NumPy/ONNX broadcasting rules. Shapes are right-aligned and each pair of dimensions (a, b) is resolved according to the table below:

a

b

Result

Side effect

int (any)

int (any)

max

none

int n 0, 1

str (symbolic)

n

register_constraint_dimension (b, n) if graph_builder

1

str (symbolic)

b

none

str (symbolic)

int n 0, 1

n

register_constraint_dimension (a, n) if graph_builder

str (symbolic)

1

a

none

str a == b

str a == b

a

none

str a != b

str a != b

a^b

none (^ means max)

When a symbolic dimension is paired with a concrete integer n 1, the concrete value is chosen as the output dimension and the equality is stored as a constraint via register_constraint_dimension. This avoids the need to backtrack through earlier nodes when the concrete value is discovered later: downstream operations immediately see a precise integer shape.

Parameters:
  • sh1 – first shape (tuple of ints and/or symbolic strings)

  • sh2 – second shape (tuple of ints and/or symbolic strings)

  • graph_builder – if not None, constraints are registered on this builder whenever a symbolic dimension is equated to a concrete integer

Returns:

resulting broadcast shape

yobx.xshape.shape_type_compute.set_shape_type_custom(self: ShapeBuilder, node: NodeProto, exc: bool = False)[source]#

Sets the shape and type if it can.

yobx.xshape.shape_type_compute.set_shape_type_op_any(self: ShapeBuilder, node: NodeProto, exc: bool = False)[source]#

Sets the shape and type if it can.

yobx.xshape.shape_type_compute.set_type_shape_binary_op(g: ShapeBuilder, name: str, *input_names: List[str], begin: int = 0, cmp_op: bool = False, itype: int | None = None) bool[source]#

Sets the shape and type for a binary operator (add, mul, …).

yobx.xshape.shape_type_compute.set_type_shape_complex_module(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type ComplexModule (extracts real/imaginary part).

yobx.xshape.shape_type_compute.set_type_shape_fused_matmul(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type FusedMatMul.

yobx.xshape.shape_type_compute.set_type_shape_gemm(g: ShapeBuilder, name: str, x: str, y: str, transA: int, transB: int)[source]#

Sets the output shape for node type Gemm.

yobx.xshape.shape_type_compute.set_type_shape_matmul(g: ShapeBuilder, name: str, x: str, y: str) bool[source]#

Sets the output shape for node type MatMul.

yobx.xshape.shape_type_compute.set_type_shape_multi_head_attention(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type MultiHeadAttention.

yobx.xshape.shape_type_compute.set_type_shape_reduce_op(g: ShapeBuilder, name: str, x: str, keepdim: int, axes: Tuple[int] | None = None)[source]#

Sets the output shape for any Reduce type.

yobx.xshape.shape_type_compute.set_type_shape_reshape(g: ShapeBuilder, name: str, input_name: str, new_shape: Sequence[int])[source]#

Sets the output shape for node type Reshape

yobx.xshape.shape_type_compute.set_type_shape_scatter_nd_of_shape(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node types ScatterNDOfShape and MaskedScatterNDOfShape.

yobx.xshape.shape_type_compute.set_type_shape_shared_input(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shapes for nodes with two outputs sharing the same inputs.

yobx.xshape.shape_type_compute.set_type_shape_to_complex(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type ToComplex (converts float to complex).

yobx.xshape.shape_type_compute.set_type_shape_transpose_2d_cast_fp16(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type Transpose2DCastFP16 (transposes and casts to float16).

yobx.xshape.shape_type_compute.set_type_shape_transpose_2d_cast_fp32(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type Transpose2DCastFP32 (transposes and casts to float32).

yobx.xshape.shape_type_compute.set_type_shape_tree_ensemble(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node types TreeEnsemble and TreeEnsembleRegressor.

yobx.xshape.shape_type_compute.set_type_shape_tri_matrix(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type TriMatrix.

yobx.xshape.shape_type_compute.set_type_shape_unary_op(g: ShapeBuilder, name: str, input_name: str, itype: int | None = None) bool[source]#

Sets the shape and type for an unary operator (abs, exp, …).

yobx.xshape.shape_type_compute.set_type_shape_unary_op_abs(g: ShapeBuilder, name: str, input_name: str, itype: int | None = None) bool[source]#

Sets the shape and type for an unary operator (abs, exp, …).