class MinimaxCacheManager(ConstantSizeCache):
    def __init__(self, dtype, cache_shape):
        super().__init__(cache_shape[1])  # max_batch_size is cache_shape[1]
        self._minimax_cache = torch.empty(size=cache_shape,
                                          dtype=dtype,
                                          device="cuda")
    @property
    def cache(self):
        return self._minimax_cache
    def _copy_cache(self, from_index: int, to_index: int):
        assert len(self.cache) > 0
        for cache_t in self.cache:
            cache_t[:, to_index].copy_(cache_t[:, from_index],
                                       non_blocking=True)