@CustomOp.register("mamba_mixer2")
class MambaMixer2(MambaBase, CustomOp):
    """
    Compute ∆, A, B, C, and D the state space parameters and compute
    the `contextualized_states`. A, D are input independent
    (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
    for why A isn't selective) ∆, B, C are input-dependent
    (this is a key difference between Mamba and the linear time
    invariant S4, and is why Mamba is called
    **selective** state spaces)
    """
    def __init__(
        self,
        hidden_size: int,
        ssm_state_size: int,
        conv_kernel_size: int,
        intermediate_size: int,
        use_conv_bias: bool,
        use_bias: bool,
        n_groups: int = 1,
        num_heads: int = 128,
        head_dim: int = 64,
        rms_norm_eps: float = 1e-5,
        activation: str = "silu",
        use_rms_norm: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        # For TP, the sharding plan is as follows:
        # - for the conv modules, since
        #   conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
        #   we shard intermediate_size and n_groups
        # - since intermediate_size = n_heads * head_dim, sharding on
        #   intermediate_size is achieved by sharding on n_heads.
        # - IF, world_size divides groups, then sharding
        #   (n_groups / world_size, n_heads / world_size)
        #   also maintains the invariant n_heads % n_groups == 0
        # - HOWEVER IF, world_size DOES NOT divide groups, then we need
        #   to allocate extra space in the shard, such that groups
        #   may be replicated to follow the head shard.
        # - NOTE: currently for the world size DOES NOT divide groups
        #   case, we only support the case when n_groups == 1
        self.tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        assert (num_heads % self.tp_size == 0
                ), "Tensor parallel world size must divide num heads."
        assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
            "If tensor parallel world size does not divide num_heads, "
            "then num_groups must equal 1.")
        assert (
            self.tp_size == 1 or quant_config is None
        ), "Tensor parallel currently not supported for quantized models."
        self.ssm_state_size = ssm_state_size
        self.conv_kernel_size = conv_kernel_size
        self.activation = activation
        self.intermediate_size = intermediate_size
        self.head_dim = head_dim
        self.num_heads = num_heads
        self.n_groups = n_groups
        if n_groups % self.tp_size != 0:
            # - for TP we shard conv_dim by sharding on n_groups,
            # - but if n_groups cannot divide tp_size, we need to
            #   extend some extra groups
            groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
                n_groups, self.tp_size)
            self.n_groups = n_groups + groups
        self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
        self.conv1d = ColumnParallelLinear(
            input_size=conv_kernel_size,
            output_size=self.conv_dim,
            bias=use_conv_bias,
            quant_config=None,
        )
        # unsqueeze to fit conv1d weights shape into the linear weights shape.
        # Can't do this in `weight_loader` since it already exists in
        # `ColumnParallelLinear` and `set_weight_attrs`
        # doesn't allow to override it
        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
        self.in_proj = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=intermediate_size + self.conv_dim + self.num_heads,
            bias=use_bias,
            quant_config=quant_config,
        )
        # - because in_proj is a concatenation of 3 weights, we
        #   need to interleave them before sharding
        # - use the custom weight loader mamba_v2_sharded_weight_loader
        #   for conv1d.bias, covn1d.weight and in_proj.weight
        # - need to set these settings, to assign the groups to the head shards
        group_shard_settings = (
            self.n_groups * self.ssm_state_size,  # expected model size
            (self.n_groups - n_groups) *
            self.ssm_state_size,  # extra dims assigned
            n_groups == 1,  # if there was only one group
        )
        intermediate_settings = (intermediate_size, 0, False)
        head_settings = (self.num_heads, 0, False)
        # - the weight already has a "weight_loader" attribute
        #   which set_weight_attrs will raise if we do not
        #   delete before trying to override it
        # - ditto for the otther two weights below
        delattr(self.conv1d.bias, "weight_loader")
        set_weight_attrs(
            self.conv1d.bias,
            {
                "weight_loader":
                mamba_v2_sharded_weight_loader(
                    [
                        intermediate_settings,
                        group_shard_settings,
                        group_shard_settings,
                    ],
                    self.tp_size,
                    tp_rank,
                )
            },
        )
        delattr(self.conv1d.weight, "weight_loader")
        set_weight_attrs(
            self.conv1d.weight,
            {
                "weight_loader":
                mamba_v2_sharded_weight_loader(
                    [
                        intermediate_settings,
                        group_shard_settings,
                        group_shard_settings,
                    ],
                    self.tp_size,
                    tp_rank,
                )
            },
        )
        if quant_config is None:
            # - quant layers do not have a weight loader
            delattr(self.in_proj.weight, "weight_loader")
            set_weight_attrs(
                self.in_proj.weight,
                {
                    "weight_loader":
                    mamba_v2_sharded_weight_loader(
                        [
                            intermediate_settings,  # for gate
                            intermediate_settings,
                            group_shard_settings,
                            group_shard_settings,
                            head_settings,  # for dt
                        ],
                        self.tp_size,
                        tp_rank,
                    )
                },
            )
        # - these are TPed by heads to reduce the size of the
        #   temporal shape
        self.A = nn.Parameter(
            torch.empty(
                divide(num_heads, self.tp_size),
                dtype=torch.float32,
            ))
        self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
        self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
        self.use_rms_norm = use_rms_norm
        set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
        a_weight_loader = composed_weight_loader(
            sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
        set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
        set_weight_attrs(self.dt_bias,
                         {"weight_loader": sharded_weight_loader(0)})
        self.out_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=use_bias,
            input_is_parallel=True,
            quant_config=quant_config,
        )
        self.norm = Mixer2RMSNormGated(intermediate_size,
                                       n_groups,
                                       self.use_rms_norm,
                                       eps=rms_norm_eps)
        if envs.VLLM_USE_V1:
            compilation_config = get_current_vllm_config().compilation_config
            if prefix in compilation_config.static_forward_context:
                raise ValueError(f"Duplicate layer name: {prefix}")
            compilation_config.static_forward_context[prefix] = self
            # The outer list is for v0 PP virtual engine. Though this code path
            # only runs for v1, we have to do this to unify with the interface
            # of Attention + v0 PP.
            # The inner tuple is (conv_state, ssm_state)
            self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
        self.prefix = prefix
    def forward_native(
        self,
        hidden_states: torch.Tensor,
        output: torch.Tensor,
        mamba_cache_params: MambaCacheParams,
        mamba2_metadata: Mamba2Metadata,
        mup_vector: Optional[torch.Tensor] = None,
    ):
        pass
    def forward(
        self,
        hidden_states: torch.Tensor,
        output: torch.Tensor,
        mamba_cache_params: MambaCacheParams,
        mamba2_metadata: Mamba2Metadata,
        mup_vector: Optional[torch.Tensor] = None,
    ):
        if not envs.VLLM_USE_V1:
            CustomOp.forward(self, hidden_states, output, mamba_cache_params,
                             mamba2_metadata, mup_vector)
        else:
            torch.ops.vllm.mamba_mixer2(
                hidden_states,
                output,
                self.prefix,
                mup_vector,
            )
    def forward_cuda(
        self,
        hidden_states: torch.Tensor,
        output: torch.Tensor,
        mamba_cache_params: MambaCacheParams,
        mamba2_metadata: Mamba2Metadata,
        mup_vector: Optional[torch.Tensor] = None,
    ):
        forward_context = get_forward_context()
        # mamba2_metadata contains metadata necessary for the mamba2 triton
        # kernels to operate in continuous batching and in chunked prefill
        # modes; they are computed at top-level model forward since they
        # stay the same and reused for all mamba layers in the same iteration
        attn_metadata: AttentionMetadata = forward_context.attn_metadata
        if envs.VLLM_USE_V1:
            if attn_metadata is not None:
                assert isinstance(attn_metadata, dict)
                attn_metadata = attn_metadata[self.prefix]
                mamba2_metadata = attn_metadata
                assert isinstance(attn_metadata, Mamba2AttentionMetadata)
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                # conv_state = (..., dim, width-1) yet contiguous along 'dim'
                conv_state = self_kv_cache[0].transpose(-1, -2)
                ssm_state = self_kv_cache[1]
                state_indices_tensor = attn_metadata.state_indices_tensor
                has_initial_states_p = attn_metadata.has_initial_states
                prep_initial_states = attn_metadata.prep_initial_states
                chunk_size = attn_metadata.chunk_size
                seq_idx_p = attn_metadata.seq_idx
                chunk_indices_p = attn_metadata.chunk_indices
                chunk_offsets_p = attn_metadata.chunk_offsets
        else:
            conv_state = mamba_cache_params.conv_state
            ssm_state = mamba_cache_params.ssm_state
            state_indices_tensor = mamba_cache_params.state_indices_tensor
            has_initial_states_p = mamba2_metadata.has_initial_states
            prep_initial_states = mamba2_metadata.prep_initial_states
            chunk_size = mamba2_metadata.chunk_size
            seq_idx_p = mamba2_metadata.seq_idx
            chunk_indices_p = mamba2_metadata.chunk_indices
            chunk_offsets_p = mamba2_metadata.chunk_offsets
        groups_time_state_size = self.n_groups * self.ssm_state_size
        # 1. Gated MLP's linear projection
        projected_states, _ = self.in_proj(hidden_states)
        if mup_vector is not None:
            projected_states = projected_states * mup_vector
        gate, hidden_states_B_C, dt = torch.split(
            projected_states,
            [
                self.intermediate_size // self.tp_size,
                self.conv_dim // self.tp_size,
                self.num_heads // self.tp_size,
            ],
            dim=-1,
        )
        conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
                                               self.conv1d.weight.size(2))
        # - get hidden_states, B and C after depthwise convolution.
        split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
            hidden_states_B_C,
            [
                self.intermediate_size // self.tp_size,
                groups_time_state_size // self.tp_size,
                groups_time_state_size // self.tp_size,
            ],
            dim=-1,
        )
        if envs.VLLM_USE_V1 and attn_metadata is None:
            # V1 profile run
            hidden_states_B_C = (hidden_states_B_C.transpose(
                0, 1).clone().transpose(0, 1)).contiguous()
            hidden_states, _B, _C = split_hidden_states_B_C_fn(
                hidden_states_B_C)
            hidden_states = self.norm(hidden_states, gate)
            out, _ = self.out_proj(hidden_states)
            return out
        num_prefills = attn_metadata.num_prefills  # request count
        num_decodes = attn_metadata.num_decode_tokens  # token count (=request)
        num_prefill_tokens = attn_metadata.num_prefill_tokens  # token count
        has_prefill = num_prefills > 0
        has_decode = num_decodes > 0
        num_actual_tokens = num_prefill_tokens + num_decodes
        # NOTE: V0 put prefill before decode, v1 puts decode before prefill
        # Separate prefill and decode by splitting varlen input
        # Split along token dimension
        if envs.VLLM_USE_V1:
            hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
                hidden_states_B_C[:num_actual_tokens],
                [num_decodes, num_prefill_tokens],
                dim=0,
            )
            dt_d, dt_p = torch.split(
                dt[:num_actual_tokens],
                [num_decodes, num_prefill_tokens],
                dim=0,
            )
            # Split along batch dimension
            state_indices_tensor_d, state_indices_tensor_p = torch.split(
                state_indices_tensor[:num_actual_tokens],
                [num_decodes, num_prefills],
                dim=0,
            )
            query_start_loc_p = (
                attn_metadata.query_start_loc[-num_prefills - 1:] -
                num_decodes if has_prefill else None)
        else:
            hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
                hidden_states_B_C,
                [num_prefill_tokens, num_decodes],
                dim=0,
            )
            dt_p, dt_d = torch.split(
                dt,
                [num_prefill_tokens, num_decodes],
                dim=0,
            )
            # Split along batch dimension
            state_indices_tensor_p, state_indices_tensor_d = torch.split(
                state_indices_tensor,
                [num_prefills, num_decodes],
                dim=0,
            )
            query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
                                                               1]
                                 if has_prefill else None)
        # Preallocate output tensor to avoid memcpy cost for merging prefill
        # and decode outputs
        preallocated_ssm_out = torch.empty(
            [
                num_prefill_tokens + num_decodes,
                (self.num_heads // self.tp_size) * self.head_dim
            ],
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )
        if envs.VLLM_USE_V1:
            preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
                preallocated_ssm_out,
                [num_decodes, num_prefill_tokens],
                dim=0,
            )
        else:
            preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
                preallocated_ssm_out,
                [num_prefill_tokens, num_decodes],
                dim=0,
            )
        # Process prefill requests
        if has_prefill:
            # 2. Convolution sequence transformation
            # - "cache_indices" updates the conv_state cache in positions
            #   pointed to by "state_indices_tensor"
            x = hidden_states_B_C_p.transpose(
                0, 1)  # this is the form that causal-conv see
            if mamba2_metadata.cu_seqlen is None:
                mamba2_metadata = update_metadata(x, query_start_loc_p,
                                                  mamba2_metadata)
            hidden_states_B_C_p = causal_conv1d_fn(
                x,
                conv_weights,
                self.conv1d.bias,
                activation=self.activation,
                conv_states=conv_state,
                has_initial_state=has_initial_states_p,
                cache_indices=state_indices_tensor_p,
                metadata=mamba2_metadata,
                query_start_loc=query_start_loc_p).transpose(
                    0, 1)[:num_prefill_tokens]
            hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
                hidden_states_B_C_p)
            # 3. State Space Model sequence transformation
            initial_states = None
            if (has_initial_states_p is not None and prep_initial_states):
                # making a copy of the states
                if envs.VLLM_USE_V1:
                    initial_states = torch.where(
                        has_initial_states_p[:, None, None, None],
                        ssm_state[state_indices_tensor_p], 0)
                else:
                    initial_states = torch.where(
                        has_initial_states_p[:num_prefills, None, None, None],
                        ssm_state[state_indices_tensor_p], 0)
            # NOTE: final output is an in-place update of out tensor
            varlen_state = mamba_chunk_scan_combined(
                hidden_states_p.view(1, num_prefill_tokens,
                                     self.num_heads // self.tp_size,
                                     self.head_dim),
                dt_p.unsqueeze(0),
                self.A,
                B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
                         -1),
                C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
                         -1),
                chunk_size=chunk_size,
                D=self.D,
                z=None,
                dt_bias=self.dt_bias,
                seq_idx=seq_idx_p,
                chunk_indices=chunk_indices_p,
                chunk_offsets=chunk_offsets_p,
                cu_seqlens=query_start_loc_p,
                initial_states=initial_states,
                return_varlen_states=True,
                return_final_states=False,
                dt_softplus=True,
                dt_limit=(0.0, float("inf")),
                out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
                                                self.head_dim),
            )
            # update ssm states
            # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
            ssm_state[state_indices_tensor_p] = varlen_state
        # Process decode requests
        if has_decode:
            # 2. Convolution sequence transformation
            hidden_states_B_C_d = causal_conv1d_update(
                hidden_states_B_C_d,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
                conv_state_indices=state_indices_tensor_d)
            hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
                hidden_states_B_C_d)
            # 3. State Space Model sequence transformation
            n_groups = self.n_groups // self.tp_size
            A_d = self.A[:, None, ...][:, :, None].expand(
                -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
            dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
            dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
            D_d = self.D[:, None, ...].expand(-1, self.head_dim)
            B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
            C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
            hidden_states_d = hidden_states_d.view(
                -1, self.num_heads // self.tp_size, self.head_dim)
            # - the hidden is reshaped into (bs, num_heads, head_dim)
            # - mamba_cache_params.ssm_state's slots will be selected
            #   using state_indices_tensor_d
            # NOTE: final output is an in-place update of out tensor
            selective_state_update(
                ssm_state,
                hidden_states_d,
                dt_d,
                A_d,
                B_d,
                C_d,
                D_d,
                z=None,
                dt_bias=dt_bias,
                dt_softplus=True,
                state_batch_indices=state_indices_tensor_d,
                out=preallocated_ssm_out_d.view(num_decodes, -1,
                                                self.head_dim),
            )
        # 4. gated MLP
        # GatedRMSNorm internally applying SiLU to the gate
        # SiLU is applied internally before normalization, unlike standard
        # norm usage
        hidden_states = self.norm(preallocated_ssm_out,
                                  gate[:num_actual_tokens])
        # 5. Final linear projection
        output[:num_actual_tokens], _ = self.out_proj(hidden_states)
    def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
        return MambaStateShapeCalculator.mamba2_state_shape(
            intermediate_size=self.intermediate_size,
            tp_world_size=get_tensor_model_parallel_world_size(),
            n_groups=self.n_groups,
            num_heads=self.num_heads,
            head_dim=self.head_dim,
            state_size=self.ssm_state_size,
            conv_kernel=self.conv_kernel_size,
        )
    @property
    def mamba_type(self) -> str:
        return "mamba2"