Compare commits
20 Commits
Author | SHA1 | Date |
---|---|---|
Zhengxiao Du | 02a065cf27 | |
Zhengxiao Du | e214c5b71d | |
duzx16 | d8a6cfc6cb | |
duzx16 | f6b88da8c1 | |
duzx16 | 63d66b0572 | |
duzx16 | f55a1089a2 | |
duzx16 | e02ba894cf | |
duzx16 | 6498797e79 | |
duzx16 | 1e40d965fe | |
songxxzp | 630d0efd8b | |
songxxzp | bcc35f08b4 | |
songxxzp | fe0674f86d | |
songxxzp | c7d8998bb3 | |
duzx16 | 3485994337 | |
duzx16 | 9333486c30 | |
duzx16 | 6466cdcff5 | |
duzx16 | 9163f7e6d9 | |
duzx16 | 649466f2d7 | |
duzx16 | 41fda88421 | |
duzx16 | a7272d4c93 |
|
@ -7,7 +7,11 @@ tags:
|
||||||
- chatglm
|
- chatglm
|
||||||
- thudm
|
- thudm
|
||||||
---
|
---
|
||||||
# ChatGLM-6B
|
# ChatGLM-6B-INT4
|
||||||
|
<p align="center">
|
||||||
|
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1udqapmrr-ocT1DS_mxWe6dDY8ahRWzg" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
||||||
|
</p>
|
||||||
|
|
||||||
## 介绍
|
## 介绍
|
||||||
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
|
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
|
||||||
|
|
||||||
|
@ -18,7 +22,7 @@ ChatGLM-6B-INT4 是 ChatGLM-6B 量化后的模型权重。具体的,ChatGLM-6B
|
||||||
## 软件依赖
|
## 软件依赖
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels
|
pip install protobuf transformers==4.27.1 cpm_kernels
|
||||||
```
|
```
|
||||||
|
|
||||||
## 代码调用
|
## 代码调用
|
||||||
|
|
|
@ -10,16 +10,16 @@
|
||||||
},
|
},
|
||||||
"bos_token_id": 130004,
|
"bos_token_id": 130004,
|
||||||
"eos_token_id": 130005,
|
"eos_token_id": 130005,
|
||||||
|
"mask_token_id": 130000,
|
||||||
"gmask_token_id": 130001,
|
"gmask_token_id": 130001,
|
||||||
|
"pad_token_id": 3,
|
||||||
"hidden_size": 4096,
|
"hidden_size": 4096,
|
||||||
"inner_hidden_size": 16384,
|
"inner_hidden_size": 16384,
|
||||||
"layernorm_epsilon": 1e-05,
|
"layernorm_epsilon": 1e-05,
|
||||||
"mask_token_id": 130000,
|
|
||||||
"max_sequence_length": 2048,
|
"max_sequence_length": 2048,
|
||||||
"model_type": "chatglm",
|
"model_type": "chatglm",
|
||||||
"num_attention_heads": 32,
|
"num_attention_heads": 32,
|
||||||
"num_layers": 28,
|
"num_layers": 28,
|
||||||
"pad_token_id": 3,
|
|
||||||
"position_encoding_2d": true,
|
"position_encoding_2d": true,
|
||||||
"quantization_bit": 4,
|
"quantization_bit": 4,
|
||||||
"quantization_embeddings": false,
|
"quantization_embeddings": false,
|
||||||
|
|
|
@ -56,7 +56,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||||
scores.zero_()
|
scores.zero_()
|
||||||
scores[..., 20005] = 5e4
|
scores[..., 5] = 5e4
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
@ -280,10 +280,8 @@ def attention_fn(
|
||||||
# [sk, b, np, hn] -> [sk, b * np, hn]
|
# [sk, b, np, hn] -> [sk, b * np, hn]
|
||||||
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
||||||
|
|
||||||
matmul_result = torch.empty(
|
matmul_result = torch.zeros(
|
||||||
output_size[0] * output_size[1],
|
1, 1, 1,
|
||||||
output_size[2],
|
|
||||||
output_size[3],
|
|
||||||
dtype=query_layer.dtype,
|
dtype=query_layer.dtype,
|
||||||
device=query_layer.device,
|
device=query_layer.device,
|
||||||
)
|
)
|
||||||
|
@ -348,10 +346,18 @@ def attention_fn(
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def default_init(cls, *args, **kwargs):
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(torch.nn.Module):
|
class SelfAttention(torch.nn.Module):
|
||||||
def __init__(self, hidden_size, num_attention_heads,
|
def __init__(self, hidden_size, num_attention_heads,
|
||||||
layer_id, hidden_size_per_attention_head=None, bias=True,
|
layer_id, hidden_size_per_attention_head=None, bias=True,
|
||||||
params_dtype=torch.float, position_encoding_2d=True):
|
params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
|
||||||
|
if empty_init:
|
||||||
|
init_method = skip_init
|
||||||
|
else:
|
||||||
|
init_method = default_init
|
||||||
super(SelfAttention, self).__init__()
|
super(SelfAttention, self).__init__()
|
||||||
|
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
@ -379,7 +385,7 @@ class SelfAttention(torch.nn.Module):
|
||||||
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
|
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
|
||||||
|
|
||||||
# Strided linear layer.
|
# Strided linear layer.
|
||||||
self.query_key_value = skip_init(
|
self.query_key_value = init_method(
|
||||||
torch.nn.Linear,
|
torch.nn.Linear,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3 * self.inner_hidden_size,
|
3 * self.inner_hidden_size,
|
||||||
|
@ -387,7 +393,7 @@ class SelfAttention(torch.nn.Module):
|
||||||
dtype=params_dtype,
|
dtype=params_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dense = skip_init(
|
self.dense = init_method(
|
||||||
torch.nn.Linear,
|
torch.nn.Linear,
|
||||||
self.inner_hidden_size,
|
self.inner_hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -500,8 +506,12 @@ class GEGLU(torch.nn.Module):
|
||||||
|
|
||||||
class GLU(torch.nn.Module):
|
class GLU(torch.nn.Module):
|
||||||
def __init__(self, hidden_size, inner_hidden_size=None,
|
def __init__(self, hidden_size, inner_hidden_size=None,
|
||||||
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):
|
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
|
||||||
super(GLU, self).__init__()
|
super(GLU, self).__init__()
|
||||||
|
if empty_init:
|
||||||
|
init_method = skip_init
|
||||||
|
else:
|
||||||
|
init_method = default_init
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.activation_func = activation_func
|
self.activation_func = activation_func
|
||||||
|
|
||||||
|
@ -510,7 +520,7 @@ class GLU(torch.nn.Module):
|
||||||
if inner_hidden_size is None:
|
if inner_hidden_size is None:
|
||||||
inner_hidden_size = 4 * hidden_size
|
inner_hidden_size = 4 * hidden_size
|
||||||
self.inner_hidden_size = inner_hidden_size
|
self.inner_hidden_size = inner_hidden_size
|
||||||
self.dense_h_to_4h = skip_init(
|
self.dense_h_to_4h = init_method(
|
||||||
torch.nn.Linear,
|
torch.nn.Linear,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.inner_hidden_size,
|
self.inner_hidden_size,
|
||||||
|
@ -518,7 +528,7 @@ class GLU(torch.nn.Module):
|
||||||
dtype=params_dtype,
|
dtype=params_dtype,
|
||||||
)
|
)
|
||||||
# Project back to h.
|
# Project back to h.
|
||||||
self.dense_4h_to_h = skip_init(
|
self.dense_4h_to_h = init_method(
|
||||||
torch.nn.Linear,
|
torch.nn.Linear,
|
||||||
self.inner_hidden_size,
|
self.inner_hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
|
@ -554,7 +564,8 @@ class GLMBlock(torch.nn.Module):
|
||||||
use_bias=True,
|
use_bias=True,
|
||||||
params_dtype=torch.float,
|
params_dtype=torch.float,
|
||||||
num_layers=28,
|
num_layers=28,
|
||||||
position_encoding_2d=True
|
position_encoding_2d=True,
|
||||||
|
empty_init=True
|
||||||
):
|
):
|
||||||
super(GLMBlock, self).__init__()
|
super(GLMBlock, self).__init__()
|
||||||
# Set output layer initialization if not provided.
|
# Set output layer initialization if not provided.
|
||||||
|
@ -574,7 +585,8 @@ class GLMBlock(torch.nn.Module):
|
||||||
hidden_size_per_attention_head=hidden_size_per_attention_head,
|
hidden_size_per_attention_head=hidden_size_per_attention_head,
|
||||||
bias=use_bias,
|
bias=use_bias,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
position_encoding_2d=self.position_encoding_2d
|
position_encoding_2d=self.position_encoding_2d,
|
||||||
|
empty_init=empty_init
|
||||||
)
|
)
|
||||||
|
|
||||||
# Layernorm on the input data.
|
# Layernorm on the input data.
|
||||||
|
@ -589,6 +601,7 @@ class GLMBlock(torch.nn.Module):
|
||||||
bias=use_bias,
|
bias=use_bias,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
|
empty_init=empty_init
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -676,8 +689,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
if use_gmasks is None:
|
||||||
|
use_gmasks = [False] * batch_size
|
||||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||||
|
@ -691,8 +706,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
||||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||||
else:
|
else:
|
||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||||
if not gmask:
|
|
||||||
for i, context_length in enumerate(context_lengths):
|
for i, context_length in enumerate(context_lengths):
|
||||||
|
if not use_gmasks[i]:
|
||||||
position_ids[context_length:] = mask_positions[i]
|
position_ids[context_length:] = mask_positions[i]
|
||||||
|
|
||||||
return position_ids
|
return position_ids
|
||||||
|
@ -783,9 +798,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
`encoder_hidden_states` is then expected as an input to the forward pass.
|
`encoder_hidden_states` is then expected as an input to the forward pass.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: ChatGLMConfig):
|
def __init__(self, config: ChatGLMConfig, empty_init=True):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
if empty_init:
|
||||||
|
init_method = skip_init
|
||||||
|
else:
|
||||||
|
init_method = default_init
|
||||||
# recording parameters
|
# recording parameters
|
||||||
self.max_sequence_length = config.max_sequence_length
|
self.max_sequence_length = config.max_sequence_length
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
@ -800,7 +818,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
self.pre_seq_len = config.pre_seq_len
|
self.pre_seq_len = config.pre_seq_len
|
||||||
self.prefix_projection = config.prefix_projection
|
self.prefix_projection = config.prefix_projection
|
||||||
|
|
||||||
self.word_embeddings = skip_init(
|
self.word_embeddings = init_method(
|
||||||
torch.nn.Embedding,
|
torch.nn.Embedding,
|
||||||
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
||||||
dtype=self.params_dtype
|
dtype=self.params_dtype
|
||||||
|
@ -819,6 +837,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
use_bias=True,
|
use_bias=True,
|
||||||
params_dtype=self.params_dtype,
|
params_dtype=self.params_dtype,
|
||||||
position_encoding_2d=self.position_encoding_2d,
|
position_encoding_2d=self.position_encoding_2d,
|
||||||
|
empty_init=empty_init
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers = torch.nn.ModuleList(
|
self.layers = torch.nn.ModuleList(
|
||||||
|
@ -899,7 +918,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
batch_size, seq_length = input_ids.shape[:2]
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
@ -922,15 +941,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
||||||
mask_token = gMASK if gMASK in input_ids else MASK
|
seqs = input_ids.tolist()
|
||||||
use_gmask = True if gMASK in input_ids else False
|
|
||||||
|
mask_positions, use_gmasks = [], []
|
||||||
|
for seq in seqs:
|
||||||
|
mask_token = gMASK if gMASK in seq else MASK
|
||||||
|
use_gmask = mask_token == gMASK
|
||||||
|
mask_positions.append(seq.index(mask_token))
|
||||||
|
use_gmasks.append(use_gmask)
|
||||||
|
|
||||||
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
|
||||||
position_ids = self.get_position_ids(
|
position_ids = self.get_position_ids(
|
||||||
input_ids,
|
input_ids,
|
||||||
mask_positions=mask_positions,
|
mask_positions=mask_positions,
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
gmask=use_gmask
|
use_gmasks=use_gmasks
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pre_seq_len is not None and attention_mask is not None:
|
if self.pre_seq_len is not None and attention_mask is not None:
|
||||||
|
@ -948,9 +972,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attention_mask = attention_mask.to(input_ids.device)
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
|
||||||
|
@ -1006,8 +1029,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
def __init__(self, config: ChatGLMConfig):
|
def __init__(self, config: ChatGLMConfig, empty_init=True):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
if empty_init:
|
||||||
|
init_method = skip_init
|
||||||
|
else:
|
||||||
|
init_method = default_init
|
||||||
|
|
||||||
# self.hidden_size = config.hidden_size
|
# self.hidden_size = config.hidden_size
|
||||||
# self.params_dtype = torch.half
|
# self.params_dtype = torch.half
|
||||||
|
@ -1016,9 +1043,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
self.position_encoding_2d = config.position_encoding_2d
|
self.position_encoding_2d = config.position_encoding_2d
|
||||||
|
|
||||||
self.transformer = ChatGLMModel(config)
|
self.transformer = ChatGLMModel(config, empty_init=empty_init)
|
||||||
|
|
||||||
self.lm_head = skip_init(
|
self.lm_head = init_method(
|
||||||
nn.Linear,
|
nn.Linear,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
|
@ -1085,10 +1112,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
) -> dict:
|
) -> dict:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
||||||
mask_token = gMASK if gMASK in input_ids else MASK
|
|
||||||
use_gmask = True if gMASK in input_ids else False
|
|
||||||
seqs = input_ids.tolist()
|
seqs = input_ids.tolist()
|
||||||
mask_positions = [seq.index(mask_token) for seq in seqs]
|
mask_positions, use_gmasks = [], []
|
||||||
|
for seq in seqs:
|
||||||
|
mask_token = gMASK if gMASK in seq else MASK
|
||||||
|
use_gmask = mask_token == gMASK
|
||||||
|
mask_positions.append(seq.index(mask_token))
|
||||||
|
use_gmasks.append(use_gmask)
|
||||||
|
|
||||||
# only last token for input_ids if past is not None
|
# only last token for input_ids if past is not None
|
||||||
if past is not None or past_key_values is not None:
|
if past is not None or past_key_values is not None:
|
||||||
|
@ -1131,7 +1161,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
input_ids,
|
input_ids,
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
mask_positions=mask_positions,
|
mask_positions=mask_positions,
|
||||||
gmask=use_gmask
|
use_gmasks=use_gmasks
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:35828b49cf23cbae4c27788d4b04fc68c79a276300e09f14d72a49b0b738b4a9
|
oid sha256:245786435bde9f4593c105ea846fa461fe42bc63c12b738d0272fcaed6276645
|
||||||
size 3893083075
|
size 3893083075
|
||||||
|
|
125
quantization.py
125
quantization.py
File diff suppressed because one or more lines are too long
|
@ -31,6 +31,9 @@ class TextTokenizer:
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
return self.sp.EncodeAsPieces(text)
|
return self.sp.EncodeAsPieces(text)
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
return self.sp.DecodePieces(tokens)
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens):
|
||||||
return [self.sp.PieceToId(token) for token in tokens]
|
return [self.sp.PieceToId(token) for token in tokens]
|
||||||
|
|
||||||
|
@ -111,16 +114,25 @@ class SPTokenizer:
|
||||||
tokens = [x + self.num_image_tokens for x in tmp]
|
tokens = [x + self.num_image_tokens for x in tmp]
|
||||||
return tokens if add_dummy_prefix else tokens[2:]
|
return tokens if add_dummy_prefix else tokens[2:]
|
||||||
|
|
||||||
def decode(self, text_ids: List[int]) -> str:
|
def postprocess(self, text):
|
||||||
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
|
||||||
ids = [_id for _id in ids if _id >= 0]
|
|
||||||
text = self._get_text_tokenizer().decode(ids)
|
|
||||||
text = text.replace("<n>", "\n")
|
text = text.replace("<n>", "\n")
|
||||||
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
||||||
for i in range(2, self.max_blank_length + 1):
|
for i in range(2, self.max_blank_length + 1):
|
||||||
text = text.replace(self.get_blank_token(i), " " * i)
|
text = text.replace(self.get_blank_token(i), " " * i)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def decode(self, text_ids: List[int]) -> str:
|
||||||
|
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
||||||
|
ids = [_id for _id in ids if _id >= 0]
|
||||||
|
text = self._get_text_tokenizer().decode(ids)
|
||||||
|
text = self.postprocess(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens: List[str]) -> str:
|
||||||
|
text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
|
||||||
|
text = self.postprocess(text)
|
||||||
|
return text
|
||||||
|
|
||||||
def tokenize(
|
def tokenize(
|
||||||
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -170,12 +182,14 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
vocab_file,
|
vocab_file,
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
remove_space=False,
|
remove_space=False,
|
||||||
bos_token='sop',
|
bos_token='<sop>',
|
||||||
eos_token='eos',
|
eos_token='<eop>',
|
||||||
eop_token='eop',
|
end_token='</s>',
|
||||||
mask_token='[MASK]',
|
mask_token='[MASK]',
|
||||||
gmask_token='[gMASK]',
|
gmask_token='[gMASK]',
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
|
pad_token="<pad>",
|
||||||
|
unk_token="<unk>",
|
||||||
num_image_tokens=20000,
|
num_image_tokens=20000,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -183,6 +197,14 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
do_lower_case=do_lower_case,
|
do_lower_case=do_lower_case,
|
||||||
remove_space=remove_space,
|
remove_space=remove_space,
|
||||||
padding_side=padding_side,
|
padding_side=padding_side,
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
end_token=end_token,
|
||||||
|
mask_token=mask_token,
|
||||||
|
gmask_token=gmask_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
num_image_tokens=num_image_tokens,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -192,7 +214,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
self.bos_token = bos_token
|
self.bos_token = bos_token
|
||||||
self.eos_token = eos_token
|
self.eos_token = eos_token
|
||||||
self.eop_token = eop_token
|
self.end_token = end_token
|
||||||
self.mask_token = mask_token
|
self.mask_token = mask_token
|
||||||
self.gmask_token = gmask_token
|
self.gmask_token = gmask_token
|
||||||
|
|
||||||
|
@ -207,14 +229,14 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
return self.convert_tokens_to_ids(self.gmask_token)
|
return self.convert_tokens_to_ids(self.gmask_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eop_token_id(self) -> Optional[int]:
|
def end_token_id(self) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been
|
`Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been
|
||||||
set.
|
set.
|
||||||
"""
|
"""
|
||||||
if self.eop_token is None:
|
if self.end_token is None:
|
||||||
return None
|
return None
|
||||||
return self.convert_tokens_to_ids(self.eop_token)
|
return self.convert_tokens_to_ids(self.end_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
|
@ -246,11 +268,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
return seq
|
return seq
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
return self.sp_tokenizer.decode_tokens(tokens)
|
||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
self,
|
self,
|
||||||
token_ids: Union[int, List[int]],
|
token_ids: Union[int, List[int]],
|
||||||
skip_special_tokens: bool = False,
|
|
||||||
clean_up_tokenization_spaces: bool = True,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(token_ids, int):
|
if isinstance(token_ids, int):
|
||||||
|
@ -259,7 +282,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
return ""
|
return ""
|
||||||
if self.pad_token_id in token_ids: # remove pad
|
if self.pad_token_id in token_ids: # remove pad
|
||||||
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
||||||
return self.sp_tokenizer.decode(token_ids)
|
return super()._decode(token_ids, **kwargs)
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a token (str) in an id using the vocab. """
|
""" Converts a token (str) in an id using the vocab. """
|
||||||
|
@ -316,22 +339,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
Returns:
|
Returns:
|
||||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||||
"""
|
"""
|
||||||
mask_ids = self.sp_tokenizer[self.mask_token]
|
gmask_id = self.sp_tokenizer[self.gmask_token]
|
||||||
gmask_ids = self.sp_tokenizer[self.gmask_token]
|
eos_id = self.sp_tokenizer[self.eos_token]
|
||||||
eop_id = self.sp_tokenizer[self.eop_token]
|
token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
|
||||||
if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
|
|
||||||
token_ids_0 += [gmask_ids]
|
|
||||||
|
|
||||||
if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids:
|
|
||||||
token_ids_0 += [self.sp_tokenizer[self.eos_token]]
|
|
||||||
|
|
||||||
token_ids_0 += [self.sp_tokenizer[self.bos_token]]
|
|
||||||
|
|
||||||
if token_ids_1 is not None:
|
if token_ids_1 is not None:
|
||||||
if not token_ids_1 or token_ids_1[-1] != eop_id:
|
token_ids_0 = token_ids_0 + token_ids_1 + [eos_id]
|
||||||
token_ids_1 += [eop_id]
|
|
||||||
token_ids_0 += token_ids_1
|
|
||||||
|
|
||||||
return token_ids_0
|
return token_ids_0
|
||||||
|
|
||||||
def _pad(
|
def _pad(
|
||||||
|
@ -396,6 +408,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
encoded_inputs["attention_mask"] = attention_mask
|
encoded_inputs["attention_mask"] = attention_mask
|
||||||
|
|
||||||
if "position_ids" not in encoded_inputs:
|
if "position_ids" not in encoded_inputs:
|
||||||
|
if bos_token_id in required_input:
|
||||||
|
context_length = required_input.index(bos_token_id)
|
||||||
|
else:
|
||||||
|
context_length = seq_length
|
||||||
position_ids = np.arange(seq_length, dtype=np.int64)
|
position_ids = np.arange(seq_length, dtype=np.int64)
|
||||||
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
||||||
if mask_token in required_input:
|
if mask_token in required_input:
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
{
|
{
|
||||||
"name_or_path": "THUDM/chatglm-6b-int4",
|
"name_or_path": "THUDM/chatglm-6b-int4",
|
||||||
"bos_token": "<sop>",
|
"bos_token": "<sop>",
|
||||||
"eop_token": "<eop>",
|
"eos_token": "<eop>",
|
||||||
"eos_token": "</s>",
|
"end_token": "</s>",
|
||||||
"gmask_token": "[gMASK]",
|
"gmask_token": "[gMASK]",
|
||||||
"mask_token": "[MASK]",
|
"mask_token": "[MASK]",
|
||||||
"pad_token": "<pad>",
|
"pad_token": "<pad>",
|
||||||
|
|
Loading…
Reference in New Issue