Fix default history argument
This commit is contained in:
parent
bcb053bda6
commit
9d1509a1ad
|
@ -1077,8 +1077,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], max_length: int = 2048, num_beams=1,
|
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
||||||
do_sample=True, top_p=0.7, temperature=0.95, **kwargs):
|
do_sample=True, top_p=0.7, temperature=0.95, **kwargs):
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
||||||
"temperature": temperature, **kwargs}
|
"temperature": temperature, **kwargs}
|
||||||
if not history:
|
if not history:
|
||||||
|
@ -1095,7 +1097,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
response = tokenizer.decode(outputs)
|
response = tokenizer.decode(outputs)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
response = response.replace("[[训练时间]]", "2023年")
|
response = response.replace("[[训练时间]]", "2023年")
|
||||||
history.append((query, response))
|
history = history + [(query, response)]
|
||||||
return response, history
|
return response, history
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
Loading…
Reference in New Issue