class TPUModelLoader(DefaultModelLoader):
    """
    A TPU model loader for model loading under SPMD mode.
    """
    def load_model(
        self,
        vllm_config: VllmConfig,
        model_config: ModelConfig,
        mesh: Optional[xs.Mesh] = None,
    ) -> nn.Module:
        # Initialize model and load weights on CPU. Then, during SPMD partition,
        # weights are sharded and transferred to TPUs.
        self.counter_before_loading_weights = time.perf_counter()
        model_config = vllm_config.model_config
        assert model_config.quantization is None, "Quantization not supported"
        target_device = torch.device('cpu')
        with set_default_torch_dtype(model_config.dtype):
            with target_device:
                model = initialize_model(vllm_config=vllm_config)
            load_format = vllm_config.load_config.load_format
            if load_format != "dummy":
                weights_to_load = {
                    name
                    for name, _ in model.named_parameters()
                }
                all_weights = self.get_all_weights(model_config, model)
                loaded_weights = model.load_weights(all_weights)
                self.counter_after_loading_weights = time.perf_counter()
                logger.info(
                    "Loading weights took %.2f seconds",
                    self.counter_after_loading_weights -
                    self.counter_before_loading_weights)
                # We only enable strict check for non-quantized models
                # that have loaded weights tracking currently.
                if model_config.quantization is None and \
                    loaded_weights is not None:
                    weights_not_loaded = weights_to_load - loaded_weights
                    if weights_not_loaded:
                        raise ValueError(
                            "Following weights were not initialized from "
                            f"checkpoint: {weights_not_loaded}")
            else:
                logger.info("Use dummy weight during weight loading.")
            process_weights_after_loading(model, model_config, target_device)
        counter_before_partition = time.perf_counter()
        model = model.eval()
        model = model.to('xla')
        shard_model(model, mesh)
        counter_after_partition = time.perf_counter()
        logger.info("Partition model took %.2f seconds",
                    counter_after_partition - counter_before_partition)
        # Ensure the model is properly loaded.
        self._check_model_is_loaded(mesh, model)
        # Need to torch compile after model sharding are done. Because the
        # compiler hints ('xs.mark_sharding') are torch ops.
        if not model_config.is_multimodal_model:
            model.model = torch.compile(model.model, backend="openxla")
        else:
            model.language_model.model = \
                torch.compile(model.language_model.model, backend="openxla")
        return model
    def _check_model_is_loaded(self, mesh: Optional[xs.Mesh],
                               model: nn.Module) -> None:
        """
        Ensure the model is properly loaded.
        1. All model parameters and buffers are on XLA device.
        2. Non-SPMD friendly layers are replaced as expected.
        """
        device = xm.xla_device()
        device_type = str(device.type)
        # Check parameters
        for name, param in model.named_parameters():
            assert param.device.type == device_type, f"Parameter {name} is on \
                {param.device.type} instead of {device_type}"
        # Check buffers
        for name, buffer in model.named_buffers():
            assert buffer.device.type == device_type, \
                f"Buffer {name} is on {buffer.device.type} instead of \
                    {device_type}"
        for module in model.modules():
            if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'):
                raise AssertionError("QKVParallelLinear should be replaced by \
                            XlaQKVParallelLinear under SPMD mode.")