Compare commits
5 Commits
Author | SHA1 | Date |
---|---|---|
|
02a065cf27 | |
|
e214c5b71d | |
|
d8a6cfc6cb | |
|
f6b88da8c1 | |
|
63d66b0572 |
|
@ -9,7 +9,7 @@ tags:
|
|||
---
|
||||
# ChatGLM-6B-INT4
|
||||
<p align="center">
|
||||
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1th2q5u69-7tURzFuOPanmuHy9hsZnKA" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
||||
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1udqapmrr-ocT1DS_mxWe6dDY8ahRWzg" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
||||
</p>
|
||||
|
||||
## 介绍
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:35828b49cf23cbae4c27788d4b04fc68c79a276300e09f14d72a49b0b738b4a9
|
||||
oid sha256:245786435bde9f4593c105ea846fa461fe42bc63c12b738d0272fcaed6276645
|
||||
size 3893083075
|
||||
|
|
|
@ -7,6 +7,7 @@ import bz2
|
|||
import torch
|
||||
import base64
|
||||
import ctypes
|
||||
import sys
|
||||
from transformers.utils import logging
|
||||
|
||||
from typing import List
|
||||
|
@ -142,8 +143,12 @@ class CPUKernel:
|
|||
kernel_file = source_code[:-2] + ".so"
|
||||
|
||||
if compile_parallel_kernel:
|
||||
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(
|
||||
source_code, kernel_file)
|
||||
if sys.platform != 'darwin':
|
||||
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(
|
||||
source_code, kernel_file)
|
||||
else:
|
||||
compile_command = "clang -O3 -fPIC -pthread -Xclang -fopenmp -lomp -std=c99 {} -shared -o {}".format(
|
||||
source_code, kernel_file)
|
||||
print("Compiling", compile_command)
|
||||
exit_state = os.system(compile_command)
|
||||
if not exit_state:
|
||||
|
@ -442,7 +447,6 @@ class QuantizedEmbedding(Embedding): # TODO: backward, check empty_init
|
|||
def load_cpu_kernel(**kwargs):
|
||||
global cpu_kernels
|
||||
cpu_kernels = CPUKernel(**kwargs)
|
||||
assert cpu_kernels.load
|
||||
|
||||
|
||||
def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
|
||||
|
@ -453,9 +457,8 @@ def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=F
|
|||
dense_h_to_4h_quantization_cache = None
|
||||
dense_4h_to_h_quantization_cache = None
|
||||
|
||||
try:
|
||||
load_cpu_kernel(**kwargs)
|
||||
except:
|
||||
load_cpu_kernel(**kwargs)
|
||||
if not cpu_kernels.load:
|
||||
if kernels is None: # CUDA kernels failed
|
||||
print("Cannot load cpu or cuda kernel, quantization failed:")
|
||||
assert kernels is not None
|
||||
|
|
|
@ -31,6 +31,9 @@ class TextTokenizer:
|
|||
def tokenize(self, text):
|
||||
return self.sp.EncodeAsPieces(text)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
return self.sp.DecodePieces(tokens)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return [self.sp.PieceToId(token) for token in tokens]
|
||||
|
||||
|
@ -111,16 +114,25 @@ class SPTokenizer:
|
|||
tokens = [x + self.num_image_tokens for x in tmp]
|
||||
return tokens if add_dummy_prefix else tokens[2:]
|
||||
|
||||
def decode(self, text_ids: List[int]) -> str:
|
||||
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
||||
ids = [_id for _id in ids if _id >= 0]
|
||||
text = self._get_text_tokenizer().decode(ids)
|
||||
def postprocess(self, text):
|
||||
text = text.replace("<n>", "\n")
|
||||
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
||||
for i in range(2, self.max_blank_length + 1):
|
||||
text = text.replace(self.get_blank_token(i), " " * i)
|
||||
return text
|
||||
|
||||
def decode(self, text_ids: List[int]) -> str:
|
||||
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
||||
ids = [_id for _id in ids if _id >= 0]
|
||||
text = self._get_text_tokenizer().decode(ids)
|
||||
text = self.postprocess(text)
|
||||
return text
|
||||
|
||||
def decode_tokens(self, tokens: List[str]) -> str:
|
||||
text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
|
||||
text = self.postprocess(text)
|
||||
return text
|
||||
|
||||
def tokenize(
|
||||
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
||||
) -> List[str]:
|
||||
|
@ -256,11 +268,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
|
||||
return seq
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return self.sp_tokenizer.decode_tokens(tokens)
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int]],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
**kwargs
|
||||
) -> str:
|
||||
if isinstance(token_ids, int):
|
||||
|
@ -269,7 +282,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
return ""
|
||||
if self.pad_token_id in token_ids: # remove pad
|
||||
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
||||
return self.sp_tokenizer.decode(token_ids)
|
||||
return super()._decode(token_ids, **kwargs)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
|
|
Loading…
Reference in New Issue