Fix embedding quantization
This commit is contained in:
parent
bfb1a8f2b6
commit
5fc46d22f7
|
@ -1408,6 +1408,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
|
||||
self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
|
||||
|
||||
if self.device == torch.device("cpu"):
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = torch.half
|
||||
|
||||
if quantize_embeddings:
|
||||
logger.info("Applying quantization to embeddings")
|
||||
self.transformer.word_embeddings = QuantizedEmbedding(
|
||||
|
@ -1415,11 +1420,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
|
||||
num_embeddings=self.transformer.word_embeddings.num_embeddings,
|
||||
embedding_dim=self.transformer.word_embeddings.embedding_dim,
|
||||
dtype=torch.half,
|
||||
empty_init=True,
|
||||
dtype=dtype,
|
||||
empty_init=empty_init,
|
||||
device=self.transformer.word_embeddings.weight.device,
|
||||
)
|
||||
self.lm_head = QuantizedLinear(
|
||||
self.lm_head = QuantizedLinear(
|
||||
weight_bit_width=bits,
|
||||
weight_tensor=self.lm_head.weight.to(self.device),
|
||||
bias_tensor=None,
|
||||
|
@ -1428,8 +1433,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
bias=False,
|
||||
quantized_weight=self.transformer.word_embeddings.weight,
|
||||
quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
|
||||
dtype=torch.half,
|
||||
empty_init=True,
|
||||
dtype=dtype,
|
||||
empty_init=empty_init,
|
||||
device=self.lm_head.weight.device,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue