class DualChunkFlashAttentionImpl(FlashAttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prefill_tokens ----------------->|
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
    Otherwise, the layout is as follows:
    |<----------------- num_decode_tokens ------------------>|
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.
    The prompts might have different lengths, while the generation tokens
    always have length 1.
    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.
    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
    """
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: Optional[str] = None,
        layer_idx: int = -1,
        dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
    ) -> None:
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not supported in V0 "
                                      "DUAL_CHUNK_FLASH_ATTN backend.")
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        if sliding_window is not None:
            # NOTE(woosuk): flash-attn's sliding window does not work with
            # paged KV cache.
            raise ValueError(
                "Sliding window is not supported in FlashAttention.")
        support_head_sizes = (
            DualChunkFlashAttentionBackend.get_supported_head_sizes())
        if head_size not in support_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by FlashAttention. "
                f"Supported head sizes are: {support_head_sizes}.")
        assert dual_chunk_attention_config is not None
        self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192)
        self.local_size = dual_chunk_attention_config.get("local_size", 1024)
        self.original_max_position_embeddings = dual_chunk_attention_config.get(
            "original_max_position_embeddings", 0)
        self.sparse_attention_config = dual_chunk_attention_config.get(
            "sparse_attention_config", None)
        if not self.sparse_attention_config:
            logger.warning_once("Sparse attention will not be enabled as "
                                "sparse attention config is not provided.")
        self.sparse_attention_enabled = dual_chunk_attention_config.get(
            "sparse_attention_enabled", self.sparse_attention_config
            is not None)
        self.sparse_attention_threshold = dual_chunk_attention_config.get(
            "sparse_attention_threshold", 32768)
        self.sparse_attention_last_q = dual_chunk_attention_config.get(
            "sparse_attention_last_q", 64)
        self.layer_idx = layer_idx
        self.dual_chunk_attention_config = dual_chunk_attention_config
        if self.sparse_attention_config:
            self.sparse_attention_config = {
                int(i): j
                for i, j in self.sparse_attention_config[
                    self.layer_idx].items()
            }
            start_head = self.num_heads * get_tensor_model_parallel_rank()
            end_head = start_head + self.num_heads
            self.sparse_attention_config = [
                self.sparse_attention_config[i]
                for i in range(start_head, end_head)
            ]
        if self.sparse_attention_enabled:
            self.arange = torch.arange(self.sparse_attention_last_q,
                                       device="cuda")
            self.last_q_mask = (self.arange[None, None, :, None]
                                >= self.arange[None, None, None, :])
    def forward(  # type: ignore
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: DualChunkFlashAttentionMetadata,
        output: Optional[torch.Tensor] = None,
        output_scale: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with DualChunkFlashAttention.
        Args:
            query: shape = [num_tokens, num_heads * head_size]
            query_succ: shape = [num_tokens, num_heads * head_size]
            query_inter: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is None, "Output tensor not supported for DualChunk"
        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashAttentionImpl")
        (
            query,
            query_succ,
            query_inter,
            query_succ_critical,
            query_inter_critical,
        ) = torch.split(query, query.shape[-1] // 5, dim=-1)
        assert (
            query_succ is not None and query_inter is not None
        ), "query_succ and query_inter are required in Dual Chunk Attention."
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        query_succ = query_succ.view(-1, self.num_heads, self.head_size)
        query_inter = query_inter.view(-1, self.num_heads, self.head_size)
        query_succ_critical = query_succ_critical.view(-1, self.num_heads,
                                                       self.head_size)
        query_inter_critical = query_inter_critical.view(
            -1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
        if self.original_max_position_embeddings > 0:
            if prefill_meta := attn_metadata.prefill_metadata:
                assert prefill_meta.scaling_factor is not None
                assert prefill_meta.query_start_loc is not None
                assert prefill_meta.orig_seq_lens is not None
                current_start = 0
                query_start_loc_cpu = prefill_meta.query_start_loc.cpu()
                for i in range(len(prefill_meta.orig_seq_lens)):
                    current_end = (current_start +
                                   (query_start_loc_cpu[i + 1] -
                                    query_start_loc_cpu[i]).item())
                    key[current_start:current_end].mul_(
                        prefill_meta.scaling_factor[i])
                    current_start = current_end
                assert current_end <= attn_metadata.num_prefill_tokens
            if decode_meta := attn_metadata.decode_metadata:
                assert decode_meta.scaling_factor is not None
                scaling_factor = decode_meta.scaling_factor
                key[attn_metadata.num_prefill_tokens:].mul_(
                    scaling_factor.unsqueeze(-1).unsqueeze(-1))
        if kv_cache is not None and kv_cache.numel() > 0:
            key_cache = kv_cache[0]
            value_cache = kv_cache[1]
            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            ops.reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping.flatten(),
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens
        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        decode_query_succ = query_succ[num_prefill_tokens:]
        decode_query_inter = query_inter[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
        query_succ = query_succ[:num_prefill_tokens]
        query_inter = query_inter[:num_prefill_tokens]
        query_succ_critical = query_succ_critical[:num_prefill_tokens]
        query_inter_critical = query_inter_critical[:num_prefill_tokens]
        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]
        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens
        if prefill_meta := attn_metadata.prefill_metadata:
            # Prompt run.
            if (kv_cache is None or prefill_meta.block_tables is None
                    or prefill_meta.block_tables.numel() == 0):
                # normal attention, called during the profiling run.
                out = flash_attn_varlen_func(
                    q=query,
                    k=key,
                    v=value,
                    cu_seqlens_q=prefill_meta.seq_start_loc,
                    cu_seqlens_k=prefill_meta.seq_start_loc,
                    max_seqlen_q=prefill_meta.max_prefill_seq_len,
                    max_seqlen_k=prefill_meta.max_prefill_seq_len,
                    softmax_scale=self.scale,
                    causal=True,
                    window_size=self.sliding_window,
                    alibi_slopes=self.alibi_slopes,
                )
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out
            else:
                # prefix-enabled attention
                assert prefill_meta.seq_lens is not None
                assert prefill_meta.orig_seq_lens is not None
                output[:num_prefill_tokens] = (
                    self._dual_chunk_flash_attn_prefill(
                        q=query,
                        q_succ=query_succ,
                        q_inter=query_inter,
                        q_succ_critical=query_succ_critical,
                        q_inter_critical=query_inter_critical,
                        k=key_cache,
                        v=value_cache,
                        cu_seqlens_q=prefill_meta.query_start_loc,
                        cu_seqlens_k=prefill_meta.seq_start_loc,
                        orig_seq_lens=prefill_meta.orig_seq_lens,
                        scaling_factor=prefill_meta.scaling_factor,
                        softmax_scale=self.scale,
                        causal=True,
                        window_size=(-1, -1),
                        alibi_slopes=self.alibi_slopes,
                        block_table=prefill_meta.block_tables,
                        chunk_size=self.chunk_size,
                        local_size=self.local_size,
                    ))
        if decode_meta := attn_metadata.decode_metadata:
            # Decoding run.
            output[num_prefill_tokens:] = (
                self._dual_chunk_flash_attn_decoding(
                    decode_query.unsqueeze(1),
                    decode_query_succ.unsqueeze(1),
                    decode_query_inter.unsqueeze(1),
                    key_cache,
                    value_cache,
                    block_table=decode_meta.block_tables,
                    cache_seqlens=decode_meta.seq_lens_tensor,
                    softmax_scale=self.scale,
                    causal=True,
                    alibi_slopes=self.alibi_slopes,
                    chunk_size=self.chunk_size,
                    local_size=self.local_size,
                    original_max_position_embeddings=self.
                    original_max_position_embeddings,
                    decode_meta=decode_meta,
                ).squeeze(1))
        # Reshape the output tensor.
        return output.view(num_tokens, hidden_size)
    def _dual_chunk_flash_attn_prefill(
        self,
        q,
        q_succ,
        q_inter,
        q_succ_critical,
        q_inter_critical,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        orig_seq_lens: List[int],
        scaling_factor: torch.Tensor,
        softmax_scale: float,
        causal: Optional[bool] = True,
        window_size: Tuple[int, int] = (-1, -1),
        alibi_slopes: Optional[torch.Tensor] = None,
        block_table: Optional[torch.Tensor] = None,
        chunk_size: int = 8192,
        local_size: int = 1024,
    ):
        if alibi_slopes is not None:
            raise ValueError(
                "Dual Chunk Attention does not support alibi_slopes")
        if not causal:
            raise ValueError(
                "Dual Chunk Attention does not support causal=False")
        if window_size != (-1, -1):
            raise ValueError(
                "Dual Chunk Attention does not support window_size")
        cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist()
        cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist()
        all_outputs = []
        for i in range(0, len(cu_seqlens_q_cpu) - 1):
            qs = cu_seqlens_q_cpu[i]
            qe = cu_seqlens_q_cpu[i:i + 2][-1]
            ks = cu_seqlens_k_cpu[i]
            ke = cu_seqlens_k_cpu[i:i + 2][-1]
            current_q = q[qs:qe]
            current_q_succ = q_succ[qs:qe]
            current_q_inter = q_inter[qs:qe]
            current_q_succ_critical = q_succ_critical[qs:qe]
            current_q_inter_critical = q_inter_critical[qs:qe]
            if block_table is None:
                current_k = k[ks:ke]
                current_v = v[ks:ke]
                current_block_table = None
                current_orig_seq_len = orig_seq_lens[i]
            else:
                current_block_table = block_table[i]
                current_orig_seq_len = orig_seq_lens[i]
                current_k = k
                current_v = v
            sparse_attn_enabled = (self.sparse_attention_enabled
                                   and current_orig_seq_len
                                   > self.sparse_attention_threshold)
            if current_q.shape[0] == 0:
                continue
            if current_k.shape[0] == 0:
                all_outputs.append(
                    torch.zeros(
                        (current_q.shape[0], current_q.shape[1], v.shape[2]),
                        device=q.device,
                        dtype=q.dtype,
                    ))
                continue
            current_output = torch.empty_like(current_q)
            group_size = int(current_q.size(-2) / current_k.size(-2))
            if sparse_attn_enabled:
                num_device_q_heads = current_q.size(-2)
                heads_vertical_size = torch.empty(size=(num_device_q_heads, ),
                                                  dtype=torch.int32)
                heads_slash_size = torch.empty(size=(num_device_q_heads, ),
                                               dtype=torch.int32)
                for head_id in range(current_q.size(-2)):
                    (
                        ty,
                        vertical_size,
                        slash_size,
                        _,
                    ) = self.sparse_attention_config[head_id]
                    assert ty == "vertical_and_slash", "only support slash mode"
                    if vertical_size == 30:
                        vertical_size += 100
                    heads_vertical_size[head_id] = vertical_size
                    heads_slash_size[head_id] = slash_size
                current_output = self._dual_chunk_flash_attn_prefill_func(
                    current_q,  # allheads
                    current_q_succ,
                    current_q_inter,
                    current_q_succ_critical,
                    current_q_inter_critical,
                    current_k,
                    current_v,
                    current_block_table,
                    softmax_scale,
                    chunk_size,
                    local_size,
                    scaling_factor[i].item(),
                    ke - ks,
                    sparse_attn_enabled=sparse_attn_enabled,
                    heads_vertical_size=heads_vertical_size,
                    heads_slash_size=heads_slash_size,
                    group_size=group_size)
            else:
                for head_id in range(current_q.size(-2)):
                    # (seq_len, num_heads, head_size)
                    current_q_head = current_q[:, head_id, :].unsqueeze(1)
                    current_q_succ_head = \
                        current_q_succ[:, head_id, :].unsqueeze(1)
                    current_q_inter_head = \
                        current_q_inter[:, head_id, :].unsqueeze(1)
                    current_q_succ_head_critical = \
                        current_q_succ_critical[:, head_id, :].unsqueeze(1)
                    current_q_inter_head_critical = \
                        current_q_inter_critical[:, head_id, :].unsqueeze(1)
                    if block_table is not None:
                        current_k_head = current_k[..., head_id //
                                                   group_size, :].unsqueeze(2)
                        current_v_head = current_v[..., head_id //
                                                   group_size, :].unsqueeze(2)
                    else:
                        current_k_head = current_k[:, head_id, :].unsqueeze(1)
                        current_v_head = current_v[:, head_id, :].unsqueeze(1)
                    current_out = self._dual_chunk_flash_attn_prefill_func(
                        current_q_head,
                        current_q_succ_head,
                        current_q_inter_head,
                        current_q_succ_head_critical,
                        current_q_inter_head_critical,
                        current_k_head,
                        current_v_head,
                        current_block_table,
                        softmax_scale,
                        chunk_size,
                        local_size,
                        scaling_factor[i].item(),
                        ke - ks,
                        sparse_attn_enabled=sparse_attn_enabled,
                    )
                    current_output[:, head_id:head_id + 1, :] = current_out
            all_outputs.append(current_output)
        return torch.cat(all_outputs, dim=0)
    def _dual_chunk_flash_attn_prefill_func(
        self,
        q,
        q_succ,
        q_inter,
        q_succ_critical,
        q_inter_critical,
        k,
        v,
        block_table,
        softmax_scale: float,
        chunk_size: int,
        local_size: int,
        scaling_factor: float,
        k_length: int,
        sparse_attn_enabled: Optional[bool] = True,
        heads_vertical_size=None,
        heads_slash_size=None,
        group_size=None,
    ):
        flash_results = []
        chunk_len = chunk_size - local_size
        if block_table is not None:
            block_size = v.shape[1]
            if chunk_len % block_size != 0:
                raise ValueError("chunk_len must be divisible by block_size.")
        else:
            block_size = 1
        if self.original_max_position_embeddings > 0:
            softmax_scale = softmax_scale * scaling_factor
        begin = k_length - q.shape[0]
        while begin < k_length:
            flash_per_chunk = []
            prev_chunk_end_pos = (begin // chunk_len) * chunk_len
            next_chunk_end_pos = prev_chunk_end_pos + chunk_len
            end = min(next_chunk_end_pos, k_length)
            qbegin = begin - (k_length - q.shape[0])
            qend = end - (k_length - q.shape[0])
            qk_chunks = []
            q_states_intra = q[qbegin:qend]
            # choose critical token
            if block_table is not None:
                block_tables_intra = _get_block(block_table, block_size,
                                                prev_chunk_end_pos, end)
                k_states_intra = k[block_tables_intra].view(
                    -1, *k.shape[-2:])[:(end - prev_chunk_end_pos)]
                v_states_intra = v[block_tables_intra].view(
                    -1, *v.shape[-2:])[:(end - prev_chunk_end_pos)]
            else:
                block_tables_intra = None
                k_states_intra = k[prev_chunk_end_pos:end]
                v_states_intra = v[prev_chunk_end_pos:end]
            if sparse_attn_enabled:
                last_q_size = min(qend - qbegin, self.sparse_attention_last_q)
                _, num_device_k_heads, head_dim = k_states_intra.shape
                k_states_intra = (k_states_intra.unsqueeze(2).repeat(
                    1, 1, group_size,
                    1).reshape(-1, num_device_k_heads * group_size, head_dim))
                v_states_intra = (v_states_intra.unsqueeze(2).repeat(
                    1, 1, group_size,
                    1).reshape(-1, num_device_k_heads * group_size, head_dim))
                qk_chunks.append(
                    (q_states_intra.transpose(0, 1)[:, -last_q_size:] *
                     softmax_scale) @ k_states_intra.permute(1, 2, 0))
            if prev_chunk_end_pos - chunk_len >= 0:
                q_states_succ = q_succ[qbegin:qend]
                q_states_succ_critical = q_succ_critical[qbegin:qend]
                if block_table is not None:
                    block_tables_succ = _get_block(
                        block_table, block_size,
                        prev_chunk_end_pos - chunk_len, prev_chunk_end_pos)
                    k_states_succ = k[block_tables_succ].view(
                        -1, *k.shape[-2:])[:chunk_len]
                    v_states_succ = v[block_tables_succ].view(
                        -1, *v.shape[-2:])[:chunk_len]
                else:
                    k_states_succ = k[prev_chunk_end_pos -
                                      chunk_len:prev_chunk_end_pos]
                    v_states_succ = v[prev_chunk_end_pos -
                                      chunk_len:prev_chunk_end_pos]
                if sparse_attn_enabled:
                    k_states_succ = (k_states_succ.unsqueeze(2).repeat(
                        1, 1, group_size,
                        1).reshape(-1, num_device_k_heads * group_size,
                                   head_dim))
                    v_states_succ = (v_states_succ.unsqueeze(2).repeat(
                        1, 1, group_size,
                        1).reshape(-1, num_device_k_heads * group_size,
                                   head_dim))
                    qk_chunks.append((q_states_succ_critical.transpose(
                        0, 1)[:, -last_q_size:] * softmax_scale)
                                     @ k_states_succ.permute(1, 2, 0))
            if prev_chunk_end_pos - chunk_len * 2 >= 0:
                q_states_inter = q_inter[qbegin:qend]
                q_states_inter_critical = q_inter_critical[qbegin:qend]
                if block_table is not None:
                    block_tables_inter = _get_block(
                        block_table, block_size, 0,
                        prev_chunk_end_pos - chunk_len)
                    k_states_inter = k[block_tables_inter].view(
                        -1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)]
                    v_states_inter = v[block_tables_inter].view(
                        -1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)]
                else:
                    k_states_inter = k[:prev_chunk_end_pos - chunk_len]
                    v_states_inter = v[:prev_chunk_end_pos - chunk_len]
                if sparse_attn_enabled:
                    k_states_inter = (k_states_inter.unsqueeze(2).repeat(
                        1, 1, group_size,
                        1).reshape(-1, num_device_k_heads * group_size,
                                   head_dim))
                    v_states_inter = (v_states_inter.unsqueeze(2).repeat(
                        1, 1, group_size,
                        1).reshape(-1, num_device_k_heads * group_size,
                                   head_dim))
                    qk_chunks.append((q_states_inter_critical.transpose(
                        0, 1)[:, -last_q_size:] * softmax_scale)
                                     @ k_states_inter.permute(1, 2, 0))
            if sparse_attn_enabled:
                reversed_qk = qk_chunks[::-1]
                qk = torch.cat(reversed_qk, dim=-1)
                qk[:, :, -last_q_size:] = torch.where(
                    self.last_q_mask[..., -last_q_size:,
                                     -last_q_size:].to(qk.device),
                    qk[:, :, -last_q_size:], -torch.inf)
                qk = F.softmax(qk, dim=-1, dtype=torch.float32)
                vertical = qk.sum(-2, keepdim=True)
                vertical[..., :30] = torch.inf
                # Avoid sorting by using the min/max ints to fill the indexer
                # buffers.
                int32_max = torch.iinfo(torch.int32).max
                int32_min = torch.iinfo(torch.int32).min
                n_heads = qk.size()[0]
                max_slash_topk = torch.max(heads_slash_size).item()
                max_vertical_topk = torch.max(heads_vertical_size).item()
                # store each head's slash topk, vertical topk
                vertical = vertical.reshape((n_heads, -1))
                # prevent out of range when prompt size < max_vertical_topk
                max_vertical_topk = min(vertical.shape[-1], max_vertical_topk)
                vertical_topk_buffer = torch.topk(vertical, max_vertical_topk,
                                                  -1).indices
                slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk),
                                                dtype=torch.int64,
                                                device=qk.device)
                for head_i in range(n_heads):
                    #  (nqheads=1, lastq, k_len)
                    head_score = qk[head_i:head_i + 1, :, :]
                    slash_scores = _sum_all_diagonal_matrix(head_score)
                    if head_score.size(1) != 1:
                        # drop right up corner
                        slash_scores = slash_scores[..., :-last_q_size + 1]
                    slash_scores[..., -100:] = torch.inf
                    head_slash_size = heads_slash_size[head_i]
                    head_slash_size = min(head_slash_size, vertical.size(-1))
                    slash_topk = torch.topk(slash_scores, head_slash_size,
                                            -1).indices
                    #(nheads, max_topk)
                    slash_topk_buffer[head_i, :head_slash_size] = slash_topk
                    # reset heads topk
                    heads_slash_size[head_i] = head_slash_size
                    heads_vertical_size[head_i] = min(
                        heads_vertical_size[head_i], max_vertical_topk)
                # store
                vertical_buffer = torch.full((n_heads, max_vertical_topk),
                                             int32_max,
                                             dtype=torch.int64,
                                             device=q.device)
                slash_buffer = torch.full((n_heads, max_slash_topk),
                                          int32_min,
                                          dtype=torch.int64,
                                          device=q.device)
                succ_vertical_buffer = torch.full((n_heads, max_vertical_topk),
                                                  int32_max,
                                                  dtype=torch.int64,
                                                  device=q.device)
                succ_slash_buffer = torch.full((n_heads, max_slash_topk),
                                               int32_min,
                                               dtype=torch.int64,
                                               device=q.device)
                inter_vertical_buffer = torch.full(
                    (n_heads, max_vertical_topk),
                    int32_max,
                    dtype=torch.int64,
                    device=q.device)
                inter_slash_buffer = torch.full((n_heads, max_slash_topk),
                                                int32_min,
                                                dtype=torch.int64,
                                                device=q.device)
                vertical_size_buffer = torch.empty(size=(n_heads, ),
                                                   dtype=torch.int32,
                                                   device=q.device)
                slash_sizes_buffer = torch.empty(size=(n_heads, ),
                                                 dtype=torch.int32,
                                                 device=q.device)
                succ_vertical_size_buffer = torch.empty(size=(n_heads, ),
                                                        dtype=torch.int32,
                                                        device=q.device)
                succ_slash_sizes_buffer = torch.empty(size=(n_heads, ),
                                                      dtype=torch.int32,
                                                      device=q.device)
                inter_vertical_size_buffer = torch.empty(size=(n_heads, ),
                                                         dtype=torch.int32,
                                                         device=q.device)
                inter_slash_sizes_buffer = torch.empty(size=(n_heads, ),
                                                       dtype=torch.int32,
                                                       device=q.device)
                for head_i in range(n_heads):
                    vertical_topk = vertical_topk_buffer[
                        head_i, :heads_vertical_size[head_i]]
                    # intra
                    intra_vertical_indices = vertical_topk[
                        vertical_topk >=
                        prev_chunk_end_pos] - prev_chunk_end_pos
                    if intra_vertical_indices.nelement() == 0:
                        intra_vertical_indices = torch.cat([
                            intra_vertical_indices,
                            torch.arange(0,
                                         k_states_intra.size(0),
                                         max(1,
                                             k_states_intra.size(0) / 5),
                                         dtype=torch.int32,
                                         device=intra_vertical_indices.device)
                        ])
                    slash_topk = slash_topk_buffer[
                        head_i, :heads_slash_size[head_i]]
                    intra_slash_indices = (
                        (qk.size(-1) - 1) -
                        slash_topk[slash_topk >= prev_chunk_end_pos])
                    # fill buffer
                    v_count = intra_vertical_indices.nelement()
                    s_count = intra_slash_indices.nelement()
                    vertical_size_buffer[head_i] = v_count
                    slash_sizes_buffer[head_i] = s_count
                    vertical_buffer[head_i, :v_count].copy_(
                        intra_vertical_indices)
                    slash_buffer[head_i, :s_count].copy_(intra_slash_indices)
                    # succ
                    if prev_chunk_end_pos - chunk_len >= 0:
                        succ_vertical_indices = vertical_topk[
                            (vertical_topk < prev_chunk_end_pos)
                            & (vertical_topk >= prev_chunk_end_pos -
                               chunk_len)] - (prev_chunk_end_pos - chunk_len)
                        # TODO: support no vertical
                        if succ_vertical_indices.nelement() == 0:
                            succ_vertical_indices = torch.cat([
                                succ_vertical_indices,
                                torch.arange(
                                    0,
                                    k_states_succ.size(0),
                                    max(1,
                                        k_states_succ.size(0) / 5),
                                    dtype=torch.int32,
                                    device=intra_vertical_indices.device)
                            ])
                        succ_slash_indices = (
                            (prev_chunk_end_pos + (qend - qbegin) - 1) -
                            slash_topk[((slash_topk >=
                                         (prev_chunk_end_pos - chunk_len)) &
                                        (slash_topk < (prev_chunk_end_pos +
                                                       (qend - qbegin))))])
                        if succ_slash_indices.nelement() == 0:
                            succ_slash_indices = torch.cat([
                                succ_slash_indices,
                                torch.arange(
                                    0,
                                    k_states_succ.size(0),
                                    max(1,
                                        k_states_succ.size(0) / 5),
                                    dtype=torch.int32,
                                    device=intra_vertical_indices.device)
                            ])
                        # fill buffer
                        v_count = succ_vertical_indices.nelement()
                        s_count = succ_slash_indices.nelement()
                        succ_vertical_size_buffer[head_i] = v_count
                        succ_slash_sizes_buffer[head_i] = s_count
                        succ_vertical_buffer[head_i, :v_count].copy_(
                            succ_vertical_indices)
                        succ_slash_buffer[head_i, :s_count].copy_(
                            succ_slash_indices)
                    if prev_chunk_end_pos - 2 * chunk_len >= 0:
                        inter_vertical_indices = vertical_topk[
                            vertical_topk < prev_chunk_end_pos - chunk_len]
                        if inter_vertical_indices.nelement() == 0:
                            inter_vertical_indices = torch.cat([
                                inter_vertical_indices,
                                torch.arange(
                                    0,
                                    k_states_inter.size(0),
                                    max(1,
                                        k_states_inter.size(0) / 5),
                                    dtype=torch.int32,
                                    device=intra_vertical_indices.device)
                            ])
                        inter_slash_indices = (
                            (prev_chunk_end_pos - chunk_len +
                             (qend - qbegin) - 1) -
                            slash_topk[slash_topk < (prev_chunk_end_pos -
                                                     chunk_len +
                                                     (qend - qbegin))])
                        if inter_slash_indices.nelement() == 0:
                            inter_slash_indices = torch.cat([
                                inter_slash_indices,
                                torch.arange(
                                    0,
                                    k_states_inter.size(0),
                                    max(1,
                                        k_states_inter.size(0) / 5),
                                    dtype=torch.int32,
                                    device=intra_vertical_indices.device)
                            ])
                        # fill buffer
                        v_count = inter_vertical_indices.nelement()
                        s_count = inter_slash_indices.nelement()
                        inter_vertical_size_buffer[head_i] = v_count
                        inter_slash_sizes_buffer[head_i] = s_count
                        inter_vertical_buffer[head_i, :v_count].copy_(
                            inter_vertical_indices)
                        inter_slash_buffer[head_i, :s_count].copy_(
                            inter_slash_indices)
            else:
                intra_vertical_indices, intra_slash_indices = None, None
                succ_vertical_indices, succ_slash_indices = None, None
                inter_vertical_indices, inter_slash_indices = None, None
            if sparse_attn_enabled:
                flash_result = self._do_flash_attn(
                    q_states_intra,
                    k_states_intra,
                    v_states_intra,
                    softmax_scale=softmax_scale,
                    causal=True,
                    stage="intra",
                    vertical_indices=vertical_buffer,
                    slash_indices=slash_buffer,
                    vertical_indices_count=vertical_size_buffer,
                    slash_indices_count=slash_sizes_buffer,
                    mergehead_softmax_scale=softmax_scale,
                    sparse_attn_enabled=sparse_attn_enabled)
            else:
                flash_result = self._do_flash_attn(
                    q_states_intra,
                    k_states_intra,
                    v_states_intra,
                    softmax_scale=softmax_scale,
                    causal=True,
                    stage="intra",
                    vertical_indices=intra_vertical_indices,
                    slash_indices=intra_slash_indices,
                    sparse_attn_enabled=sparse_attn_enabled)
            flash_per_chunk.append(flash_result)
            if prev_chunk_end_pos - chunk_len >= 0:
                if sparse_attn_enabled:
                    flash_result = self._do_flash_attn(
                        q_states_succ,
                        k_states_succ,
                        v_states_succ,
                        softmax_scale=softmax_scale,
                        causal=False,
                        stage="succ",
                        vertical_indices=succ_vertical_buffer,
                        slash_indices=succ_slash_buffer,
                        vertical_indices_count=succ_vertical_size_buffer,
                        slash_indices_count=succ_slash_sizes_buffer,
                        mergehead_softmax_scale=softmax_scale,
                        sparse_attn_enabled=sparse_attn_enabled)
                else:
                    flash_result = self._do_flash_attn(
                        q_states_succ,
                        k_states_succ,
                        v_states_succ,
                        softmax_scale=softmax_scale,
                        causal=False,
                        stage="succ",
                        vertical_indices=succ_vertical_indices,
                        slash_indices=succ_slash_indices,
                        sparse_attn_enabled=sparse_attn_enabled)
                flash_per_chunk.append(flash_result)
            if prev_chunk_end_pos - chunk_len * 2 >= 0:
                if sparse_attn_enabled:
                    flash_result = self._do_flash_attn(
                        q_states_inter,
                        k_states_inter,
                        v_states_inter,
                        softmax_scale=softmax_scale,
                        causal=False,
                        stage="inter",
                        vertical_indices=inter_vertical_buffer,
                        slash_indices=inter_slash_buffer,
                        vertical_indices_count=inter_vertical_size_buffer,
                        slash_indices_count=inter_slash_sizes_buffer,
                        mergehead_softmax_scale=softmax_scale,
                        sparse_attn_enabled=sparse_attn_enabled)
                else:
                    flash_result = self._do_flash_attn(
                        q_states_inter,
                        k_states_inter,
                        v_states_inter,
                        softmax_scale=softmax_scale,
                        causal=False,
                        stage="inter",
                        vertical_indices=inter_vertical_indices,
                        slash_indices=inter_slash_indices,
                        sparse_attn_enabled=sparse_attn_enabled)
                flash_per_chunk.append(flash_result)
            flash_results.append(flash_per_chunk)
            begin = end
        attn_output = self._merge_attn_outputs(flash_results)
        del flash_results
        return attn_output
    def _do_flash_attn(
        self,
        query_states: torch.Tensor,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        softmax_scale: float,
        causal: bool = True,
        max_seqlen_k: Optional[int] = None,
        stage: str = "intra",
        vertical_indices: Optional[torch.Tensor] = None,
        slash_indices: Optional[torch.Tensor] = None,
        vertical_indices_count: Optional[torch.Tensor] = None,
        slash_indices_count: Optional[torch.Tensor] = None,
        mergehead_softmax_scale: Optional[float] = None,
        sparse_attn_enabled: Optional[bool] = False,
    ):
        if max_seqlen_k is None:
            max_seqlen_k = key_states.shape[0]
        q_len = query_states.shape[0]
        q_heads = query_states.shape[1]
        h_dim = query_states.shape[-1]
        if sparse_attn_enabled:
            assert slash_indices is not None
            if stage == "intra":
                assert causal
            else:
                assert not causal
            query_states = query_states.unsqueeze(0).transpose(1, 2)
            key_states = key_states.unsqueeze(0).transpose(1, 2)
            value_states = value_states.unsqueeze(0).transpose(1, 2)
            q = query_states
            k = key_states
            v = value_states
            if (vertical_indices_count is not None and \
                    slash_indices_count is not None):
                assert mergehead_softmax_scale is not None
                res, s_lse = _vertical_slash_sparse_attention(
                    q,
                    k,
                    v,
                    vertical_indices,
                    slash_indices,
                    mergehead_softmax_scale,
                    causal=causal,
                    stage=stage,
                    vertical_indices_count=vertical_indices_count,
                    slash_indices_count=slash_indices_count)
                res = res.view(q_heads, q_len,
                               h_dim).transpose(0, 1)  # (qlen,nhead,h_dim)
                s_lse = s_lse.view(
                    q_heads, q_len,
                    1).squeeze(-1).unsqueeze(0).float()  # (1, nhead,qlen)
            else:
                res, s_lse = _vertical_slash_sparse_attention(q,
                                                              k,
                                                              v,
                                                              vertical_indices,
                                                              slash_indices,
                                                              softmax_scale,
                                                              causal=causal,
                                                              stage=stage)
                res = res.view(q_len, q_heads, h_dim)
                s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float()
            return res, s_lse
        output, softmax_lse = flash_attn_varlen_func(
            q=query_states,
            k=key_states,
            v=value_states,
            softmax_scale=softmax_scale,
            cu_seqlens_q=torch.tensor([0, query_states.shape[0]],
                                      dtype=torch.int32,
                                      device=query_states.device),
            max_seqlen_q=query_states.shape[0],
            cu_seqlens_k=torch.tensor([0, max_seqlen_k],
                                      dtype=torch.int32,
                                      device=query_states.device),
            max_seqlen_k=max_seqlen_k,
            causal=causal,
            return_softmax_lse=True,
        )
        softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,
                                                                    2).float()
        return output, softmax_lse
    def _merge_attn_outputs(
        self,
        flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]],
        return_lse: Optional[bool] = False,
    ) -> torch.Tensor:
        attn_outputs_all = []
        logits_all = []
        for flash_per_chunk in flash_results:
            if len(flash_per_chunk) == 1:
                attn_outputs_all.append(flash_per_chunk[0][0])
                if return_lse:
                    logits_all.append(flash_per_chunk[0][1])
                continue
            attn_outputs = torch.stack([
                flash_attn_output[0] for flash_attn_output in flash_per_chunk
            ])
            logits = torch.stack([
                flash_attn_output[1] for flash_attn_output in flash_per_chunk
            ])
            logits = logits.to(torch.float32)
            if return_lse:
                max_val = torch.max(logits, dim=0).values
                diff = torch.abs(logits[0] - logits[1])
                log_sum_exp = max_val + torch.log1p(torch.exp(-diff))
                logits_all.append(log_sum_exp)
            max_logits = torch.max(logits, dim=0).values
            stable_logits = logits - max_logits.unsqueeze(0)
            lse_s = torch.exp(stable_logits).detach()
            lse_sum = torch.sum(lse_s, dim=0)
            lse_s /= lse_sum
            attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1)
            attn_outputs_all.append(attn_outputs.sum(dim=0))
        if return_lse:
            return (torch.cat(attn_outputs_all,
                              dim=0), torch.cat(logits_all, dim=-1))
        else:
            return torch.cat(attn_outputs_all, dim=0)
    def _dual_chunk_flash_attn_decoding(
        self,
        query: torch.Tensor,
        query_succ: torch.Tensor,
        query_inter: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_table: torch.Tensor,
        cache_seqlens: torch.Tensor,
        softmax_scale: float,
        causal: bool,
        alibi_slopes: Optional[torch.Tensor],
        chunk_size: int,
        local_size: int,
        original_max_position_embeddings: int,
        decode_meta: DualChunkFlashAttentionMetadata,
    ):
        if not causal:
            raise ValueError(
                "Dual Chunk Attention does not support causal=False")
        block_size = value_cache.shape[1]
        chunk_len = chunk_size - local_size
        if chunk_len % block_size != 0:
            raise ValueError("chunk_len must be divisible by block_size.")
        if original_max_position_embeddings > 0:
            assert decode_meta.scaling_factor is not None
            scaling_factor = decode_meta.scaling_factor
            query = (query * scaling_factor.view(-1, 1, 1, 1)).to(
                query.dtype
            )  # possible for numerical issue, need to fused in the kernel
            query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to(
                query.dtype)
            query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to(
                query.dtype)
        outputs_list = []
        softmax_lses_list = []
        # intra-attention
        intra_output, intra_softmax_lse = (
            self._dual_chunk_flash_attn_decoding_with_exp_sums(
                query,
                key_cache,
                value_cache,
                decode_meta.block_tables_intra,
                decode_meta.seq_lens_intra,
                softmax_scale,
                alibi_slopes,
                causal=False,
            ))
        outputs_list.append(intra_output)
        softmax_lses_list.append(intra_softmax_lse)
        # succ-attention
        if decode_meta.max_seq_len_succ:
            succ_output, succ_softmax_lse = (
                self._dual_chunk_flash_attn_decoding_with_exp_sums(
                    query_succ,
                    key_cache,
                    value_cache,
                    decode_meta.block_tables_succ,
                    decode_meta.seq_lens_succ,
                    softmax_scale,
                    alibi_slopes,
                    causal=False,
                ))
            outputs_list.append(succ_output)
            softmax_lses_list.append(succ_softmax_lse)
        # inter-attention
        if decode_meta.max_seq_len_inter:
            inter_output, inter_softmax_lse = (
                self._dual_chunk_flash_attn_decoding_with_exp_sums(
                    query_inter,
                    key_cache,
                    value_cache,
                    block_table[:, :decode_meta.max_seq_len_inter],
                    decode_meta.seq_lens_inter,
                    softmax_scale,
                    alibi_slopes,
                    causal=False,
                ))
            outputs_list.append(inter_output)
            softmax_lses_list.append(inter_softmax_lse)
        outputs = torch.stack(outputs_list, dim=0)
        del outputs_list
        softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32)
        del softmax_lses_list
        max_logits = torch.max(softmax_lses, dim=0).values
        stable_logits = softmax_lses - max_logits.unsqueeze(0)
        lse_s = torch.exp(stable_logits).detach()
        lse_sum = torch.sum(lse_s, dim=0)
        lse_s /= lse_sum
        outputs *= lse_s.unsqueeze(-1).transpose(2, 3)
        return outputs.sum(0)
    def _dual_chunk_flash_attn_decoding_with_exp_sums(
        self,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_table: torch.Tensor,
        cache_seqlens: torch.Tensor,
        softmax_scale: float,
        alibi_slopes: Optional[torch.Tensor],
        causal: bool,
    ):
        out, softmax_lse = flash_attn_with_kvcache(
            q=query,
            k_cache=key_cache,
            v_cache=value_cache,
            block_table=block_table,
            cache_seqlens=cache_seqlens,
            softmax_scale=softmax_scale,
            alibi_slopes=alibi_slopes,
            causal=causal,
            return_softmax_lse=True,
        )
        mask = (cache_seqlens == 0)
        out[mask] = 0
        softmax_lse[mask] = -float("inf")
        return out, softmax_lse