Add support for loading quantized model

This commit is contained in:
duzx16 2023-03-31 10:48:38 +08:00
parent c949d03152
commit 2e1be30ac4
3 changed files with 77 additions and 34 deletions

View File

@ -70,6 +70,7 @@ class ChatGLMConfig(PretrainedConfig):
max_sequence_length=2048, max_sequence_length=2048,
inner_hidden_size=16384, inner_hidden_size=16384,
position_encoding_2d=True, position_encoding_2d=True,
quantization_bit=0,
pre_seq_len=None, pre_seq_len=None,
prefix_projection=False, prefix_projection=False,
**kwargs **kwargs
@ -86,8 +87,10 @@ class ChatGLMConfig(PretrainedConfig):
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.position_encoding_2d = position_encoding_2d self.position_encoding_2d = position_encoding_2d
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection self.prefix_projection = prefix_projection
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,

View File

@ -139,6 +139,7 @@ class PrefixEncoder(torch.nn.Module):
Input shape: (batch-size, prefix-length) Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden) Output shape: (batch-size, prefix-length, 2*layers*hidden)
""" """
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.prefix_projection = config.prefix_projection self.prefix_projection = config.prefix_projection
@ -216,6 +217,13 @@ class RotaryEmbedding(torch.nn.Module):
self.cos_cached, self.sin_cached = cos_cached, sin_cached self.cos_cached, self.sin_cached = cos_cached, sin_cached
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
def _apply(self, fn):
if self.cos_cached is not None:
self.cos_cached = fn(self.cos_cached)
if self.sin_cached is not None:
self.sin_cached = fn(self.sin_cached)
return super()._apply(fn)
def rotate_half(x): def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
@ -931,7 +939,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
gmask=use_gmask gmask=use_gmask
) )
# [seq_len, batch, hidden_size] # [seq_len, batch, hidden_size]
hidden_states = inputs_embeds.transpose(0, 1) hidden_states = inputs_embeds.transpose(0, 1)
@ -999,7 +1006,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
def __init__(self, config): def __init__(self, config: ChatGLMConfig):
super().__init__(config) super().__init__(config)
# self.hidden_size = config.hidden_size # self.hidden_size = config.hidden_size
@ -1019,6 +1026,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
dtype=torch.half dtype=torch.half
) )
self.config = config
self.quantized = False
if self.config.quantization_bit:
self.quantize(self.config.quantization_bit, empty_init=True)
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
@ -1351,7 +1365,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
break break
yield input_ids yield input_ids
def quantize(self, bits: int): def quantize(self, bits: int, empty_init=False, **kwargs):
if bits == 0:
return
from .quantization import quantize from .quantization import quantize
self.transformer = quantize(self.transformer, bits)
if self.quantized:
logger.info("Already quantized.")
return self
self.quantized = True
self.config.quantization_bit = bits
self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
return self return self

File diff suppressed because one or more lines are too long