diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 65e378e..bed7e6f 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -769,7 +769,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): return attention_mask def get_position_ids(self, seq, mask_position, device, gmask=False): - context_length = seq.index(self.config.bos_token_id) + 1 + context_length = len(seq) if self.position_encoding_2d: seq_length = seq.index(self.config.bos_token_id) position_ids = torch.arange(context_length, dtype=torch.long, device=device)