Use gmask in first place
This commit is contained in:
parent
d467effe91
commit
9324de70a9
|
@ -922,8 +922,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
||||||
mask_token = MASK if MASK in input_ids else gMASK
|
mask_token = gMASK if gMASK in input_ids else MASK
|
||||||
use_gmask = False if MASK in input_ids else True
|
use_gmask = True if gMASK in input_ids else False
|
||||||
|
|
||||||
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
||||||
position_ids = self.get_position_ids(
|
position_ids = self.get_position_ids(
|
||||||
|
@ -1085,8 +1085,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
) -> dict:
|
) -> dict:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
||||||
mask_token = MASK if MASK in input_ids else gMASK
|
mask_token = gMASK if gMASK in input_ids else MASK
|
||||||
use_gmask = False if MASK in input_ids else True
|
use_gmask = True if gMASK in input_ids else False
|
||||||
seqs = input_ids.tolist()
|
seqs = input_ids.tolist()
|
||||||
mask_positions = [seq.index(mask_token) for seq in seqs]
|
mask_positions = [seq.index(mask_token) for seq in seqs]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue