Fix context length in get_position_ids
This commit is contained in:
parent
4a9b711e61
commit
096f3de6b4
|
@ -769,7 +769,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
||||||
context_length = seq.index(self.config.bos_token_id) + 1
|
context_length = len(seq)
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
seq_length = seq.index(self.config.bos_token_id)
|
seq_length = seq.index(self.config.bos_token_id)
|
||||||
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
||||||
|
|
Loading…
Reference in New Issue