fix typo in use_gmask (#21)
- fix typo in use_gmask (d6504255afdd555d12137fc3af04646f099b5785) Co-authored-by: Fan Zhang <fzhang@users.noreply.huggingface.co>
This commit is contained in:
parent
23ad39b571
commit
551a50efec
|
@ -923,7 +923,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
MASK, gMASK = 150000, 150001
|
MASK, gMASK = 150000, 150001
|
||||||
mask_token = MASK if MASK in input_ids else gMASK
|
mask_token = MASK if MASK in input_ids else gMASK
|
||||||
use_gmask = False if MASK in input_ids else gMASK
|
use_gmask = False if MASK in input_ids else True
|
||||||
|
|
||||||
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
||||||
position_ids = self.get_position_ids(
|
position_ids = self.get_position_ids(
|
||||||
|
@ -1086,7 +1086,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
MASK, gMASK = 150000, 150001
|
MASK, gMASK = 150000, 150001
|
||||||
mask_token = MASK if MASK in input_ids else gMASK
|
mask_token = MASK if MASK in input_ids else gMASK
|
||||||
use_gmask = False if MASK in input_ids else gMASK
|
use_gmask = False if MASK in input_ids else True
|
||||||
seqs = input_ids.tolist()
|
seqs = input_ids.tolist()
|
||||||
mask_positions = [seq.index(mask_token) for seq in seqs]
|
mask_positions = [seq.index(mask_token) for seq in seqs]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue