From 096f3de6b4959ce38bef7bb05f3129c931a3084e Mon Sep 17 00:00:00 2001 From: duzx16 Date: Tue, 28 Mar 2023 17:37:46 +0800 Subject: [PATCH] Fix context length in get_position_ids --- modeling_chatglm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 65e378e..bed7e6f 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -769,7 +769,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): return attention_mask def get_position_ids(self, seq, mask_position, device, gmask=False): - context_length = seq.index(self.config.bos_token_id) + 1 + context_length = len(seq) if self.position_encoding_2d: seq_length = seq.index(self.config.bos_token_id) position_ids = torch.arange(context_length, dtype=torch.long, device=device)