diff --git a/modeling_chatglm.py b/modeling_chatglm.py index eee464c..e6c6511 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -952,6 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): self, input_ids: torch.LongTensor, past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs ) -> dict: @@ -966,7 +967,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): raise ValueError("You have to add either [MASK] or [gMASK] in your input") # only last token for input_ids if past is not None - if past: + if past is not None or past_key_values is not None: context_length = seq.index(150004) last_token = input_ids[:, -1].unsqueeze(-1) if self.position_encoding_2d: @@ -975,6 +976,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): else: position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device) + if past is None: + past = past_key_values return { "input_ids": last_token, "past_key_values": past,