yobx.xshape.shape_type_compute#

yobx.xshape.shape_type_compute.broadcast_shape(sh1: Tuple[int | torch.SymInt | torch.SymFloat | TracingInt | float | str, ...], sh2: Tuple[int | torch.SymInt | torch.SymFloat | TracingInt | float | str, ...], graph_builder: ShapeBuilder | None = None) Tuple[int | torch.SymInt | torch.SymFloat | TracingInt | float | str, ...][source]#

Computes the output shape for broadcasting operators (e.g. Add, Mul, Where).

The function follows NumPy/ONNX broadcasting rules. Shapes are right-aligned and each pair of dimensions (a, b) is resolved according to the table below:

a

b

Result

Side effect

int (any)

int (any)

max

none

int n 0, 1

str (symbolic)

n

register_constraint_dimension (b, n) if graph_builder

1

str (symbolic)

b

none

str (symbolic)

int n 0, 1

n

register_constraint_dimension (a, n) if graph_builder

str (symbolic)

1

a

none

str a == b

str a == b

a

none

str a != b

str a != b

a^b

none (^ means max)

When a symbolic dimension is paired with a concrete integer n 1, the concrete value is chosen as the output dimension and the equality is stored as a constraint via register_constraint_dimension. This avoids the need to backtrack through earlier nodes when the concrete value is discovered later: downstream operations immediately see a precise integer shape.

Parameters:
  • sh1 – first shape (tuple of ints and/or symbolic strings)

  • sh2 – second shape (tuple of ints and/or symbolic strings)

  • graph_builder – if not None, constraints are registered on this builder whenever a symbolic dimension is equated to a concrete integer

Returns:

resulting broadcast shape

