diff --git a/ice_text.model b/ice_text.model index c5aa32f..0dcfe31 100644 --- a/ice_text.model +++ b/ice_text.model @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:99871e0c85db81ad7af1028854fd091cd5778c8414ae9d94bbbc10d02c831c21 -size 2699926 +oid sha256:5e974d9a69c242ce014c88c2b26089270f6198f3c0b700a887666cd3e816f17e +size 2706249 diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 554d7f8..4bc0092 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -923,7 +923,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): if position_ids is None: MASK, gMASK = 150000, 150001 mask_token = MASK if MASK in input_ids else gMASK - use_gmask = False if MASK in input_ids else gMASK + use_gmask = False if MASK in input_ids else True mask_positions = [seq.tolist().index(mask_token) for seq in input_ids] position_ids = self.get_position_ids( @@ -1086,7 +1086,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): batch_size, seq_length = input_ids.shape MASK, gMASK = 150000, 150001 mask_token = MASK if MASK in input_ids else gMASK - use_gmask = False if MASK in input_ids else gMASK + use_gmask = False if MASK in input_ids else True seqs = input_ids.tolist() mask_positions = [seq.index(mask_token) for seq in seqs] diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index 3062c7c..2808c04 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -3,11 +3,10 @@ from typing import List, Optional, Union import os from transformers.tokenization_utils import PreTrainedTokenizer -from icetk.text_tokenizer import TextTokenizer -import icetk.sentencepiece_model_pb2 as sp_model from transformers.utils import logging, PaddingStrategy from transformers.tokenization_utils_base import EncodedInput, BatchEncoding from typing import Dict +import sentencepiece as spm import numpy as np logger = logging.get_logger(__name__) @@ -17,61 +16,50 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { } +class TextTokenizer: + def __init__(self, model_path): + self.sp = spm.SentencePieceProcessor() + self.sp.Load(model_path) + self.num_tokens = self.sp.vocab_size() + + def encode(self, text): + return self.sp.EncodeAsIds(text) + + def decode(self, ids: List[int]): + return self.sp.DecodeIds(ids) + + def tokenize(self, text): + return self.sp.EncodeAsPieces(text) + + def convert_tokens_to_ids(self, tokens): + return [self.sp.PieceToId(token) for token in tokens] + + def convert_token_to_id(self, token): + return self.sp.PieceToId(token) + + def convert_id_to_token(self, idx): + return self.sp.IdToPiece(idx) + + def __len__(self): + return self.num_tokens + + class SPTokenizer: def __init__( - self, - vocab_file, - max_blank_length=80, - byte_fallback=True, + self, + vocab_file, + max_blank_length=80, + byte_fallback=True, ): assert vocab_file is not None self.vocab_file = vocab_file self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] self.max_blank_length = max_blank_length self.byte_fallback = byte_fallback - self.text_tokenizer = self._build_text_tokenizer(encode_special_tokens=False) - self.special_text_tokenizer = self._build_text_tokenizer(encode_special_tokens=True) + self.text_tokenizer = TextTokenizer(vocab_file) - @staticmethod - def _configure_tokenizer( - text_tokenizer: TextTokenizer, - special_tokens: List[str], - max_blank_length: int, - byte_fallback: bool, - encode_special_tokens=False, - ): - # special token - special_token_type = 4 if encode_special_tokens else 3 # 3 - CONTROL, 4 - USER_DEFINE - for token in special_tokens: - text_tokenizer.proto.pieces.append( - sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=special_token_type) - ) - # whitespaces - for token in [SPTokenizer.get_tab_token()] + [ - SPTokenizer.get_blank_token(i) for i in range(2, max_blank_length + 1) - ]: - text_tokenizer.proto.pieces.append(sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=4)) - # byte fallback - if byte_fallback: - text_tokenizer.proto.trainer_spec.byte_fallback = True - for i in range(256): - text_tokenizer.proto.pieces.append( - sp_model.ModelProto.SentencePiece(piece="<0x{:02X}>".format(i), score=0.0, type=6) - ) - text_tokenizer.refresh() - - def _build_text_tokenizer(self, encode_special_tokens=False): - tokenizer = TextTokenizer(self.vocab_file) - self._configure_tokenizer( - tokenizer, self.special_tokens, self.max_blank_length, self.byte_fallback, encode_special_tokens - ) - return tokenizer - - def _get_text_tokenizer(self, encode_special_tokens=False): - if encode_special_tokens: - return self.special_text_tokenizer - else: - return self.text_tokenizer + def _get_text_tokenizer(self): + return self.text_tokenizer @staticmethod def get_blank_token(length: int): @@ -109,7 +97,7 @@ class SPTokenizer: return text def encode( - self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True + self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True ) -> List[int]: """ @param text: Text to encode. @@ -121,14 +109,14 @@ class SPTokenizer: text = self._preprocess(text, linebreak, whitespaces) if not add_dummy_prefix: text = "" + text - tmp = self._get_text_tokenizer(encode_special_tokens=special_tokens).encode(text) + tmp = self._get_text_tokenizer().encode(text) tokens = [x + self.num_image_tokens for x in tmp] return tokens if add_dummy_prefix else tokens[2:] - def decode(self, text_ids: List[int], special_tokens=False) -> str: + def decode(self, text_ids: List[int]) -> str: ids = [int(_id) - self.num_image_tokens for _id in text_ids] ids = [_id for _id in ids if _id >= 0] - text = self._get_text_tokenizer(encode_special_tokens=special_tokens).decode(ids) + text = self._get_text_tokenizer().decode(ids) text = text.replace("", "\n") text = text.replace(SPTokenizer.get_tab_token(), "\t") for i in range(2, self.max_blank_length + 1): @@ -136,7 +124,7 @@ class SPTokenizer: return text def tokenize( - self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True + self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True ) -> List[str]: """ @param text: Text to encode. @@ -148,7 +136,7 @@ class SPTokenizer: text = self._preprocess(text, linebreak, whitespaces) if not add_dummy_prefix: text = "" + text - tokens = self._get_text_tokenizer(encode_special_tokens=special_tokens).tokenize(text) + tokens = self._get_text_tokenizer().tokenize(text) return tokens if add_dummy_prefix else tokens[2:] def __getitem__(self, x: Union[int, str]): @@ -253,25 +241,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer): return seq - def decode( + def _decode( self, - token_ids: Union[List[int], List[List[int]]], + token_ids: Union[int, List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, - spaces_between_special_tokens: bool = True, **kwargs ) -> str: - if isinstance(token_ids[0], list): - tokens = [] - for single_token_ids in token_ids: - if self.pad_token_id in single_token_ids: # remove pad - single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids)) - tokens.append(self.sp_tokenizer.decode(single_token_ids)) - return (tokens) - else: - if self.pad_token_id in token_ids: # remove pad - token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) - return self.sp_tokenizer.decode(token_ids) + if isinstance(token_ids, int): + token_ids = [token_ids] + if len(token_ids) == 0: + return "" + if self.pad_token_id in token_ids: # remove pad + token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) + return self.sp_tokenizer.decode(token_ids) def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ @@ -347,12 +330,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer): return token_ids_0 def _pad( - self, - encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], - max_length: Optional[int] = None, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, ) -> dict: """ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)