Fix Chinese punctuation
This commit is contained in:
parent
3ba9437241
commit
debaf0032c
|
@ -4,6 +4,7 @@ import math
|
|||
import copy
|
||||
import os
|
||||
import warnings
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
@ -1099,6 +1100,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
for layer_past in past
|
||||
)
|
||||
|
||||
def process_response(self, response):
|
||||
response = response.strip()
|
||||
response = response.replace("[[训练时间]]", "2023年")
|
||||
punkts = [
|
||||
[",", ","],
|
||||
["!", "!"],
|
||||
[":", ":"],
|
||||
[";", ";"],
|
||||
["\?", "?"],
|
||||
]
|
||||
for item in punkts:
|
||||
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
||||
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
||||
return response
|
||||
|
||||
@torch.no_grad()
|
||||
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
||||
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
||||
|
@ -1121,8 +1137,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
outputs = self.generate(**input_ids, **gen_kwargs)
|
||||
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
||||
response = tokenizer.decode(outputs)
|
||||
response = response.strip()
|
||||
response = response.replace("[[训练时间]]", "2023年")
|
||||
response = self.process_response(response)
|
||||
history = history + [(query, response)]
|
||||
return response, history
|
||||
|
||||
|
@ -1148,8 +1163,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
||||
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
||||
response = tokenizer.decode(outputs)
|
||||
response = response.strip()
|
||||
response = response.replace("[[训练时间]]", "2023年")
|
||||
response = self.process_response(response)
|
||||
new_history = history + [(query, response)]
|
||||
yield response, new_history
|
||||
|
||||
|
|
Loading…
Reference in New Issue