Fix embedding quantization

This commit is contained in:
duzx16 2023-04-07 23:34:41 +08:00
parent bfb1a8f2b6
commit 5fc46d22f7
1 changed files with 10 additions and 5 deletions

View File

@ -1408,6 +1408,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs) self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
if self.device == torch.device("cpu"):
dtype = torch.float32
else:
dtype = torch.half
if quantize_embeddings: if quantize_embeddings:
logger.info("Applying quantization to embeddings") logger.info("Applying quantization to embeddings")
self.transformer.word_embeddings = QuantizedEmbedding( self.transformer.word_embeddings = QuantizedEmbedding(
@ -1415,8 +1420,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
weight_tensor=self.transformer.word_embeddings.weight.to(self.device), weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
num_embeddings=self.transformer.word_embeddings.num_embeddings, num_embeddings=self.transformer.word_embeddings.num_embeddings,
embedding_dim=self.transformer.word_embeddings.embedding_dim, embedding_dim=self.transformer.word_embeddings.embedding_dim,
dtype=torch.half, dtype=dtype,
empty_init=True, empty_init=empty_init,
device=self.transformer.word_embeddings.weight.device, device=self.transformer.word_embeddings.weight.device,
) )
self.lm_head = QuantizedLinear( self.lm_head = QuantizedLinear(
@ -1428,8 +1433,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
bias=False, bias=False,
quantized_weight=self.transformer.word_embeddings.weight, quantized_weight=self.transformer.word_embeddings.weight,
quantized_weight_scale=self.transformer.word_embeddings.weight_scale, quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
dtype=torch.half, dtype=dtype,
empty_init=True, empty_init=empty_init,
device=self.lm_head.weight.device, device=self.lm_head.weight.device,
) )