diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index 1d4f0ba..69ee85c 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -31,6 +31,9 @@ class TextTokenizer: def tokenize(self, text): return self.sp.EncodeAsPieces(text) + def convert_tokens_to_string(self, tokens): + return self.sp.DecodePieces(tokens) + def convert_tokens_to_ids(self, tokens): return [self.sp.PieceToId(token) for token in tokens] @@ -111,16 +114,25 @@ class SPTokenizer: tokens = [x + self.num_image_tokens for x in tmp] return tokens if add_dummy_prefix else tokens[2:] - def decode(self, text_ids: List[int]) -> str: - ids = [int(_id) - self.num_image_tokens for _id in text_ids] - ids = [_id for _id in ids if _id >= 0] - text = self._get_text_tokenizer().decode(ids) + def postprocess(self, text): text = text.replace("", "\n") text = text.replace(SPTokenizer.get_tab_token(), "\t") for i in range(2, self.max_blank_length + 1): text = text.replace(self.get_blank_token(i), " " * i) return text + def decode(self, text_ids: List[int]) -> str: + ids = [int(_id) - self.num_image_tokens for _id in text_ids] + ids = [_id for _id in ids if _id >= 0] + text = self._get_text_tokenizer().decode(ids) + text = self.postprocess(text) + return text + + def decode_tokens(self, tokens: List[str]) -> str: + text = self._get_text_tokenizer().convert_tokens_to_string(tokens) + text = self.postprocess(text) + return text + def tokenize( self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True ) -> List[str]: @@ -256,11 +268,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer): return seq + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.sp_tokenizer.decode_tokens(tokens) + def _decode( self, token_ids: Union[int, List[int]], - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: bool = True, **kwargs ) -> str: if isinstance(token_ids, int): @@ -269,7 +282,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): return "" if self.pad_token_id in token_ids: # remove pad token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) - return self.sp_tokenizer.decode(token_ids) + return super()._decode(token_ids, **kwargs) def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """