Fix embedding quantization
This commit is contained in:
parent
bfb1a8f2b6
commit
5fc46d22f7
|
@ -1408,6 +1408,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
|
self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
|
||||||
|
|
||||||
|
if self.device == torch.device("cpu"):
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = torch.half
|
||||||
|
|
||||||
if quantize_embeddings:
|
if quantize_embeddings:
|
||||||
logger.info("Applying quantization to embeddings")
|
logger.info("Applying quantization to embeddings")
|
||||||
self.transformer.word_embeddings = QuantizedEmbedding(
|
self.transformer.word_embeddings = QuantizedEmbedding(
|
||||||
|
@ -1415,8 +1420,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
|
weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
|
||||||
num_embeddings=self.transformer.word_embeddings.num_embeddings,
|
num_embeddings=self.transformer.word_embeddings.num_embeddings,
|
||||||
embedding_dim=self.transformer.word_embeddings.embedding_dim,
|
embedding_dim=self.transformer.word_embeddings.embedding_dim,
|
||||||
dtype=torch.half,
|
dtype=dtype,
|
||||||
empty_init=True,
|
empty_init=empty_init,
|
||||||
device=self.transformer.word_embeddings.weight.device,
|
device=self.transformer.word_embeddings.weight.device,
|
||||||
)
|
)
|
||||||
self.lm_head = QuantizedLinear(
|
self.lm_head = QuantizedLinear(
|
||||||
|
@ -1428,8 +1433,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
bias=False,
|
bias=False,
|
||||||
quantized_weight=self.transformer.word_embeddings.weight,
|
quantized_weight=self.transformer.word_embeddings.weight,
|
||||||
quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
|
quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
|
||||||
dtype=torch.half,
|
dtype=dtype,
|
||||||
empty_init=True,
|
empty_init=empty_init,
|
||||||
device=self.lm_head.weight.device,
|
device=self.lm_head.weight.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue