Compare commits
5 Commits
Author | SHA1 | Date |
---|---|---|
|
02a065cf27 | |
|
e214c5b71d | |
|
d8a6cfc6cb | |
|
f6b88da8c1 | |
|
63d66b0572 |
|
@ -9,7 +9,7 @@ tags:
|
||||||
---
|
---
|
||||||
# ChatGLM-6B-INT4
|
# ChatGLM-6B-INT4
|
||||||
<p align="center">
|
<p align="center">
|
||||||
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1th2q5u69-7tURzFuOPanmuHy9hsZnKA" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1udqapmrr-ocT1DS_mxWe6dDY8ahRWzg" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## 介绍
|
## 介绍
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:35828b49cf23cbae4c27788d4b04fc68c79a276300e09f14d72a49b0b738b4a9
|
oid sha256:245786435bde9f4593c105ea846fa461fe42bc63c12b738d0272fcaed6276645
|
||||||
size 3893083075
|
size 3893083075
|
||||||
|
|
|
@ -7,6 +7,7 @@ import bz2
|
||||||
import torch
|
import torch
|
||||||
import base64
|
import base64
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import sys
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
@ -142,8 +143,12 @@ class CPUKernel:
|
||||||
kernel_file = source_code[:-2] + ".so"
|
kernel_file = source_code[:-2] + ".so"
|
||||||
|
|
||||||
if compile_parallel_kernel:
|
if compile_parallel_kernel:
|
||||||
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(
|
if sys.platform != 'darwin':
|
||||||
source_code, kernel_file)
|
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(
|
||||||
|
source_code, kernel_file)
|
||||||
|
else:
|
||||||
|
compile_command = "clang -O3 -fPIC -pthread -Xclang -fopenmp -lomp -std=c99 {} -shared -o {}".format(
|
||||||
|
source_code, kernel_file)
|
||||||
print("Compiling", compile_command)
|
print("Compiling", compile_command)
|
||||||
exit_state = os.system(compile_command)
|
exit_state = os.system(compile_command)
|
||||||
if not exit_state:
|
if not exit_state:
|
||||||
|
@ -442,7 +447,6 @@ class QuantizedEmbedding(Embedding): # TODO: backward, check empty_init
|
||||||
def load_cpu_kernel(**kwargs):
|
def load_cpu_kernel(**kwargs):
|
||||||
global cpu_kernels
|
global cpu_kernels
|
||||||
cpu_kernels = CPUKernel(**kwargs)
|
cpu_kernels = CPUKernel(**kwargs)
|
||||||
assert cpu_kernels.load
|
|
||||||
|
|
||||||
|
|
||||||
def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
|
def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
|
||||||
|
@ -453,9 +457,8 @@ def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=F
|
||||||
dense_h_to_4h_quantization_cache = None
|
dense_h_to_4h_quantization_cache = None
|
||||||
dense_4h_to_h_quantization_cache = None
|
dense_4h_to_h_quantization_cache = None
|
||||||
|
|
||||||
try:
|
load_cpu_kernel(**kwargs)
|
||||||
load_cpu_kernel(**kwargs)
|
if not cpu_kernels.load:
|
||||||
except:
|
|
||||||
if kernels is None: # CUDA kernels failed
|
if kernels is None: # CUDA kernels failed
|
||||||
print("Cannot load cpu or cuda kernel, quantization failed:")
|
print("Cannot load cpu or cuda kernel, quantization failed:")
|
||||||
assert kernels is not None
|
assert kernels is not None
|
||||||
|
|
|
@ -31,6 +31,9 @@ class TextTokenizer:
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
return self.sp.EncodeAsPieces(text)
|
return self.sp.EncodeAsPieces(text)
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
return self.sp.DecodePieces(tokens)
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens):
|
||||||
return [self.sp.PieceToId(token) for token in tokens]
|
return [self.sp.PieceToId(token) for token in tokens]
|
||||||
|
|
||||||
|
@ -111,16 +114,25 @@ class SPTokenizer:
|
||||||
tokens = [x + self.num_image_tokens for x in tmp]
|
tokens = [x + self.num_image_tokens for x in tmp]
|
||||||
return tokens if add_dummy_prefix else tokens[2:]
|
return tokens if add_dummy_prefix else tokens[2:]
|
||||||
|
|
||||||
def decode(self, text_ids: List[int]) -> str:
|
def postprocess(self, text):
|
||||||
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
|
||||||
ids = [_id for _id in ids if _id >= 0]
|
|
||||||
text = self._get_text_tokenizer().decode(ids)
|
|
||||||
text = text.replace("<n>", "\n")
|
text = text.replace("<n>", "\n")
|
||||||
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
||||||
for i in range(2, self.max_blank_length + 1):
|
for i in range(2, self.max_blank_length + 1):
|
||||||
text = text.replace(self.get_blank_token(i), " " * i)
|
text = text.replace(self.get_blank_token(i), " " * i)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def decode(self, text_ids: List[int]) -> str:
|
||||||
|
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
||||||
|
ids = [_id for _id in ids if _id >= 0]
|
||||||
|
text = self._get_text_tokenizer().decode(ids)
|
||||||
|
text = self.postprocess(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens: List[str]) -> str:
|
||||||
|
text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
|
||||||
|
text = self.postprocess(text)
|
||||||
|
return text
|
||||||
|
|
||||||
def tokenize(
|
def tokenize(
|
||||||
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -256,11 +268,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
return seq
|
return seq
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
return self.sp_tokenizer.decode_tokens(tokens)
|
||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
self,
|
self,
|
||||||
token_ids: Union[int, List[int]],
|
token_ids: Union[int, List[int]],
|
||||||
skip_special_tokens: bool = False,
|
|
||||||
clean_up_tokenization_spaces: bool = True,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(token_ids, int):
|
if isinstance(token_ids, int):
|
||||||
|
@ -269,7 +282,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
return ""
|
return ""
|
||||||
if self.pad_token_id in token_ids: # remove pad
|
if self.pad_token_id in token_ids: # remove pad
|
||||||
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
||||||
return self.sp_tokenizer.decode(token_ids)
|
return super()._decode(token_ids, **kwargs)
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a token (str) in an id using the vocab. """
|
""" Converts a token (str) in an id using the vocab. """
|
||||||
|
|
Loading…
Reference in New Issue