From 3485994337c23d60a30f90407da4d5ba53ace76e Mon Sep 17 00:00:00 2001 From: duzx16 Date: Fri, 14 Apr 2023 15:57:11 +0800 Subject: [PATCH] Fix gmask --- modeling_chatglm.py | 32 +++++++++++++++++++++----------- tokenization_chatglm.py | 25 +++++++++++-------------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 49798d5..52ac051 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -689,8 +689,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel): return attention_mask - def get_position_ids(self, input_ids, mask_positions, device, gmask=False): + def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): batch_size, seq_length = input_ids.shape + if use_gmasks is None: + use_gmasks = [False] * batch_size context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] if self.position_encoding_2d: position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) @@ -704,8 +706,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel): position_ids = torch.stack((position_ids, block_position_ids), dim=1) else: position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - if not gmask: - for i, context_length in enumerate(context_lengths): + for i, context_length in enumerate(context_lengths): + if not use_gmasks[i]: position_ids[context_length:] = mask_positions[i] return position_ids @@ -939,15 +941,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel): if position_ids is None: MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - mask_token = gMASK if gMASK in input_ids else MASK - use_gmask = True if gMASK in input_ids else False + seqs = input_ids.tolist() + + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) - mask_positions = [seq.tolist().index(mask_token) for seq in input_ids] position_ids = self.get_position_ids( input_ids, mask_positions=mask_positions, device=input_ids.device, - gmask=use_gmask + use_gmasks=use_gmasks ) if self.pre_seq_len is not None and attention_mask is not None: @@ -1106,10 +1113,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ) -> dict: batch_size, seq_length = input_ids.shape MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - mask_token = gMASK if gMASK in input_ids else MASK - use_gmask = True if gMASK in input_ids else False seqs = input_ids.tolist() - mask_positions = [seq.index(mask_token) for seq in seqs] + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) # only last token for input_ids if past is not None if past is not None or past_key_values is not None: @@ -1152,7 +1162,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): input_ids, device=input_ids.device, mask_positions=mask_positions, - gmask=use_gmask + use_gmasks=use_gmasks ) return { diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index 2138987..1d4f0ba 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -176,6 +176,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer): mask_token='[MASK]', gmask_token='[gMASK]', padding_side="left", + pad_token="", + unk_token="", num_image_tokens=20000, **kwargs ) -> None: @@ -188,6 +190,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer): end_token=end_token, mask_token=mask_token, gmask_token=gmask_token, + pad_token=pad_token, + unk_token=unk_token, num_image_tokens=num_image_tokens, **kwargs ) @@ -322,22 +326,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer): Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ - mask_ids = self.sp_tokenizer[self.mask_token] - gmask_ids = self.sp_tokenizer[self.gmask_token] + gmask_id = self.sp_tokenizer[self.gmask_token] eos_id = self.sp_tokenizer[self.eos_token] - if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0: - token_ids_0 += [gmask_ids] - - if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids: - token_ids_0 += [self.sp_tokenizer[self.end_token]] - - token_ids_0 += [self.sp_tokenizer[self.bos_token]] - + token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] if token_ids_1 is not None: - if not token_ids_1 or token_ids_1[-1] != eos_id: - token_ids_1 += [eos_id] - token_ids_0 += token_ids_1 - + token_ids_0 = token_ids_0 + token_ids_1 + [eos_id] return token_ids_0 def _pad( @@ -402,6 +395,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer): encoded_inputs["attention_mask"] = attention_mask if "position_ids" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length position_ids = np.arange(seq_length, dtype=np.int64) mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id if mask_token in required_input: