From 9324de70a93207c9a310cf99d5d6261791489691 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 6 Apr 2023 23:25:10 +0800 Subject: [PATCH] Use gmask in first place --- modeling_chatglm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 4887139..8ee5cf1 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -922,8 +922,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel): if position_ids is None: MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - mask_token = MASK if MASK in input_ids else gMASK - use_gmask = False if MASK in input_ids else True + mask_token = gMASK if gMASK in input_ids else MASK + use_gmask = True if gMASK in input_ids else False mask_positions = [seq.tolist().index(mask_token) for seq in input_ids] position_ids = self.get_position_ids( @@ -1085,8 +1085,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ) -> dict: batch_size, seq_length = input_ids.shape MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - mask_token = MASK if MASK in input_ids else gMASK - use_gmask = False if MASK in input_ids else True + mask_token = gMASK if gMASK in input_ids else MASK + use_gmask = True if gMASK in input_ids else False seqs = input_ids.tolist() mask_positions = [seq.index(mask_token) for seq in seqs]