Fix past_key_values
This commit is contained in:
parent
65bb3f00a7
commit
cd8041ea53
|
@ -952,6 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
|
@ -966,7 +967,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
|
||||
|
||||
# only last token for input_ids if past is not None
|
||||
if past:
|
||||
if past is not None or past_key_values is not None:
|
||||
context_length = seq.index(150004)
|
||||
last_token = input_ids[:, -1].unsqueeze(-1)
|
||||
if self.position_encoding_2d:
|
||||
|
@ -975,6 +976,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
else:
|
||||
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
|
||||
|
||||
if past is None:
|
||||
past = past_key_values
|
||||
return {
|
||||
"input_ids": last_token,
|
||||
"past_key_values": past,
|
||||
|
|
Loading…
Reference in New Issue