Add support for loading quantized model
This commit is contained in:
parent
c949d03152
commit
2e1be30ac4
|
@ -70,6 +70,7 @@ class ChatGLMConfig(PretrainedConfig):
|
|||
max_sequence_length=2048,
|
||||
inner_hidden_size=16384,
|
||||
position_encoding_2d=True,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs
|
||||
|
@ -86,8 +87,10 @@ class ChatGLMConfig(PretrainedConfig):
|
|||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.position_encoding_2d = position_encoding_2d
|
||||
self.quantization_bit = quantization_bit
|
||||
self.pre_seq_len = pre_seq_len
|
||||
self.prefix_projection = prefix_projection
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
|
|
|
@ -139,6 +139,7 @@ class PrefixEncoder(torch.nn.Module):
|
|||
Input shape: (batch-size, prefix-length)
|
||||
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.prefix_projection = config.prefix_projection
|
||||
|
@ -216,6 +217,13 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
self.cos_cached, self.sin_cached = cos_cached, sin_cached
|
||||
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
|
||||
|
||||
def _apply(self, fn):
|
||||
if self.cos_cached is not None:
|
||||
self.cos_cached = fn(self.cos_cached)
|
||||
if self.sin_cached is not None:
|
||||
self.sin_cached = fn(self.sin_cached)
|
||||
return super()._apply(fn)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
||||
|
@ -931,7 +939,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
gmask=use_gmask
|
||||
)
|
||||
|
||||
|
||||
# [seq_len, batch, hidden_size]
|
||||
hidden_states = inputs_embeds.transpose(0, 1)
|
||||
|
||||
|
@ -999,7 +1006,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
|
||||
|
||||
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: ChatGLMConfig):
|
||||
super().__init__(config)
|
||||
|
||||
# self.hidden_size = config.hidden_size
|
||||
|
@ -1019,6 +1026,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
dtype=torch.half
|
||||
)
|
||||
|
||||
self.config = config
|
||||
|
||||
self.quantized = False
|
||||
|
||||
if self.config.quantization_bit:
|
||||
self.quantize(self.config.quantization_bit, empty_init=True)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
@ -1351,7 +1365,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
break
|
||||
yield input_ids
|
||||
|
||||
def quantize(self, bits: int):
|
||||
def quantize(self, bits: int, empty_init=False, **kwargs):
|
||||
if bits == 0:
|
||||
return
|
||||
|
||||
from .quantization import quantize
|
||||
self.transformer = quantize(self.transformer, bits)
|
||||
|
||||
if self.quantized:
|
||||
logger.info("Already quantized.")
|
||||
return self
|
||||
|
||||
self.quantized = True
|
||||
|
||||
self.config.quantization_bit = bits
|
||||
|
||||
self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
|
||||
return self
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue