From 11c270c26ca79a82b50de0456f0193c8ff4d2171 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Fri, 31 Mar 2023 20:18:10 +0800 Subject: [PATCH] Fix position id for training --- modeling_chatglm.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 7f0f0e1..fcbef2d 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -845,9 +845,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel): context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] if self.position_encoding_2d: position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) - if not gmask: - for i, context_length in enumerate(context_lengths): - position_ids[i, context_length:] = mask_positions[i] + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] block_position_ids = [torch.cat(( torch.zeros(context_length, dtype=torch.long, device=device), torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 @@ -1053,9 +1052,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] if self.position_encoding_2d: position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) - if not gmask: - for i, context_length in enumerate(context_lengths): - position_ids[i, context_length:] = mask_positions[i] + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] block_position_ids = [torch.cat(( torch.zeros(context_length, dtype=torch.long, device=device), torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1