diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index e3d9edf..1d4f0ba 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -326,22 +326,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer): Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ - mask_ids = self.sp_tokenizer[self.mask_token] - gmask_ids = self.sp_tokenizer[self.gmask_token] + gmask_id = self.sp_tokenizer[self.gmask_token] eos_id = self.sp_tokenizer[self.eos_token] - if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0: - token_ids_0 += [gmask_ids] - - if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids: - token_ids_0 += [self.sp_tokenizer[self.end_token]] - - token_ids_0 += [self.sp_tokenizer[self.bos_token]] - + token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] if token_ids_1 is not None: - if not token_ids_1 or token_ids_1[-1] != eos_id: - token_ids_1 += [eos_id] - token_ids_0 += token_ids_1 - + token_ids_0 = token_ids_0 + token_ids_1 + [eos_id] return token_ids_0 def _pad(