diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index caa33d8..0583e9a 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -253,29 +253,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer): return seq - def decode( + def _decode( self, - token_ids: Union[List[int], List[List[int]]], + token_ids: Union[int, List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, - spaces_between_special_tokens: bool = True, **kwargs ) -> str: - if not isinstance(token_ids, list): + if isinstance(token_ids, int): token_ids = [token_ids] if len(token_ids) == 0: return "" - if isinstance(token_ids[0], list): - tokens = [] - for single_token_ids in token_ids: - if self.pad_token_id in single_token_ids: # remove pad - single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids)) - tokens.append(self.sp_tokenizer.decode(single_token_ids)) - return (tokens) - else: - if self.pad_token_id in token_ids: # remove pad - token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) - return self.sp_tokenizer.decode(token_ids) + if self.pad_token_id in token_ids: # remove pad + token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) + return self.sp_tokenizer.decode(token_ids) def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """