Use gmask in first place

This commit is contained in:
duzx16 2023-04-06 23:25:10 +08:00
parent d467effe91
commit 9324de70a9
1 changed files with 4 additions and 4 deletions

View File

@ -922,8 +922,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if position_ids is None: if position_ids is None:
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
mask_token = MASK if MASK in input_ids else gMASK mask_token = gMASK if gMASK in input_ids else MASK
use_gmask = False if MASK in input_ids else True use_gmask = True if gMASK in input_ids else False
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids] mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
position_ids = self.get_position_ids( position_ids = self.get_position_ids(
@ -1085,8 +1085,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
) -> dict: ) -> dict:
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
mask_token = MASK if MASK in input_ids else gMASK mask_token = gMASK if gMASK in input_ids else MASK
use_gmask = False if MASK in input_ids else True use_gmask = True if gMASK in input_ids else False
seqs = input_ids.tolist() seqs = input_ids.tolist()
mask_positions = [seq.index(mask_token) for seq in seqs] mask_positions = [seq.index(mask_token) for seq in seqs]