From cd8041ea536133e13caff2130e68c436fab7454d Mon Sep 17 00:00:00 2001 From: duzx16 Date: Tue, 14 Mar 2023 02:08:43 +0800 Subject: [PATCH] Fix past_key_values --- modeling_chatglm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index eee464c..e6c6511 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -952,6 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): self, input_ids: torch.LongTensor, past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs ) -> dict: @@ -966,7 +967,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): raise ValueError("You have to add either [MASK] or [gMASK] in your input") # only last token for input_ids if past is not None - if past: + if past is not None or past_key_values is not None: context_length = seq.index(150004) last_token = input_ids[:, -1].unsqueeze(-1) if self.position_encoding_2d: @@ -975,6 +976,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): else: position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device) + if past is None: + past = past_key_values return { "input_ids": last_token, "past_key_values": past,