Fix decode method for torch tensor
This commit is contained in:
parent
fdb7a601d8
commit
23ad39b571
|
@ -253,29 +253,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
return seq
|
return seq
|
||||||
|
|
||||||
def decode(
|
def _decode(
|
||||||
self,
|
self,
|
||||||
token_ids: Union[List[int], List[List[int]]],
|
token_ids: Union[int, List[int]],
|
||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
clean_up_tokenization_spaces: bool = True,
|
clean_up_tokenization_spaces: bool = True,
|
||||||
spaces_between_special_tokens: bool = True,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
if not isinstance(token_ids, list):
|
if isinstance(token_ids, int):
|
||||||
token_ids = [token_ids]
|
token_ids = [token_ids]
|
||||||
if len(token_ids) == 0:
|
if len(token_ids) == 0:
|
||||||
return ""
|
return ""
|
||||||
if isinstance(token_ids[0], list):
|
if self.pad_token_id in token_ids: # remove pad
|
||||||
tokens = []
|
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
||||||
for single_token_ids in token_ids:
|
return self.sp_tokenizer.decode(token_ids)
|
||||||
if self.pad_token_id in single_token_ids: # remove pad
|
|
||||||
single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids))
|
|
||||||
tokens.append(self.sp_tokenizer.decode(single_token_ids))
|
|
||||||
return (tokens)
|
|
||||||
else:
|
|
||||||
if self.pad_token_id in token_ids: # remove pad
|
|
||||||
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
|
||||||
return self.sp_tokenizer.decode(token_ids)
|
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a token (str) in an id using the vocab. """
|
""" Converts a token (str) in an id using the vocab. """
|
||||||
|
|
Loading…
Reference in New Issue