From d2bbc82a2cdd04522ad340bdd464379808676950 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Wed, 22 Mar 2023 14:37:21 +0800 Subject: [PATCH] Fix Chinese punctuation --- modeling_chatglm.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 10a1df2..d833334 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -4,6 +4,7 @@ import math import copy import os import warnings +import re import torch import torch.utils.checkpoint @@ -1085,6 +1086,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): for layer_past in past ) + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + @torch.no_grad() def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): @@ -1107,8 +1123,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): outputs = self.generate(**input_ids, **gen_kwargs) outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] response = tokenizer.decode(outputs) - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") + response = self.process_response(response) history = history + [(query, response)] return response, history @@ -1134,8 +1149,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): for outputs in self.stream_generate(**input_ids, **gen_kwargs): outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] response = tokenizer.decode(outputs) - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") + response = self.process_response(response) new_history = history + [(query, response)] yield response, new_history