From e22cddf21240ad67992cf3cb03ee32b83e7b716a Mon Sep 17 00:00:00 2001 From: duzx16 Date: Sun, 2 Apr 2023 01:04:40 +0800 Subject: [PATCH] Fix encode method --- tokenization_chatglm.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index a7d1ca1..9776728 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -398,19 +398,21 @@ class ChatGLMTokenizer(PreTrainedTokenizer): # Initialize attention mask if not present. if return_attention_mask: - context_length = required_input.index(bos_token_id) + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length attention_mask = np.ones((1, seq_length, seq_length)) attention_mask = np.tril(attention_mask) attention_mask[:, :, :context_length] = 1 attention_mask = np.bool_(attention_mask < 0.5) encoded_inputs["attention_mask"] = attention_mask - if return_attention_mask: - mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id - mask_position = required_input.index(mask_token) - context_length = required_input.index(bos_token_id) position_ids = np.arange(seq_length, dtype=np.int64) - position_ids[context_length:] = mask_position + mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id + if mask_token in required_input: + mask_position = required_input.index(mask_token) + position_ids[context_length:] = mask_position block_position_ids = np.concatenate( [np.zeros(context_length, dtype=np.int64), np.arange(1, seq_length - context_length + 1, dtype=np.int64)])