yobx.xshape.shape_type_compute.set_shape_type_attention_microsoft(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for com.microsoft.Attention.

yobx.xshape.shape_type_compute.set_shape_type_causal_conv_with_state(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for com.microsoft.CausalConvWithState.

Inputs: input (N, C, L), weight (C, 1, K), bias (C) (optional), past_state (N, C, K-1) (optional). Outputs: output (same shape as input), present_state (N, C, K-1) (optional).

yobx.xshape.shape_type_compute.set_shape_type_cdist(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for com.microsoft.CDist.

Input A has shape (N, D) and input B has shape (M, D), so the output has shape (N, M).

yobx.xshape.shape_type_compute.set_shape_type_custom(self: ShapeBuilder, node: NodeProto, exc: bool = False)[source]#

Sets the shape and type if it can.

yobx.xshape.shape_type_compute.set_shape_type_embed_layer_normalization(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for com.microsoft.EmbedLayerNormalization.

Inputs: input_ids [B, S], segment_ids [B, S] (optional), word_embedding [V, D], position_embedding [P, D], segment_embedding [NS, D] (optional), gamma [D], beta [D], mask [B, S] (optional), position_ids [B, S] (optional). Outputs: output [B, S, D], mask_index [B].

yobx.xshape.shape_type_compute.set_shape_type_gated_relative_position_bias(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for com.microsoft.GatedRelativePositionBias.

Inputs: query_layer (batch, seq_len, num_heads*head_size), query_bias, rel_pos (1, num_heads, seq_len, seq_len), weight, bias, eco_a (1, num_heads, 1, 1), [token_offset]. Output: (batch_size, num_heads, seq_len, seq_len).

Sets the output shape for com.microsoft.GreedySearch.

Input input_ids has shape (batch_size, sequence_length) and type INT32. Input max_length is a scalar INT32. Output sequences has shape (batch_size, max_length_value) and type INT32.

yobx.xshape.shape_type_compute.set_shape_type_group_query_attention(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for com.microsoft.GroupQueryAttention.

yobx.xshape.shape_type_compute.set_shape_type_moe(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for com.microsoft.MoE (Mixture of Experts).

Input input has shape (num_tokens, hidden_size) or (batch_size, seq_len, hidden_size). Output output has the same shape and dtype as the input.

yobx.xshape.shape_type_compute.set_shape_type_murmur_hash3(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for com.microsoft.MurmurHash3.

yobx.xshape.shape_type_compute.set_shape_type_op_any(self: ShapeBuilder, node: NodeProto, exc: bool = False)[source]#

Sets the shape and type if it can.

yobx.xshape.shape_type_compute.set_shape_type_packed_multi_head_attention(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for com.microsoft.PackedMultiHeadAttention.

yobx.xshape.shape_type_compute.set_shape_type_relative_position_bias(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for com.microsoft.RelativePositionBias.

Inputs: bias_table (num_heads, num_buckets), query_length (), key_length (). Output: (1, num_heads, query_length, key_length).

yobx.xshape.shape_type_compute.set_type_shape_bias_split_gelu(g: ShapeBuilder, node: NodeProto)[source]#

Sets the shape and type for com.microsoft.BiasSplitGelu.

The operator computes Y = left * Gelu(right) after adding a bias and splitting the last dimension into two equal halves, so the output shape equals the input shape with the last dimension halved.

yobx.xshape.shape_type_compute.set_type_shape_binary_op(g: ShapeBuilder, name: str, *input_names: List[str], begin: int = 0, cmp_op: bool = False, itype: int | None = None) bool[source]#

Sets the shape and type for a binary operator (add, mul, …).

yobx.xshape.shape_type_compute.set_type_shape_complex_module(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type ComplexModule (extracts real/imaginary part).

yobx.xshape.shape_type_compute.set_type_shape_fused_matmul(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type FusedMatMul.

yobx.xshape.shape_type_compute.set_type_shape_gemm(g: ShapeBuilder, name: str, x: str, y: str, transA: int, transB: int)[source]#

Sets the output shape for node type Gemm.

yobx.xshape.shape_type_compute.set_type_shape_matmul(g: ShapeBuilder, name: str, x: str, y: str) bool[source]#

Sets the output shape for node type MatMul.

yobx.xshape.shape_type_compute.set_type_shape_multi_head_attention(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type MultiHeadAttention.

yobx.xshape.shape_type_compute.set_type_shape_reduce_op(g: ShapeBuilder, name: str, x: str, keepdim: int, axes: Tuple[int] | None = None)[source]#

Sets the output shape for any Reduce type.

yobx.xshape.shape_type_compute.set_type_shape_reshape(g: ShapeBuilder, name: str, input_name: str, new_shape: Sequence[int])[source]#

Sets the output shape for node type Reshape

yobx.xshape.shape_type_compute.set_type_shape_scatter_nd_of_shape(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node types ScatterNDOfShape and MaskedScatterNDOfShape.

yobx.xshape.shape_type_compute.set_type_shape_shared_input(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shapes for nodes with two outputs sharing the same inputs.

yobx.xshape.shape_type_compute.set_type_shape_to_complex(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type ToComplex (converts float to complex).

yobx.xshape.shape_type_compute.set_type_shape_transpose_2d_cast_fp16(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type Transpose2DCastFP16 (transposes and casts to float16).

yobx.xshape.shape_type_compute.set_type_shape_transpose_2d_cast_fp32(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type Transpose2DCastFP32 (transposes and casts to float32).

yobx.xshape.shape_type_compute.set_type_shape_tree_ensemble(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node types TreeEnsemble and TreeEnsembleRegressor.

yobx.xshape.shape_type_compute.set_type_shape_tri_matrix(self: ShapeBuilder, node: NodeProto)[source]#

Sets the output shape for node type TriMatrix.

yobx.xshape.shape_type_compute.set_type_shape_unary_op(g: ShapeBuilder, name: str, input_name: str, itype: int | None = None) bool[source]#

Sets the shape and type for an unary operator (abs, exp, …).

yobx.xshape.shape_type_compute.set_type_shape_unary_op_abs(g: ShapeBuilder, name: str, input_name: str, itype: int | None = None) bool[source]#

Sets the shape and type for an unary operator (abs, exp, …).

yobx.xshape.shape_type_compute.supported_ops_in_set_shape_type_custom() Dict[str, FrozenSet[str]][source]#

Returns the ops supported by set_shape_type_custom() grouped by domain.

Returns a dictionary mapping each ONNX domain name to a frozenset of op type names for which set_shape_type_custom() provides shape and type inference.

The special key "" (empty string) groups ops that are handled regardless of their domain (i.e. no domain check is performed for them). Local functions registered at runtime are not included because they are determined dynamically.

Returns:

Dictionary mapping domain name to a frozenset of supported op types.