Fix context length in get_position_ids
This commit is contained in:
parent
4a9b711e61
commit
096f3de6b4
|
@ -769,7 +769,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
return attention_mask
|
||||
|
||||
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
||||
context_length = seq.index(self.config.bos_token_id) + 1
|
||||
context_length = len(seq)
|
||||
if self.position_encoding_2d:
|
||||
seq_length = seq.index(self.config.bos_token_id)
|
||||
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
||||
|
|
Loading…
Reference in New Issue