vllm.attention.ops.flashmla
 
 flash_mla_with_kvcache(
    q: Tensor,
    k_cache: Tensor,
    block_table: Tensor,
    cache_seqlens: Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: Tensor,
    num_splits: Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
) -> Tuple[Tensor, Tensor]
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| q | Tensor | (batch_size, seq_len_q, num_heads_q, head_dim). | required | 
| k_cache | Tensor | (num_blocks, page_block_size, num_heads_k, head_dim). | required | 
| block_table | Tensor | (batch_size, max_num_blocks_per_seq), torch.int32. | required | 
| cache_seqlens | Tensor | (batch_size), torch.int32. | required | 
| head_dim_v | int | Head_dim of v. | required | 
| tile_scheduler_metadata | Tensor | (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. | required | 
| num_splits | Tensor | (batch_size + 1), torch.int32, return by get_mla_metadata. | required | 
| softmax_scale | Optional[float] | float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). | None | 
| causal | bool | bool. Whether to apply causal attention mask. | False | 
Return
out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
Source code in vllm/attention/ops/flashmla.py
  
 get_mla_metadata(
    cache_seqlens: Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
) -> Tuple[Tensor, Tensor]
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| cache_seqlens | Tensor | (batch_size), dtype torch.int32. | required | 
| num_heads_per_head_k | int | Equals to seq_len_q * num_heads_q // num_heads_k. | required | 
| num_heads_k | int | num_heads_k. | required | 
Return
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32.
Source code in vllm/attention/ops/flashmla.py
  
  Return: is_supported_flag, unsupported_reason (optional).