No padding for chat function

This commit is contained in:
duzx16 2023-04-02 02:03:05 +08:00
parent 373fd6b9d4
commit 4b7ffbf070
1 changed files with 2 additions and 2 deletions

View File

@ -1243,7 +1243,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
inputs = tokenizer([prompt], return_tensors="pt", padding=True)
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
outputs = self.generate(**inputs, **gen_kwargs)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
@ -1269,7 +1269,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
inputs = tokenizer([prompt], return_tensors="pt", padding=True)
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
for outputs in self.stream_generate(**inputs, **gen_kwargs):
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]