Compare commits

...

5 Commits
v0.1.0 ... main

Author SHA1 Message Date
Zhengxiao Du 02a065cf27 Upload pytorch_model.bin 2023-05-15 12:41:28 +00:00
Zhengxiao Du e214c5b71d Update slack link 2023-05-12 13:49:56 +00:00
duzx16 d8a6cfc6cb Update decode method in tokenizer 2023-05-09 11:32:40 +08:00
duzx16 f6b88da8c1 Add support for parallel quantization on Mac 2023-05-04 21:45:31 +02:00
duzx16 63d66b0572 Remove assert in load_cpu_kernel 2023-04-29 10:34:45 +08:00
4 changed files with 31 additions and 15 deletions

View File

@ -9,7 +9,7 @@ tags:
--- ---
# ChatGLM-6B-INT4 # ChatGLM-6B-INT4
<p align="center"> <p align="center">
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1th2q5u69-7tURzFuOPanmuHy9hsZnKA" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a> 👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1udqapmrr-ocT1DS_mxWe6dDY8ahRWzg" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
</p> </p>
## 介绍 ## 介绍

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:35828b49cf23cbae4c27788d4b04fc68c79a276300e09f14d72a49b0b738b4a9 oid sha256:245786435bde9f4593c105ea846fa461fe42bc63c12b738d0272fcaed6276645
size 3893083075 size 3893083075

View File

@ -7,6 +7,7 @@ import bz2
import torch import torch
import base64 import base64
import ctypes import ctypes
import sys
from transformers.utils import logging from transformers.utils import logging
from typing import List from typing import List
@ -142,8 +143,12 @@ class CPUKernel:
kernel_file = source_code[:-2] + ".so" kernel_file = source_code[:-2] + ".so"
if compile_parallel_kernel: if compile_parallel_kernel:
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format( if sys.platform != 'darwin':
source_code, kernel_file) compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(
source_code, kernel_file)
else:
compile_command = "clang -O3 -fPIC -pthread -Xclang -fopenmp -lomp -std=c99 {} -shared -o {}".format(
source_code, kernel_file)
print("Compiling", compile_command) print("Compiling", compile_command)
exit_state = os.system(compile_command) exit_state = os.system(compile_command)
if not exit_state: if not exit_state:
@ -442,7 +447,6 @@ class QuantizedEmbedding(Embedding): # TODO: backward, check empty_init
def load_cpu_kernel(**kwargs): def load_cpu_kernel(**kwargs):
global cpu_kernels global cpu_kernels
cpu_kernels = CPUKernel(**kwargs) cpu_kernels = CPUKernel(**kwargs)
assert cpu_kernels.load
def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs): def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
@ -453,9 +457,8 @@ def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=F
dense_h_to_4h_quantization_cache = None dense_h_to_4h_quantization_cache = None
dense_4h_to_h_quantization_cache = None dense_4h_to_h_quantization_cache = None
try: load_cpu_kernel(**kwargs)
load_cpu_kernel(**kwargs) if not cpu_kernels.load:
except:
if kernels is None: # CUDA kernels failed if kernels is None: # CUDA kernels failed
print("Cannot load cpu or cuda kernel, quantization failed:") print("Cannot load cpu or cuda kernel, quantization failed:")
assert kernels is not None assert kernels is not None

View File

@ -31,6 +31,9 @@ class TextTokenizer:
def tokenize(self, text): def tokenize(self, text):
return self.sp.EncodeAsPieces(text) return self.sp.EncodeAsPieces(text)
def convert_tokens_to_string(self, tokens):
return self.sp.DecodePieces(tokens)
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
return [self.sp.PieceToId(token) for token in tokens] return [self.sp.PieceToId(token) for token in tokens]
@ -111,16 +114,25 @@ class SPTokenizer:
tokens = [x + self.num_image_tokens for x in tmp] tokens = [x + self.num_image_tokens for x in tmp]
return tokens if add_dummy_prefix else tokens[2:] return tokens if add_dummy_prefix else tokens[2:]
def decode(self, text_ids: List[int]) -> str: def postprocess(self, text):
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
ids = [_id for _id in ids if _id >= 0]
text = self._get_text_tokenizer().decode(ids)
text = text.replace("<n>", "\n") text = text.replace("<n>", "\n")
text = text.replace(SPTokenizer.get_tab_token(), "\t") text = text.replace(SPTokenizer.get_tab_token(), "\t")
for i in range(2, self.max_blank_length + 1): for i in range(2, self.max_blank_length + 1):
text = text.replace(self.get_blank_token(i), " " * i) text = text.replace(self.get_blank_token(i), " " * i)
return text return text
def decode(self, text_ids: List[int]) -> str:
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
ids = [_id for _id in ids if _id >= 0]
text = self._get_text_tokenizer().decode(ids)
text = self.postprocess(text)
return text
def decode_tokens(self, tokens: List[str]) -> str:
text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
text = self.postprocess(text)
return text
def tokenize( def tokenize(
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
) -> List[str]: ) -> List[str]:
@ -256,11 +268,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return seq return seq
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.sp_tokenizer.decode_tokens(tokens)
def _decode( def _decode(
self, self,
token_ids: Union[int, List[int]], token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
**kwargs **kwargs
) -> str: ) -> str:
if isinstance(token_ids, int): if isinstance(token_ids, int):
@ -269,7 +282,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return "" return ""
if self.pad_token_id in token_ids: # remove pad if self.pad_token_id in token_ids: # remove pad
token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
return self.sp_tokenizer.decode(token_ids) return super()._decode(token_ids, **kwargs)
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """