Fix past_key_values
This commit is contained in:
parent
65bb3f00a7
commit
cd8041ea53
|
@ -952,6 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
past: Optional[torch.Tensor] = None,
|
past: Optional[torch.Tensor] = None,
|
||||||
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -966,7 +967,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
|
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
|
||||||
|
|
||||||
# only last token for input_ids if past is not None
|
# only last token for input_ids if past is not None
|
||||||
if past:
|
if past is not None or past_key_values is not None:
|
||||||
context_length = seq.index(150004)
|
context_length = seq.index(150004)
|
||||||
last_token = input_ids[:, -1].unsqueeze(-1)
|
last_token = input_ids[:, -1].unsqueeze(-1)
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
|
@ -975,6 +976,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
|
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
|
||||||
|
|
||||||
|
if past is None:
|
||||||
|
past = past_key_values
|
||||||
return {
|
return {
|
||||||
"input_ids": last_token,
|
"input_ids": last_token,
|
||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
|
|
Loading…
Reference in New Issue