Fix Chinese punctuation
This commit is contained in:
parent
3ba9437241
commit
debaf0032c
|
@ -4,6 +4,7 @@ import math
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
@ -1099,6 +1100,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
for layer_past in past
|
for layer_past in past
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_response(self, response):
|
||||||
|
response = response.strip()
|
||||||
|
response = response.replace("[[训练时间]]", "2023年")
|
||||||
|
punkts = [
|
||||||
|
[",", ","],
|
||||||
|
["!", "!"],
|
||||||
|
[":", ":"],
|
||||||
|
[";", ";"],
|
||||||
|
["\?", "?"],
|
||||||
|
]
|
||||||
|
for item in punkts:
|
||||||
|
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
||||||
|
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
||||||
|
return response
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
||||||
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
||||||
|
@ -1121,8 +1137,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
outputs = self.generate(**input_ids, **gen_kwargs)
|
outputs = self.generate(**input_ids, **gen_kwargs)
|
||||||
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
||||||
response = tokenizer.decode(outputs)
|
response = tokenizer.decode(outputs)
|
||||||
response = response.strip()
|
response = self.process_response(response)
|
||||||
response = response.replace("[[训练时间]]", "2023年")
|
|
||||||
history = history + [(query, response)]
|
history = history + [(query, response)]
|
||||||
return response, history
|
return response, history
|
||||||
|
|
||||||
|
@ -1148,8 +1163,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
||||||
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
||||||
response = tokenizer.decode(outputs)
|
response = tokenizer.decode(outputs)
|
||||||
response = response.strip()
|
response = self.process_response(response)
|
||||||
response = response.replace("[[训练时间]]", "2023年")
|
|
||||||
new_history = history + [(query, response)]
|
new_history = history + [(query, response)]
|
||||||
yield response, new_history
|
yield response, new_history
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue