Add support for loading quantized model
This commit is contained in:
parent
c949d03152
commit
2e1be30ac4
|
@ -70,6 +70,7 @@ class ChatGLMConfig(PretrainedConfig):
|
||||||
max_sequence_length=2048,
|
max_sequence_length=2048,
|
||||||
inner_hidden_size=16384,
|
inner_hidden_size=16384,
|
||||||
position_encoding_2d=True,
|
position_encoding_2d=True,
|
||||||
|
quantization_bit=0,
|
||||||
pre_seq_len=None,
|
pre_seq_len=None,
|
||||||
prefix_projection=False,
|
prefix_projection=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
|
@ -86,8 +87,10 @@ class ChatGLMConfig(PretrainedConfig):
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.position_encoding_2d = position_encoding_2d
|
self.position_encoding_2d = position_encoding_2d
|
||||||
|
self.quantization_bit = quantization_bit
|
||||||
self.pre_seq_len = pre_seq_len
|
self.pre_seq_len = pre_seq_len
|
||||||
self.prefix_projection = prefix_projection
|
self.prefix_projection = prefix_projection
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
|
|
|
@ -139,6 +139,7 @@ class PrefixEncoder(torch.nn.Module):
|
||||||
Input shape: (batch-size, prefix-length)
|
Input shape: (batch-size, prefix-length)
|
||||||
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.prefix_projection = config.prefix_projection
|
self.prefix_projection = config.prefix_projection
|
||||||
|
@ -216,6 +217,13 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
self.cos_cached, self.sin_cached = cos_cached, sin_cached
|
self.cos_cached, self.sin_cached = cos_cached, sin_cached
|
||||||
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
|
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
|
||||||
|
|
||||||
|
def _apply(self, fn):
|
||||||
|
if self.cos_cached is not None:
|
||||||
|
self.cos_cached = fn(self.cos_cached)
|
||||||
|
if self.sin_cached is not None:
|
||||||
|
self.sin_cached = fn(self.sin_cached)
|
||||||
|
return super()._apply(fn)
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
||||||
|
@ -931,7 +939,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
gmask=use_gmask
|
gmask=use_gmask
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# [seq_len, batch, hidden_size]
|
# [seq_len, batch, hidden_size]
|
||||||
hidden_states = inputs_embeds.transpose(0, 1)
|
hidden_states = inputs_embeds.transpose(0, 1)
|
||||||
|
|
||||||
|
@ -999,7 +1006,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config: ChatGLMConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# self.hidden_size = config.hidden_size
|
# self.hidden_size = config.hidden_size
|
||||||
|
@ -1019,6 +1026,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
dtype=torch.half
|
dtype=torch.half
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.quantized = False
|
||||||
|
|
||||||
|
if self.config.quantization_bit:
|
||||||
|
self.quantize(self.config.quantization_bit, empty_init=True)
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
@ -1351,7 +1365,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
break
|
break
|
||||||
yield input_ids
|
yield input_ids
|
||||||
|
|
||||||
def quantize(self, bits: int):
|
def quantize(self, bits: int, empty_init=False, **kwargs):
|
||||||
|
if bits == 0:
|
||||||
|
return
|
||||||
|
|
||||||
from .quantization import quantize
|
from .quantization import quantize
|
||||||
self.transformer = quantize(self.transformer, bits)
|
|
||||||
|
if self.quantized:
|
||||||
|
logger.info("Already quantized.")
|
||||||
|
return self
|
||||||
|
|
||||||
|
self.quantized = True
|
||||||
|
|
||||||
|
self.config.quantization_bit = bits
|
||||||
|
|
||||||
|
self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
|
||||||
return self
|
return self
|
||||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue