Fix batch input

This commit is contained in:
duzx16 2023-04-01 22:38:53 +08:00
parent cc96a2271a
commit e1494f222d
1 changed files with 3 additions and 3 deletions

View File

@ -177,7 +177,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "ice_text.model"} vocab_files_names = {"vocab_file": "ice_text.model"}
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids"] model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__( def __init__(
self, self,
@ -397,7 +397,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present. # Initialize attention mask if not present.
if needs_to_be_padded or return_attention_mask: if return_attention_mask:
context_length = required_input.index(bos_token_id) context_length = required_input.index(bos_token_id)
attention_mask = np.ones((1, seq_length, seq_length)) attention_mask = np.ones((1, seq_length, seq_length))
attention_mask = np.tril(attention_mask) attention_mask = np.tril(attention_mask)
@ -405,7 +405,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
attention_mask = np.bool_(attention_mask < 0.5) attention_mask = np.bool_(attention_mask < 0.5)
encoded_inputs["attention_mask"] = attention_mask encoded_inputs["attention_mask"] = attention_mask
if needs_to_be_padded or return_attention_mask: if return_attention_mask:
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
mask_position = required_input.index(mask_token) mask_position = required_input.index(mask_token)
context_length = required_input.index(bos_token_id) context_length = required_input.index(bos_token_id)