.torch_interpreter._aten_functions_attention

See https://pytorch.org/docs/stable/torch.compiler_ir.html for the full list of aten functions.

experimental_experiment.torch_interpreter._aten_functions_attention.aten__scaled_dot_product_efficient_attention(g: GraphBuilder, sts: Dict[str, Any] | None, outputs: List[str], query: str, key: str, value: str, attn_bias: str | None, compute_log_sumexp: bool, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, name: str = '_scaled_dot_product_efficient_attention') Tuple[str, str, str, str][source]

_scaled_dot_product_efficient_attention (cuda)

experimental_experiment.torch_interpreter._aten_functions_attention.aten__scaled_dot_product_flash_attention(g: GraphBuilder, sts: Dict[str, Any] | None, outputs: List[str], query: str, key: str, value: str, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, scale: float | None = None, name: str = '_scaled_dot_product_flash_attention') Tuple[str, str, str, str, str, str, str, str, str][source]

_scaled_dot_product_flash_attention

experimental_experiment.torch_interpreter._aten_functions_attention.aten__scaled_dot_product_flash_attention_for_cpu(g: GraphBuilder, sts: Dict[str, Any] | None, outputs: List[str], query: str, key: str, value: str, dropout_p: float = 0.0, is_causal: bool = False, attn_mask: str | None = None, scale: float | None = None, return_debug_mask: bool = False, name: str = '_scaled_dot_product_flash_attention_for_cpu_default') Tuple[str, str, str, str, str, str, str, str, str][source]

_scaled_dot_product_flash_attention

experimental_experiment.torch_interpreter._aten_functions_attention.aten_scaled_dot_product_attention(g: GraphBuilder, sts: Dict[str, Any] | None, outputs: List[str], query: str, key: str, value: str, attn_mask: str | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, name: str = 'aten_scaled_dot_product_attention')[source]

scaled_dot_product_attention

See torch.nn.functional.scaled_dot_product_attention.

Equivalent to the PyTorch code:

scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
attn_mask = (
    torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
    if is_causal
    else attn_mask
)
attn_mask = (
    attn_mask.masked_fill(not attn_mask, -float('inf'))
    if attn_mask.dtype==torch.bool
    else attn_mask
)
attn_weight = torch.softmax(
    (Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1
)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V

where Q, K, V are the query, key, and value tensors, respectively. L is the target sequence length, S is the source sequence length, and E is the embedding size.