fix typo in use_gmask (#21)

- fix typo in use_gmask (d6504255afdd555d12137fc3af04646f099b5785)


Co-authored-by: Fan Zhang <fzhang@users.noreply.huggingface.co>
This commit is contained in:
Zhengxiao Du 2023-04-05 11:11:33 +00:00 committed by system
parent 23ad39b571
commit 551a50efec
1 changed files with 2 additions and 2 deletions

View File

@ -923,7 +923,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if position_ids is None: if position_ids is None:
MASK, gMASK = 150000, 150001 MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK use_gmask = False if MASK in input_ids else True
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids] mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
position_ids = self.get_position_ids( position_ids = self.get_position_ids(
@ -1086,7 +1086,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
MASK, gMASK = 150000, 150001 MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK use_gmask = False if MASK in input_ids else True
seqs = input_ids.tolist() seqs = input_ids.tolist()
mask_positions = [seq.index(mask_token) for seq in seqs] mask_positions = [seq.index(mask_token) for seq in seqs]