Fix position id for training

This commit is contained in:
duzx16 2023-03-31 20:18:10 +08:00
parent 9c7416d834
commit 11c270c26c
1 changed files with 4 additions and 6 deletions

View File

@ -845,7 +845,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
if self.position_encoding_2d: if self.position_encoding_2d:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
if not gmask:
for i, context_length in enumerate(context_lengths): for i, context_length in enumerate(context_lengths):
position_ids[i, context_length:] = mask_positions[i] position_ids[i, context_length:] = mask_positions[i]
block_position_ids = [torch.cat(( block_position_ids = [torch.cat((
@ -1053,7 +1052,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
if self.position_encoding_2d: if self.position_encoding_2d:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
if not gmask:
for i, context_length in enumerate(context_lengths): for i, context_length in enumerate(context_lengths):
position_ids[i, context_length:] = mask_positions[i] position_ids[i, context_length:] = mask_positions[i]
block_position_ids = [torch.cat(( block_position_ids = [torch.cat((