def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
    logger.warning_once(
        "Your GPU does not have native support for FP4 computation but "
        "FP4 quantization is being used. Weight-only FP4 compression will "
        "be used leveraging the Marlin kernel. This may degrade "
        "performance for compute-heavy workloads.")
    e = layer.num_experts
    k = layer.hidden_size
    n = layer.intermediate_size_per_partition
    # WORKSPACE
    device = layer.w13_weight.device
    param_dtype = layer.params_dtype
    layer.workspace = marlin_make_workspace_new(device, 4)
    perm = torch.empty(0, dtype=torch.int, device=device)
    # WEIGHT
    # Repack weights to marlin format
    for name in ["w13_weight", "w2_weight"]:
        weight = getattr(layer, name)
        tensor_list = []
        if "w13" in name:
            size_n, size_k = n * 2, k
        else:
            size_n, size_k = k, n
        assert weight.shape == (e, size_n, size_k // 2)
        for i in range(e):
            qweight = weight[i].view(torch.int32).T.contiguous()
            marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
                                                    perm=perm,
                                                    size_k=size_k,
                                                    size_n=size_n,
                                                    num_bits=4)
            tensor_list.append(marlin_qweight)
        weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
        weight = torch.nn.Parameter(weight, requires_grad=False)
        setattr(layer, name, weight)
    # WEIGHT SCALES
    # Permute scales
    for name in ["w13", "w2"]:
        scales = getattr(layer, name + "_weight_scale").to(param_dtype)
        global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
        tensor_list = []
        if "w13" in name:
            size_n, size_k = n * 2, k
        else:
            size_n, size_k = k, n
        for i in range(e):
            marlin_scales = marlin_permute_scales(s=scales[i].T,
                                                  size_k=size_k,
                                                  size_n=size_n,
                                                  group_size=16)
            marlin_scales = fp4_marlin_process_scales(marlin_scales)
            tensor_list.append(marlin_scales)
        scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
        scales = torch.nn.Parameter(scales, requires_grad=False)
        setattr(layer, name + "_weight_scale", scales)
        global_scale = fp4_marlin_process_global_scale(global_scale)
        global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
        setattr(layer, name + "_weight_scale_2", global_scale)