diff --git a/modeling_chatglm.py b/modeling_chatglm.py index dd9fb26..f431ff5 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -1408,6 +1408,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs) + if self.device == torch.device("cpu"): + dtype = torch.float32 + else: + dtype = torch.half + if quantize_embeddings: logger.info("Applying quantization to embeddings") self.transformer.word_embeddings = QuantizedEmbedding( @@ -1415,11 +1420,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): weight_tensor=self.transformer.word_embeddings.weight.to(self.device), num_embeddings=self.transformer.word_embeddings.num_embeddings, embedding_dim=self.transformer.word_embeddings.embedding_dim, - dtype=torch.half, - empty_init=True, + dtype=dtype, + empty_init=empty_init, device=self.transformer.word_embeddings.weight.device, ) - self.lm_head = QuantizedLinear( + self.lm_head = QuantizedLinear( weight_bit_width=bits, weight_tensor=self.lm_head.weight.to(self.device), bias_tensor=None, @@ -1428,8 +1433,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): bias=False, quantized_weight=self.transformer.word_embeddings.weight, quantized_weight_scale=self.transformer.word_embeddings.weight_scale, - dtype=torch.half, - empty_init=True, + dtype=dtype, + empty_init=empty_init, device=self.lm_head.weight.device, )