Merge branch 'main' into dev_pt

# Conflicts:
#	modeling_chatglm.py
This commit is contained in:
duzx16 2023-03-29 20:37:39 +08:00
commit fbda1206cb
3 changed files with 157 additions and 62 deletions

View File

@ -11,6 +11,8 @@ tags:
## 介绍 ## 介绍
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术用户可以在消费级的显卡上进行本地部署INT4 量化级别下最低只需 6GB 显存。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练辅以监督微调、反馈自助、人类反馈强化学习等技术的加持62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。 ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术用户可以在消费级的显卡上进行本地部署INT4 量化级别下最低只需 6GB 显存。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练辅以监督微调、反馈自助、人类反馈强化学习等技术的加持62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.
## 软件依赖 ## 软件依赖
```shell ```shell
@ -44,6 +46,8 @@ pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels
关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO以及使用模型量化以节省显存请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。 关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO以及使用模型量化以节省显存请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。
For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM-6B).
## 协议 ## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。

View File

@ -3,7 +3,9 @@
import math import math
import copy import copy
import os import os
import time import warnings
import re
import sys
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -11,7 +13,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List, Callable
from transformers.utils import ( from transformers.utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
@ -26,11 +28,13 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from .configuration_chatglm import ChatGLMConfig from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
if sys.platform != 'darwin':
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_cpu(True)
@ -294,7 +298,7 @@ def attention_fn(
if not (attention_mask == 0).all(): if not (attention_mask == 0).all():
# if auto-regressive, skip # if auto-regressive, skip
attention_scores.masked_fill_(attention_mask, -10000.0) attention_scores.masked_fill_(attention_mask, -10000.0)
dtype = attention_scores.type() dtype = attention_scores.dtype
attention_scores = attention_scores.float() attention_scores = attention_scores.float()
attention_scores = attention_scores * query_key_layer_scaling_coeff attention_scores = attention_scores * query_key_layer_scaling_coeff
@ -814,8 +818,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
return past_key_values return past_key_values
@staticmethod @staticmethod
def get_masks(seq, device): def get_masks(self, seq, device):
context_length = seq.index(150004) + 1 context_length = seq.index(self.config.bos_token_id) + 1
attention_mask = torch.ones((1, len(seq), len(seq)), device=device) attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
attention_mask.tril_() attention_mask.tril_()
@ -826,9 +830,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
return attention_mask return attention_mask
def get_position_ids(self, seq, mask_position, device, gmask=False): def get_position_ids(self, seq, mask_position, device, gmask=False):
context_length = seq.index(150004) + 1 context_length = len(seq)
if self.position_encoding_2d: if self.position_encoding_2d:
seq_length = seq.index(150004) seq_length = seq.index(self.config.bos_token_id)
position_ids = torch.arange(context_length, dtype=torch.long, device=device) position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask: if not gmask:
position_ids[seq_length:] = mask_position position_ids[seq_length:] = mask_position
@ -886,14 +890,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device) past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
else: else:
past_key_values = tuple([None] * len(self.layers)) past_key_values = tuple([None] * len(self.layers))
MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK
seq = input_ids[0].tolist() seq = input_ids[0].tolist()
mask_position = seq.index(mask_token)
if attention_mask is None: if attention_mask is None:
attention_mask = self.get_masks( attention_mask = self.get_masks(
seq=seq, seq=seq,
@ -906,6 +904,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
if position_ids is None: if position_ids is None:
MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK
mask_position = seq.index(mask_token)
position_ids = self.get_position_ids( position_ids = self.get_position_ids(
seq=seq, seq=seq,
mask_position=mask_position, mask_position=mask_position,
@ -1009,7 +1012,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
attention_mask = (attention_mask < 0.5).bool() attention_mask = (attention_mask < 0.5).bool()
if self.position_encoding_2d: if self.position_encoding_2d:
seq_length = seq.index(150004) seq_length = seq.index(self.config.bos_token_id)
position_ids = torch.arange(context_length, dtype=torch.long, device=device) position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask: if not gmask:
position_ids[seq_length:] = mask_position position_ids[seq_length:] = mask_position
@ -1047,7 +1050,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
# only last token for input_ids if past is not None # only last token for input_ids if past is not None
if past is not None or past_key_values is not None: if past is not None or past_key_values is not None:
context_length = seq.index(150004) context_length = seq.index(self.config.bos_token_id)
last_token = input_ids[:, -1].unsqueeze(-1) last_token = input_ids[:, -1].unsqueeze(-1)
if self.position_encoding_2d: if self.position_encoding_2d:
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
@ -1155,6 +1158,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
for layer_past in past for layer_past in past
) )
def process_response(self, response):
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
punkts = [
[",", ""],
["!", ""],
[":", ""],
[";", ""],
["\?", ""],
]
for item in punkts:
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
return response
@torch.no_grad() @torch.no_grad()
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
@ -1175,66 +1193,139 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
input_ids = tokenizer([prompt], return_tensors="pt", padding=True) input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
input_ids = input_ids.to(self.device) input_ids = input_ids.to(self.device)
outputs = self.generate(**input_ids, **gen_kwargs) outputs = self.generate(**input_ids, **gen_kwargs)
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:] outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
response = tokenizer.decode(outputs) response = tokenizer.decode(outputs)
response = response.strip() response = self.process_response(response)
response = response.replace("[[训练时间]]", "2023年")
history = history + [(query, response)] history = history + [(query, response)]
return response, history return response, history
@torch.no_grad() @torch.no_grad()
def generate( def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
if not history:
prompt = query
else:
prompt = ""
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
input_ids = input_ids.to(self.device)
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
response = tokenizer.decode(outputs)
response = self.process_response(response)
new_history = history + [(query, response)]
yield response, new_history
@torch.no_grad()
def stream_generate(
self, self,
input_ids,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs, **kwargs,
): ):
MASK, gMASK = 150000, 150001 batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
bos, eos = 150004, 150005
if "eos_token_id" not in kwargs: if generation_config is None:
kwargs["eos_token_id"] = eos generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
truncate = kwargs.pop("truncate") if "truncate" in kwargs else False if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
stop = False has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
return_seqs = [] if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = self._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True: while True:
output_ids = super().generate(**kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
return_seqs = [] # forward pass to get next token
max_length = 0 outputs = self(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
for i in range(output_ids.shape[0]): next_token_logits = outputs.logits[:, -1, :]
output_seq = output_ids[i].tolist()
if truncate: # pre-process distribution
output_seq = output_seq[len(kwargs["input_ids"][i]) - 2:] next_token_scores = logits_processor(input_ids, next_token_logits)
mask_token = MASK if MASK in output_seq else gMASK next_token_scores = logits_warper(input_ids, next_token_scores)
mask_position = output_seq.index(mask_token)
bos_position = output_seq.index(bos) # sample
if eos in output_seq: probs = nn.functional.softmax(next_token_scores, dim=-1)
eos_position = output_seq.index(eos) if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else: else:
eos_position = len(output_seq) next_tokens = torch.argmax(probs, dim=-1)
return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[ # update generated ids, model inputs, and length for next step
mask_position + 1:bos_position] input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
max_length = max(max_length, len(return_seq)) model_kwargs = self._update_model_kwargs_for_generation(
return_seqs.append(return_seq) outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
for i in range(output_ids.shape[0]): # stop when each sentence is finished, or if we exceed the maximum length
return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if mask_token not in return_seqs[i]:
stop = True
if stop:
break break
yield input_ids
for return_seq in return_seqs:
return_seq += [bos]
kwargs['input_ids'] = torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
def quantize(self, bits: int): def quantize(self, bits: int):
from .quantization import quantize from .quantization import quantize

View File

@ -299,7 +299,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
""" """
if os.path.isdir(save_directory): if os.path.isdir(save_directory):
vocab_file = os.path.join( vocab_file = os.path.join(
save_directory, VOCAB_FILES_NAMES["vocab_file"] save_directory, self.vocab_files_names["vocab_file"]
) )
else: else:
vocab_file = save_directory vocab_file = save_directory