Change mask positions to batch

This commit is contained in:
duzx16 2023-04-14 15:54:43 +08:00
parent 3a99d7951d
commit 4de8efebc8
1 changed files with 21 additions and 11 deletions

View File

@ -689,8 +689,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
return attention_mask return attention_mask
def get_position_ids(self, input_ids, mask_positions, device, gmask=False): def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
if use_gmasks is None:
use_gmasks = [False] * batch_size
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
if self.position_encoding_2d: if self.position_encoding_2d:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
@ -704,8 +706,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
position_ids = torch.stack((position_ids, block_position_ids), dim=1) position_ids = torch.stack((position_ids, block_position_ids), dim=1)
else: else:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
if not gmask: for i, context_length in enumerate(context_lengths):
for i, context_length in enumerate(context_lengths): if not use_gmasks[i]:
position_ids[context_length:] = mask_positions[i] position_ids[context_length:] = mask_positions[i]
return position_ids return position_ids
@ -939,15 +941,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if position_ids is None: if position_ids is None:
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
mask_token = gMASK if gMASK in input_ids else MASK seqs = input_ids.tolist()
use_gmask = True if gMASK in input_ids else False
mask_positions, use_gmasks = [], []
for seq in seqs:
mask_token = gMASK if gMASK in seq else MASK
use_gmask = mask_token == gMASK
mask_positions.append(seq.index(mask_token))
use_gmasks.append(use_gmask)
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
position_ids = self.get_position_ids( position_ids = self.get_position_ids(
input_ids, input_ids,
mask_positions=mask_positions, mask_positions=mask_positions,
device=input_ids.device, device=input_ids.device,
gmask=use_gmask use_gmasks=use_gmasks
) )
if self.pre_seq_len is not None and attention_mask is not None: if self.pre_seq_len is not None and attention_mask is not None:
@ -1106,10 +1113,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
) -> dict: ) -> dict:
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
mask_token = gMASK if gMASK in input_ids else MASK
use_gmask = True if gMASK in input_ids else False
seqs = input_ids.tolist() seqs = input_ids.tolist()
mask_positions = [seq.index(mask_token) for seq in seqs] mask_positions, use_gmasks = [], []
for seq in seqs:
mask_token = gMASK if gMASK in seq else MASK
use_gmask = mask_token == gMASK
mask_positions.append(seq.index(mask_token))
use_gmasks.append(use_gmask)
# only last token for input_ids if past is not None # only last token for input_ids if past is not None
if past is not None or past_key_values is not None: if past is not None or past_key_values is not None:
@ -1152,7 +1162,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
input_ids, input_ids,
device=input_ids.device, device=input_ids.device,
mask_positions=mask_positions, mask_positions=mask_positions,
gmask=use_gmask use_gmasks=use_gmasks
) )
return { return {