Fix gmask
This commit is contained in:
parent
9333486c30
commit
3485994337
|
@ -689,8 +689,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
if use_gmasks is None:
|
||||||
|
use_gmasks = [False] * batch_size
|
||||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||||
|
@ -704,8 +706,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
||||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||||
else:
|
else:
|
||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||||
if not gmask:
|
|
||||||
for i, context_length in enumerate(context_lengths):
|
for i, context_length in enumerate(context_lengths):
|
||||||
|
if not use_gmasks[i]:
|
||||||
position_ids[context_length:] = mask_positions[i]
|
position_ids[context_length:] = mask_positions[i]
|
||||||
|
|
||||||
return position_ids
|
return position_ids
|
||||||
|
@ -939,15 +941,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
||||||
mask_token = gMASK if gMASK in input_ids else MASK
|
seqs = input_ids.tolist()
|
||||||
use_gmask = True if gMASK in input_ids else False
|
|
||||||
|
mask_positions, use_gmasks = [], []
|
||||||
|
for seq in seqs:
|
||||||
|
mask_token = gMASK if gMASK in seq else MASK
|
||||||
|
use_gmask = mask_token == gMASK
|
||||||
|
mask_positions.append(seq.index(mask_token))
|
||||||
|
use_gmasks.append(use_gmask)
|
||||||
|
|
||||||
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
|
||||||
position_ids = self.get_position_ids(
|
position_ids = self.get_position_ids(
|
||||||
input_ids,
|
input_ids,
|
||||||
mask_positions=mask_positions,
|
mask_positions=mask_positions,
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
gmask=use_gmask
|
use_gmasks=use_gmasks
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pre_seq_len is not None and attention_mask is not None:
|
if self.pre_seq_len is not None and attention_mask is not None:
|
||||||
|
@ -1106,10 +1113,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
) -> dict:
|
) -> dict:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
||||||
mask_token = gMASK if gMASK in input_ids else MASK
|
|
||||||
use_gmask = True if gMASK in input_ids else False
|
|
||||||
seqs = input_ids.tolist()
|
seqs = input_ids.tolist()
|
||||||
mask_positions = [seq.index(mask_token) for seq in seqs]
|
mask_positions, use_gmasks = [], []
|
||||||
|
for seq in seqs:
|
||||||
|
mask_token = gMASK if gMASK in seq else MASK
|
||||||
|
use_gmask = mask_token == gMASK
|
||||||
|
mask_positions.append(seq.index(mask_token))
|
||||||
|
use_gmasks.append(use_gmask)
|
||||||
|
|
||||||
# only last token for input_ids if past is not None
|
# only last token for input_ids if past is not None
|
||||||
if past is not None or past_key_values is not None:
|
if past is not None or past_key_values is not None:
|
||||||
|
@ -1152,7 +1162,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
input_ids,
|
input_ids,
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
mask_positions=mask_positions,
|
mask_positions=mask_positions,
|
||||||
gmask=use_gmask
|
use_gmasks=use_gmasks
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -176,6 +176,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
mask_token='[MASK]',
|
mask_token='[MASK]',
|
||||||
gmask_token='[gMASK]',
|
gmask_token='[gMASK]',
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
|
pad_token="<pad>",
|
||||||
|
unk_token="<unk>",
|
||||||
num_image_tokens=20000,
|
num_image_tokens=20000,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -188,6 +190,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
end_token=end_token,
|
end_token=end_token,
|
||||||
mask_token=mask_token,
|
mask_token=mask_token,
|
||||||
gmask_token=gmask_token,
|
gmask_token=gmask_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
unk_token=unk_token,
|
||||||
num_image_tokens=num_image_tokens,
|
num_image_tokens=num_image_tokens,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
@ -322,22 +326,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
Returns:
|
Returns:
|
||||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||||
"""
|
"""
|
||||||
mask_ids = self.sp_tokenizer[self.mask_token]
|
gmask_id = self.sp_tokenizer[self.gmask_token]
|
||||||
gmask_ids = self.sp_tokenizer[self.gmask_token]
|
|
||||||
eos_id = self.sp_tokenizer[self.eos_token]
|
eos_id = self.sp_tokenizer[self.eos_token]
|
||||||
if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
|
token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
|
||||||
token_ids_0 += [gmask_ids]
|
|
||||||
|
|
||||||
if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids:
|
|
||||||
token_ids_0 += [self.sp_tokenizer[self.end_token]]
|
|
||||||
|
|
||||||
token_ids_0 += [self.sp_tokenizer[self.bos_token]]
|
|
||||||
|
|
||||||
if token_ids_1 is not None:
|
if token_ids_1 is not None:
|
||||||
if not token_ids_1 or token_ids_1[-1] != eos_id:
|
token_ids_0 = token_ids_0 + token_ids_1 + [eos_id]
|
||||||
token_ids_1 += [eos_id]
|
|
||||||
token_ids_0 += token_ids_1
|
|
||||||
|
|
||||||
return token_ids_0
|
return token_ids_0
|
||||||
|
|
||||||
def _pad(
|
def _pad(
|
||||||
|
@ -402,6 +395,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
encoded_inputs["attention_mask"] = attention_mask
|
encoded_inputs["attention_mask"] = attention_mask
|
||||||
|
|
||||||
if "position_ids" not in encoded_inputs:
|
if "position_ids" not in encoded_inputs:
|
||||||
|
if bos_token_id in required_input:
|
||||||
|
context_length = required_input.index(bos_token_id)
|
||||||
|
else:
|
||||||
|
context_length = seq_length
|
||||||
position_ids = np.arange(seq_length, dtype=np.int64)
|
position_ids = np.arange(seq_length, dtype=np.int64)
|
||||||
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
||||||
if mask_token in required_input:
|
if mask_token in required_input:
|
||||||
|
|
Loading…
Reference in New Issue