Merge branch 'main' into dev_pt

# Conflicts:
#	modeling_chatglm.py
This commit is contained in:
duzx16 2023-03-29 20:37:39 +08:00
commit fbda1206cb
3 changed files with 157 additions and 62 deletions

View File

@ -11,6 +11,8 @@ tags:
## 介绍
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术用户可以在消费级的显卡上进行本地部署INT4 量化级别下最低只需 6GB 显存。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练辅以监督微调、反馈自助、人类反馈强化学习等技术的加持62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.
## 软件依赖
```shell
@ -44,6 +46,8 @@ pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels
关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO以及使用模型量化以节省显存请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。
For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM-6B).
## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。

View File

@ -3,7 +3,9 @@
import math
import copy
import os
import time
import warnings
import re
import sys
import torch
import torch.utils.checkpoint
@ -11,7 +13,7 @@ import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List
from typing import Optional, Tuple, Union, List, Callable
from transformers.utils import (
add_code_sample_docstrings,
@ -26,15 +28,17 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
if sys.platform != 'darwin':
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
logger = logging.get_logger(__name__)
@ -294,7 +298,7 @@ def attention_fn(
if not (attention_mask == 0).all():
# if auto-regressive, skip
attention_scores.masked_fill_(attention_mask, -10000.0)
dtype = attention_scores.type()
dtype = attention_scores.dtype
attention_scores = attention_scores.float()
attention_scores = attention_scores * query_key_layer_scaling_coeff
@ -814,8 +818,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
return past_key_values
@staticmethod
def get_masks(seq, device):
context_length = seq.index(150004) + 1
def get_masks(self, seq, device):
context_length = seq.index(self.config.bos_token_id) + 1
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
attention_mask.tril_()
@ -826,9 +830,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
return attention_mask
def get_position_ids(self, seq, mask_position, device, gmask=False):
context_length = seq.index(150004) + 1
context_length = len(seq)
if self.position_encoding_2d:
seq_length = seq.index(150004)
seq_length = seq.index(self.config.bos_token_id)
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[seq_length:] = mask_position
@ -886,14 +890,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
else:
past_key_values = tuple([None] * len(self.layers))
MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK
seq = input_ids[0].tolist()
mask_position = seq.index(mask_token)
if attention_mask is None:
attention_mask = self.get_masks(
seq=seq,
@ -906,6 +904,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
if position_ids is None:
MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK
mask_position = seq.index(mask_token)
position_ids = self.get_position_ids(
seq=seq,
mask_position=mask_position,
@ -1009,7 +1012,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
attention_mask = (attention_mask < 0.5).bool()
if self.position_encoding_2d:
seq_length = seq.index(150004)
seq_length = seq.index(self.config.bos_token_id)
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[seq_length:] = mask_position
@ -1047,7 +1050,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
# only last token for input_ids if past is not None
if past is not None or past_key_values is not None:
context_length = seq.index(150004)
context_length = seq.index(self.config.bos_token_id)
last_token = input_ids[:, -1].unsqueeze(-1)
if self.position_encoding_2d:
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
@ -1155,6 +1158,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
for layer_past in past
)
def process_response(self, response):
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
punkts = [
[",", ""],
["!", ""],
[":", ""],
[";", ""],
["\?", ""],
]
for item in punkts:
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
return response
@torch.no_grad()
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
@ -1175,66 +1193,139 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
input_ids = input_ids.to(self.device)
outputs = self.generate(**input_ids, **gen_kwargs)
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:]
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
response = tokenizer.decode(outputs)
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
response = self.process_response(response)
history = history + [(query, response)]
return response, history
@torch.no_grad()
def generate(
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
if not history:
prompt = query
else:
prompt = ""
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
input_ids = input_ids.to(self.device)
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
response = tokenizer.decode(outputs)
response = self.process_response(response)
new_history = history + [(query, response)]
yield response, new_history
@torch.no_grad()
def stream_generate(
self,
input_ids,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs,
):
MASK, gMASK = 150000, 150001
bos, eos = 150004, 150005
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if "eos_token_id" not in kwargs:
kwargs["eos_token_id"] = eos
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
truncate = kwargs.pop("truncate") if "truncate" in kwargs else False
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
stop = False
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
return_seqs = []
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = self._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
output_ids = super().generate(**kwargs)
return_seqs = []
max_length = 0
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
for i in range(output_ids.shape[0]):
output_seq = output_ids[i].tolist()
if truncate:
output_seq = output_seq[len(kwargs["input_ids"][i]) - 2:]
mask_token = MASK if MASK in output_seq else gMASK
mask_position = output_seq.index(mask_token)
bos_position = output_seq.index(bos)
if eos in output_seq:
eos_position = output_seq.index(eos)
else:
eos_position = len(output_seq)
next_token_logits = outputs.logits[:, -1, :]
return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[
mask_position + 1:bos_position]
max_length = max(max_length, len(return_seq))
return_seqs.append(return_seq)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
for i in range(output_ids.shape[0]):
return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding
if mask_token not in return_seqs[i]:
stop = True
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
if stop:
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break
for return_seq in return_seqs:
return_seq += [bos]
kwargs['input_ids'] = torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
yield input_ids
def quantize(self, bits: int):
from .quantization import quantize

View File

@ -299,7 +299,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, VOCAB_FILES_NAMES["vocab_file"]
save_directory, self.vocab_files_names["vocab_file"]
)
else:
vocab_file = save_directory