.torch_interpreter._aten_functions_attention¶
See https://pytorch.org/docs/stable/torch.compiler_ir.html for the full list of aten functions.
- experimental_experiment.torch_interpreter._aten_functions_attention.aten__scaled_dot_product_efficient_attention(g: GraphBuilder, sts: Dict[str, Any] | None, outputs: List[str], query: str, key: str, value: str, attn_bias: str | None, compute_log_sumexp: bool, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, name: str = '_scaled_dot_product_efficient_attention') Tuple[str, str, str, str] [source]¶
_scaled_dot_product_efficient_attention (cuda)
- experimental_experiment.torch_interpreter._aten_functions_attention.aten__scaled_dot_product_flash_attention(g: GraphBuilder, sts: Dict[str, Any] | None, outputs: List[str], query: str, key: str, value: str, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, scale: float | None = None, name: str = '_scaled_dot_product_flash_attention') Tuple[str, str, str, str, str, str, str, str, str] [source]¶
_scaled_dot_product_flash_attention
- experimental_experiment.torch_interpreter._aten_functions_attention.aten__scaled_dot_product_flash_attention_for_cpu(g: GraphBuilder, sts: Dict[str, Any] | None, outputs: List[str], query: str, key: str, value: str, dropout_p: float = 0.0, is_causal: bool = False, attn_mask: str | None = None, scale: float | None = None, return_debug_mask: bool = False, name: str = '_scaled_dot_product_flash_attention_for_cpu_default') Tuple[str, str, str, str, str, str, str, str, str] [source]¶
_scaled_dot_product_flash_attention
- experimental_experiment.torch_interpreter._aten_functions_attention.aten_scaled_dot_product_attention(g: GraphBuilder, sts: Dict[str, Any] | None, outputs: List[str], query: str, key: str, value: str, attn_mask: str | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, name: str = 'aten_scaled_dot_product_attention')[source]¶
scaled_dot_product_attention
See torch.nn.functional.scaled_dot_product_attention.
Equivalent to the PyTorch code:
scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale attn_mask = ( torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask ) attn_mask = ( attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask ) attn_weight = torch.softmax( (Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1 ) attn_weight = torch.dropout(attn_weight, dropout_p) return attn_weight @ V
where Q, K, V are the query, key, and value tensors, respectively. L is the target sequence length, S is the source sequence length, and E is the embedding size.