diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index 956746c..a7d1ca1 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -177,7 +177,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "ice_text.model"} max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids"] + model_input_names = ["input_ids", "attention_mask", "position_ids"] def __init__( self, @@ -397,7 +397,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length # Initialize attention mask if not present. - if needs_to_be_padded or return_attention_mask: + if return_attention_mask: context_length = required_input.index(bos_token_id) attention_mask = np.ones((1, seq_length, seq_length)) attention_mask = np.tril(attention_mask) @@ -405,7 +405,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): attention_mask = np.bool_(attention_mask < 0.5) encoded_inputs["attention_mask"] = attention_mask - if needs_to_be_padded or return_attention_mask: + if return_attention_mask: mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id mask_position = required_input.index(mask_token) context_length = required_input.index(bos_token_id)