Fix Chinese punctuation

This commit is contained in:
duzx16 2023-03-22 14:49:29 +08:00
parent 3ba9437241
commit debaf0032c
1 changed files with 18 additions and 4 deletions

View File

@ -4,6 +4,7 @@ import math
import copy
import os
import warnings
import re
import torch
import torch.utils.checkpoint
@ -1099,6 +1100,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
for layer_past in past
)
def process_response(self, response):
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
punkts = [
[",", ""],
["!", ""],
[":", ""],
[";", ""],
["\?", ""],
]
for item in punkts:
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
return response
@torch.no_grad()
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
@ -1121,8 +1137,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
outputs = self.generate(**input_ids, **gen_kwargs)
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
response = tokenizer.decode(outputs)
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
response = self.process_response(response)
history = history + [(query, response)]
return response, history
@ -1148,8 +1163,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
response = tokenizer.decode(outputs)
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
response = self.process_response(response)
new_history = history + [(query, response)]
yield response, new_history