Fix decode method for torch tensor

This commit is contained in:
duzx16 2023-04-05 18:26:09 +08:00
parent fdb7a601d8
commit 23ad39b571
1 changed files with 6 additions and 15 deletions

View File

@ -253,26 +253,17 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return seq return seq
def decode( def _decode(
self, self,
token_ids: Union[List[int], List[List[int]]], token_ids: Union[int, List[int]],
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True, clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs **kwargs
) -> str: ) -> str:
if not isinstance(token_ids, list): if isinstance(token_ids, int):
token_ids = [token_ids] token_ids = [token_ids]
if len(token_ids) == 0: if len(token_ids) == 0:
return "" return ""
if isinstance(token_ids[0], list):
tokens = []
for single_token_ids in token_ids:
if self.pad_token_id in single_token_ids: # remove pad
single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids))
tokens.append(self.sp_tokenizer.decode(single_token_ids))
return (tokens)
else:
if self.pad_token_id in token_ids: # remove pad if self.pad_token_id in token_ids: # remove pad
token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
return self.sp_tokenizer.decode(token_ids) return self.sp_tokenizer.decode(token_ids)