Compare commits

..

5 Commits
v0.1.0 ... main

Author SHA1 Message Date
Zhengxiao Du 02a065cf27 Upload pytorch_model.bin 2023-05-15 12:41:28 +00:00
Zhengxiao Du e214c5b71d Update slack link 2023-05-12 13:49:56 +00:00
duzx16 d8a6cfc6cb Update decode method in tokenizer 2023-05-09 11:32:40 +08:00
duzx16 f6b88da8c1 Add support for parallel quantization on Mac 2023-05-04 21:45:31 +02:00
duzx16 63d66b0572 Remove assert in load_cpu_kernel 2023-04-29 10:34:45 +08:00
4 changed files with 31 additions and 15 deletions

View File

@ -9,7 +9,7 @@ tags:
---
# ChatGLM-6B-INT4
<p align="center">
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1th2q5u69-7tURzFuOPanmuHy9hsZnKA" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1udqapmrr-ocT1DS_mxWe6dDY8ahRWzg" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
</p>
## 介绍

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:35828b49cf23cbae4c27788d4b04fc68c79a276300e09f14d72a49b0b738b4a9
oid sha256:245786435bde9f4593c105ea846fa461fe42bc63c12b738d0272fcaed6276645
size 3893083075

View File

@ -7,6 +7,7 @@ import bz2
import torch
import base64
import ctypes
import sys
from transformers.utils import logging
from typing import List
@ -142,8 +143,12 @@ class CPUKernel:
kernel_file = source_code[:-2] + ".so"
if compile_parallel_kernel:
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(
source_code, kernel_file)
if sys.platform != 'darwin':
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(
source_code, kernel_file)
else:
compile_command = "clang -O3 -fPIC -pthread -Xclang -fopenmp -lomp -std=c99 {} -shared -o {}".format(
source_code, kernel_file)
print("Compiling", compile_command)
exit_state = os.system(compile_command)
if not exit_state:
@ -442,7 +447,6 @@ class QuantizedEmbedding(Embedding): # TODO: backward, check empty_init
def load_cpu_kernel(**kwargs):
global cpu_kernels
cpu_kernels = CPUKernel(**kwargs)
assert cpu_kernels.load
def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
@ -453,9 +457,8 @@ def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=F
dense_h_to_4h_quantization_cache = None
dense_4h_to_h_quantization_cache = None
try:
load_cpu_kernel(**kwargs)
except:
load_cpu_kernel(**kwargs)
if not cpu_kernels.load:
if kernels is None: # CUDA kernels failed
print("Cannot load cpu or cuda kernel, quantization failed:")
assert kernels is not None

View File

@ -31,6 +31,9 @@ class TextTokenizer:
def tokenize(self, text):
return self.sp.EncodeAsPieces(text)
def convert_tokens_to_string(self, tokens):
return self.sp.DecodePieces(tokens)
def convert_tokens_to_ids(self, tokens):
return [self.sp.PieceToId(token) for token in tokens]
@ -111,16 +114,25 @@ class SPTokenizer:
tokens = [x + self.num_image_tokens for x in tmp]
return tokens if add_dummy_prefix else tokens[2:]
def decode(self, text_ids: List[int]) -> str:
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
ids = [_id for _id in ids if _id >= 0]
text = self._get_text_tokenizer().decode(ids)
def postprocess(self, text):
text = text.replace("<n>", "\n")
text = text.replace(SPTokenizer.get_tab_token(), "\t")
for i in range(2, self.max_blank_length + 1):
text = text.replace(self.get_blank_token(i), " " * i)
return text
def decode(self, text_ids: List[int]) -> str:
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
ids = [_id for _id in ids if _id >= 0]
text = self._get_text_tokenizer().decode(ids)
text = self.postprocess(text)
return text
def decode_tokens(self, tokens: List[str]) -> str:
text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
text = self.postprocess(text)
return text
def tokenize(
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
) -> List[str]:
@ -256,11 +268,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return seq
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.sp_tokenizer.decode_tokens(tokens)
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
**kwargs
) -> str:
if isinstance(token_ids, int):
@ -269,7 +282,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return ""
if self.pad_token_id in token_ids: # remove pad
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
return self.sp_tokenizer.decode(token_ids)
return super()._decode(token_ids, **kwargs)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """