diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 9b344d0..77c3bdd 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -923,7 +923,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): if position_ids is None: MASK, gMASK = 150000, 150001 mask_token = MASK if MASK in input_ids else gMASK - use_gmask = False if MASK in input_ids else gMASK + use_gmask = False if MASK in input_ids else True mask_positions = [seq.tolist().index(mask_token) for seq in input_ids] position_ids = self.get_position_ids( @@ -1086,7 +1086,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): batch_size, seq_length = input_ids.shape MASK, gMASK = 150000, 150001 mask_token = MASK if MASK in input_ids else gMASK - use_gmask = False if MASK in input_ids else gMASK + use_gmask = False if MASK in input_ids else True seqs = input_ids.tolist() mask_positions = [seq.index(mask_token) for seq in seqs]