Fix decode method for torch tensor

This commit is contained in:
duzx16 2023-04-05 18:26:09 +08:00
parent fdb7a601d8
commit 23ad39b571
1 changed files with 6 additions and 15 deletions

View File

@ -253,29 +253,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return seq return seq
def decode( def _decode(
self, self,
token_ids: Union[List[int], List[List[int]]], token_ids: Union[int, List[int]],
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True, clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs **kwargs
) -> str: ) -> str:
if not isinstance(token_ids, list): if isinstance(token_ids, int):
token_ids = [token_ids] token_ids = [token_ids]
if len(token_ids) == 0: if len(token_ids) == 0:
return "" return ""
if isinstance(token_ids[0], list): if self.pad_token_id in token_ids: # remove pad
tokens = [] token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
for single_token_ids in token_ids: return self.sp_tokenizer.decode(token_ids)
if self.pad_token_id in single_token_ids: # remove pad
single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids))
tokens.append(self.sp_tokenizer.decode(single_token_ids))
return (tokens)
else:
if self.pad_token_id in token_ids: # remove pad
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
return self.sp_tokenizer.decode(token_ids)
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """