Fix decode method for torch tensor
This commit is contained in:
parent
fdb7a601d8
commit
23ad39b571
|
@ -253,29 +253,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
|
||||
return seq
|
||||
|
||||
def decode(
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: Union[List[int], List[List[int]]],
|
||||
token_ids: Union[int, List[int]],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
**kwargs
|
||||
) -> str:
|
||||
if not isinstance(token_ids, list):
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
if len(token_ids) == 0:
|
||||
return ""
|
||||
if isinstance(token_ids[0], list):
|
||||
tokens = []
|
||||
for single_token_ids in token_ids:
|
||||
if self.pad_token_id in single_token_ids: # remove pad
|
||||
single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids))
|
||||
tokens.append(self.sp_tokenizer.decode(single_token_ids))
|
||||
return (tokens)
|
||||
else:
|
||||
if self.pad_token_id in token_ids: # remove pad
|
||||
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
||||
return self.sp_tokenizer.decode(token_ids)
|
||||
if self.pad_token_id in token_ids: # remove pad
|
||||
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
||||
return self.sp_tokenizer.decode(token_ids)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
|
|
Loading…
Reference in New Issue