diff --git a/README.md b/README.md index 200f826..89cf0a7 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ tags: ## 介绍 ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。 +ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference. + ## 软件依赖 ```shell @@ -44,6 +46,8 @@ pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels 关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。 +For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM-6B). + ## 协议 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。 diff --git a/modeling_chatglm.py b/modeling_chatglm.py index d961505..bab58f8 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -3,7 +3,9 @@ import math import copy import os -import time +import warnings +import re +import sys import torch import torch.utils.checkpoint @@ -11,7 +13,7 @@ import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List +from typing import Optional, Tuple, Union, List, Callable from transformers.utils import ( add_code_sample_docstrings, @@ -26,15 +28,17 @@ from transformers.modeling_outputs import ( from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig from .configuration_chatglm import ChatGLMConfig # flags required to enable jit fusion kernels -torch._C._jit_set_profiling_mode(False) -torch._C._jit_set_profiling_executor(False) -torch._C._jit_override_can_fuse_on_cpu(True) -torch._C._jit_override_can_fuse_on_gpu(True) + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) logger = logging.get_logger(__name__) @@ -294,7 +298,7 @@ def attention_fn( if not (attention_mask == 0).all(): # if auto-regressive, skip attention_scores.masked_fill_(attention_mask, -10000.0) - dtype = attention_scores.type() + dtype = attention_scores.dtype attention_scores = attention_scores.float() attention_scores = attention_scores * query_key_layer_scaling_coeff @@ -814,8 +818,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel): return past_key_values @staticmethod - def get_masks(seq, device): - context_length = seq.index(150004) + 1 + def get_masks(self, seq, device): + context_length = seq.index(self.config.bos_token_id) + 1 attention_mask = torch.ones((1, len(seq), len(seq)), device=device) attention_mask.tril_() @@ -826,9 +830,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel): return attention_mask def get_position_ids(self, seq, mask_position, device, gmask=False): - context_length = seq.index(150004) + 1 + context_length = len(seq) if self.position_encoding_2d: - seq_length = seq.index(150004) + seq_length = seq.index(self.config.bos_token_id) position_ids = torch.arange(context_length, dtype=torch.long, device=device) if not gmask: position_ids[seq_length:] = mask_position @@ -886,14 +890,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel): past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device) else: past_key_values = tuple([None] * len(self.layers)) - - MASK, gMASK = 150000, 150001 - mask_token = MASK if MASK in input_ids else gMASK - use_gmask = False if MASK in input_ids else gMASK seq = input_ids[0].tolist() - mask_position = seq.index(mask_token) - if attention_mask is None: attention_mask = self.get_masks( seq=seq, @@ -906,6 +904,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel): attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) if position_ids is None: + MASK, gMASK = 150000, 150001 + mask_token = MASK if MASK in input_ids else gMASK + use_gmask = False if MASK in input_ids else gMASK + + mask_position = seq.index(mask_token) position_ids = self.get_position_ids( seq=seq, mask_position=mask_position, @@ -1009,7 +1012,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): attention_mask = (attention_mask < 0.5).bool() if self.position_encoding_2d: - seq_length = seq.index(150004) + seq_length = seq.index(self.config.bos_token_id) position_ids = torch.arange(context_length, dtype=torch.long, device=device) if not gmask: position_ids[seq_length:] = mask_position @@ -1047,7 +1050,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): # only last token for input_ids if past is not None if past is not None or past_key_values is not None: - context_length = seq.index(150004) + context_length = seq.index(self.config.bos_token_id) last_token = input_ids[:, -1].unsqueeze(-1) if self.position_encoding_2d: position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, @@ -1155,6 +1158,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): for layer_past in past ) + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + @torch.no_grad() def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): @@ -1175,66 +1193,139 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): input_ids = tokenizer([prompt], return_tensors="pt", padding=True) input_ids = input_ids.to(self.device) outputs = self.generate(**input_ids, **gen_kwargs) - outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:] + outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] response = tokenizer.decode(outputs) - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") + response = self.process_response(response) history = history + [(query, response)] return response, history @torch.no_grad() - def generate( + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + input_ids = tokenizer([prompt], return_tensors="pt", padding=True) + input_ids = input_ids.to(self.device) + for outputs in self.stream_generate(**input_ids, **gen_kwargs): + outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + yield response, new_history + + @torch.no_grad() + def stream_generate( self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, **kwargs, ): - MASK, gMASK = 150000, 150001 - bos, eos = 150004, 150005 + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - if "eos_token_id" not in kwargs: - kwargs["eos_token_id"] = eos + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - truncate = kwargs.pop("truncate") if "truncate" in kwargs else False + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] - stop = False + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) - return_seqs = [] + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None while True: - output_ids = super().generate(**kwargs) - return_seqs = [] - max_length = 0 + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) - for i in range(output_ids.shape[0]): - output_seq = output_ids[i].tolist() - if truncate: - output_seq = output_seq[len(kwargs["input_ids"][i]) - 2:] - mask_token = MASK if MASK in output_seq else gMASK - mask_position = output_seq.index(mask_token) - bos_position = output_seq.index(bos) - if eos in output_seq: - eos_position = output_seq.index(eos) - else: - eos_position = len(output_seq) + next_token_logits = outputs.logits[:, -1, :] - return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[ - mask_position + 1:bos_position] - max_length = max(max_length, len(return_seq)) - return_seqs.append(return_seq) + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) - for i in range(output_ids.shape[0]): - return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding - if mask_token not in return_seqs[i]: - stop = True + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) - if stop: + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): break - - for return_seq in return_seqs: - return_seq += [bos] - - kwargs['input_ids'] = torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device) - - return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device) + yield input_ids def quantize(self, bits: int): from .quantization import quantize diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index aedbcbe..5f594e6 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -299,7 +299,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): """ if os.path.isdir(save_directory): vocab_file = os.path.join( - save_directory, VOCAB_FILES_NAMES["vocab_file"] + save_directory, self.vocab_files_names["vocab_file"] ) else: vocab_file = save_directory