Fix past_key_values

This commit is contained in:
duzx16 2023-03-14 02:08:43 +08:00
parent 65bb3f00a7
commit cd8041ea53
1 changed files with 4 additions and 1 deletions

View File

@ -952,6 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None, past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
**kwargs **kwargs
) -> dict: ) -> dict:
@ -966,7 +967,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
raise ValueError("You have to add either [MASK] or [gMASK] in your input") raise ValueError("You have to add either [MASK] or [gMASK] in your input")
# only last token for input_ids if past is not None # only last token for input_ids if past is not None
if past: if past is not None or past_key_values is not None:
context_length = seq.index(150004) context_length = seq.index(150004)
last_token = input_ids[:, -1].unsqueeze(-1) last_token = input_ids[:, -1].unsqueeze(-1)
if self.position_encoding_2d: if self.position_encoding_2d:
@ -975,6 +976,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
else: else:
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device) position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
if past is None:
past = past_key_values
return { return {
"input_ids": last_token, "input_ids": last_token,
"past_key_values": past, "past_key_values": past,