yobx.torch.new_tracing.tensor#

class yobx.torch.new_tracing.tensor.TracingTensor(size: Tuple[int, ...] | TracingShape, dtype: dtype, device: str | device | None = None, requires_grad: bool = False, tracer: GraphTracer | None = None)[source]#

A torch.Tensor subclass that records all dispatch-level operations into a torch.fx.Graph via __torch_dispatch__.

TracingTensor uses __torch_dispatch__ to intercept all tensor operations at the C++ dispatcher level and records them as nodes in a torch.fx.Graph. This produces a full computation graph without requiring Python-level symbolic proxy objects.

Note

TracingTensor instances are created internally by DispatchTracer. Use trace() to trace a callable rather than constructing TracingTensor directly.

It contains two attributes:

  • _tracer: The DispatchTracer managing this tensor’s graph.

  • _node: The torch.fx.Node corresponding to this tensor in the graph.

classmethod from_tensor(t: Tensor, dynamic_shapes: Dict[int, Any] | None = None, tracer: GraphTracer | None = None) TracingTensor[source]#

Creates a tracing tensor.

item() int | float | TracingInt[source]#

Intercepts .item() during tracing to return a symbolic TracingInt backed by a graph node instead of raising an error or returning a bare Python scalar.

When a tracer is active, this method delegates to _handle_local_scalar_dense(), which emits an aten._local_scalar_dense FX node and wraps the result in a TracingInt. The caller can then use the returned TracingInt as a dynamic slice endpoint (e.g. x[..., :shape.item()]), which is recognised by _index_has_symbolic_tracing_int() and routed through _handle_symbolic_getitem().

Returns:

A TracingInt (symbolic) when a tracer is active, or the actual Python scalar when no tracer is present (falls back to torch.Tensor.item()).

make_empty_instance(dyanmic_shape_values: Dict[str, int] | None = None) Tensor[source]#

Allocates an uninitialised torch.empty() tensor whose dtype and device match this TracingTensor.

Concrete integer dimensions are used as-is. Symbolic (string) dimensions must be resolved by supplying dyanmic_shape_values, a mapping from dimension name to its concrete integer value. A missing entry for any symbolic dimension raises AssertionError.

Parameters:

dyanmic_shape_values – Optional mapping from symbolic dimension names (e.g. "batch") to their concrete integer sizes.

Returns:

A real torch.Tensor with the resolved shape, the same dtype, and the same device as this TracingTensor. The tensor is uninitialised (contents are undefined).

Raises:
new_zeros(size: Any, **kwargs: Any) Any[source]#

Accepts a TracingShape as the size argument in addition to the types accepted by the base class.

When a module attribute is temporarily replaced with a TracingTensor during tracing, self.attr.shape returns a TracingShape instead of a torch.Size. The C++ pybind for new_zeros does not accept custom sequence types, so this override converts concrete TracingShape objects to plain tuple of int before delegating to the parent implementation (which routes through __torch_dispatch__).

Parameters:
  • size – The desired output shape. A TracingShape with concrete integer dimensions is converted to a plain tuple; all other values are forwarded unchanged.

  • kwargs – Additional keyword arguments forwarded to torch.Tensor.new_zeros().

Returns:

A TracingTensor representing the zero-filled tensor.

numel() int | TracingInt[source]#

Computes the total number of elements from _tracing_shape.

Concrete integer dimensions contribute their actual value directly. Symbolic (string-valued) TracingInt dimensions are folded into the product as symbolic terms, yielding a TracingInt return value.

This ensures that guards of the form if x.numel() == 0: can be resolved at trace time via the TracingBool mechanism: models should use torch._check(x.numel() != 0) to register the non-empty constraint (as with ControlFlowShapeCheck), after which __bool__() resolves the equality to False via its negation lookup.

A concrete dimension of 0 still causes an immediate return of 0 so that genuinely empty static shapes are identified correctly.

When a tracer is active and the result is symbolic, FX nodes are emitted for the numel computation and stored on the returned TracingInt (see _emit_numel_node()). Subsequent comparisons such as numel() > 0 then also emit comparison FX nodes, allowing the result to serve as a torch.cond predicate.

Returns:

Plain int when every dimension is concrete; TracingInt when any dimension is symbolic.

Return type:

Union[int, TracingInt]

property shape: TracingShape#

Returns the shape as a TracingShape.