diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index 619a8c7..aedbcbe 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -130,6 +130,7 @@ class SPTokenizer: def decode(self, text_ids: List[int], special_tokens=False) -> str: ids = [int(_id) - self.num_image_tokens for _id in text_ids] + ids = [_id for _id in ids if _id >= 0] text = self._get_text_tokenizer(encode_special_tokens=special_tokens).decode(ids) text = text.replace("", "\n") text = text.replace(SPTokenizer.get_tab_token(), "\t")