diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 330077d..1429dbb 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -1077,8 +1077,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ) @torch.no_grad() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], max_length: int = 2048, num_beams=1, + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, do_sample=True, top_p=0.7, temperature=0.95, **kwargs): + if history is None: + history = [] gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, **kwargs} if not history: @@ -1095,7 +1097,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): response = tokenizer.decode(outputs) response = response.strip() response = response.replace("[[训练时间]]", "2023年") - history.append((query, response)) + history = history + [(query, response)] return response, history @torch.no_grad()