Update decode method in tokenizer
This commit is contained in:
parent
f6b88da8c1
commit
d8a6cfc6cb
|
@ -31,6 +31,9 @@ class TextTokenizer:
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
return self.sp.EncodeAsPieces(text)
|
return self.sp.EncodeAsPieces(text)
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
return self.sp.DecodePieces(tokens)
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens):
|
||||||
return [self.sp.PieceToId(token) for token in tokens]
|
return [self.sp.PieceToId(token) for token in tokens]
|
||||||
|
|
||||||
|
@ -111,16 +114,25 @@ class SPTokenizer:
|
||||||
tokens = [x + self.num_image_tokens for x in tmp]
|
tokens = [x + self.num_image_tokens for x in tmp]
|
||||||
return tokens if add_dummy_prefix else tokens[2:]
|
return tokens if add_dummy_prefix else tokens[2:]
|
||||||
|
|
||||||
def decode(self, text_ids: List[int]) -> str:
|
def postprocess(self, text):
|
||||||
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
|
||||||
ids = [_id for _id in ids if _id >= 0]
|
|
||||||
text = self._get_text_tokenizer().decode(ids)
|
|
||||||
text = text.replace("<n>", "\n")
|
text = text.replace("<n>", "\n")
|
||||||
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
||||||
for i in range(2, self.max_blank_length + 1):
|
for i in range(2, self.max_blank_length + 1):
|
||||||
text = text.replace(self.get_blank_token(i), " " * i)
|
text = text.replace(self.get_blank_token(i), " " * i)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def decode(self, text_ids: List[int]) -> str:
|
||||||
|
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
||||||
|
ids = [_id for _id in ids if _id >= 0]
|
||||||
|
text = self._get_text_tokenizer().decode(ids)
|
||||||
|
text = self.postprocess(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens: List[str]) -> str:
|
||||||
|
text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
|
||||||
|
text = self.postprocess(text)
|
||||||
|
return text
|
||||||
|
|
||||||
def tokenize(
|
def tokenize(
|
||||||
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -256,11 +268,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
return seq
|
return seq
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
return self.sp_tokenizer.decode_tokens(tokens)
|
||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
self,
|
self,
|
||||||
token_ids: Union[int, List[int]],
|
token_ids: Union[int, List[int]],
|
||||||
skip_special_tokens: bool = False,
|
|
||||||
clean_up_tokenization_spaces: bool = True,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(token_ids, int):
|
if isinstance(token_ids, int):
|
||||||
|
@ -269,7 +282,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
return ""
|
return ""
|
||||||
if self.pad_token_id in token_ids: # remove pad
|
if self.pad_token_id in token_ids: # remove pad
|
||||||
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
||||||
return self.sp_tokenizer.decode(token_ids)
|
return super()._decode(token_ids, **kwargs)
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a token (str) in an id using the vocab. """
|
""" Converts a token (str) in an id using the vocab. """
|
||||||
|
|
Loading…
Reference in New Issue