From 373fd6b9d484841b490856a5570d6c450f20c22c Mon Sep 17 00:00:00 2001 From: duzx16 Date: Sun, 2 Apr 2023 01:58:45 +0800 Subject: [PATCH] Fix attention_mask and position_ids --- tokenization_chatglm.py | 42 +++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index 9776728..ba86626 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -340,7 +340,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): token_ids_0 += [self.sp_tokenizer[self.bos_token]] if token_ids_1 is not None: - if token_ids_1[-1] != eop_id: + if not token_ids_1 or token_ids_1[-1] != eop_id: token_ids_1 += [eop_id] token_ids_0 += token_ids_1 @@ -397,26 +397,28 @@ class ChatGLMTokenizer(PreTrainedTokenizer): needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length # Initialize attention mask if not present. - if return_attention_mask: - if bos_token_id in required_input: - context_length = required_input.index(bos_token_id) - else: - context_length = seq_length - attention_mask = np.ones((1, seq_length, seq_length)) - attention_mask = np.tril(attention_mask) - attention_mask[:, :, :context_length] = 1 - attention_mask = np.bool_(attention_mask < 0.5) - encoded_inputs["attention_mask"] = attention_mask + if max_length is not None: + if "attention_mask" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + attention_mask = np.ones((1, seq_length, seq_length)) + attention_mask = np.tril(attention_mask) + attention_mask[:, :, :context_length] = 1 + attention_mask = np.bool_(attention_mask < 0.5) + encoded_inputs["attention_mask"] = attention_mask - position_ids = np.arange(seq_length, dtype=np.int64) - mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id - if mask_token in required_input: - mask_position = required_input.index(mask_token) - position_ids[context_length:] = mask_position - block_position_ids = np.concatenate( - [np.zeros(context_length, dtype=np.int64), - np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) - encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) + if "position_ids" not in encoded_inputs: + position_ids = np.arange(seq_length, dtype=np.int64) + mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id + if mask_token in required_input: + mask_position = required_input.index(mask_token) + position_ids[context_length:] = mask_position + block_position_ids = np.concatenate( + [np.zeros(context_length, dtype=np.int64), + np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) + encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) if needs_to_be_padded: difference = max_length - len(required_input)