Fix context length in get_position_ids

This commit is contained in:
duzx16 2023-03-28 17:37:46 +08:00
parent 4a9b711e61
commit 096f3de6b4
1 changed files with 1 additions and 1 deletions

View File

@ -769,7 +769,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
return attention_mask
def get_position_ids(self, seq, mask_position, device, gmask=False):
context_length = seq.index(self.config.bos_token_id) + 1
context_length = len(seq)
if self.position_encoding_2d:
seq_length = seq.index(self.config.bos_token_id)
position_ids = torch.arange(context_length, dtype=torch.long, device=device)