Compare commits

..

No commits in common. "main" and "v0.1.0" have entirely different histories.
main ... v0.1.0

4 changed files with 15 additions and 31 deletions

View File

@ -9,7 +9,7 @@ tags:
--- ---
# ChatGLM-6B-INT4 # ChatGLM-6B-INT4
<p align="center"> <p align="center">
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1udqapmrr-ocT1DS_mxWe6dDY8ahRWzg" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a> 👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1th2q5u69-7tURzFuOPanmuHy9hsZnKA" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
</p> </p>
## 介绍 ## 介绍

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:245786435bde9f4593c105ea846fa461fe42bc63c12b738d0272fcaed6276645 oid sha256:35828b49cf23cbae4c27788d4b04fc68c79a276300e09f14d72a49b0b738b4a9
size 3893083075 size 3893083075

View File

@ -7,7 +7,6 @@ import bz2
import torch import torch
import base64 import base64
import ctypes import ctypes
import sys
from transformers.utils import logging from transformers.utils import logging
from typing import List from typing import List
@ -143,12 +142,8 @@ class CPUKernel:
kernel_file = source_code[:-2] + ".so" kernel_file = source_code[:-2] + ".so"
if compile_parallel_kernel: if compile_parallel_kernel:
if sys.platform != 'darwin': compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format( source_code, kernel_file)
source_code, kernel_file)
else:
compile_command = "clang -O3 -fPIC -pthread -Xclang -fopenmp -lomp -std=c99 {} -shared -o {}".format(
source_code, kernel_file)
print("Compiling", compile_command) print("Compiling", compile_command)
exit_state = os.system(compile_command) exit_state = os.system(compile_command)
if not exit_state: if not exit_state:
@ -447,6 +442,7 @@ class QuantizedEmbedding(Embedding): # TODO: backward, check empty_init
def load_cpu_kernel(**kwargs): def load_cpu_kernel(**kwargs):
global cpu_kernels global cpu_kernels
cpu_kernels = CPUKernel(**kwargs) cpu_kernels = CPUKernel(**kwargs)
assert cpu_kernels.load
def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs): def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
@ -457,8 +453,9 @@ def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=F
dense_h_to_4h_quantization_cache = None dense_h_to_4h_quantization_cache = None
dense_4h_to_h_quantization_cache = None dense_4h_to_h_quantization_cache = None
load_cpu_kernel(**kwargs) try:
if not cpu_kernels.load: load_cpu_kernel(**kwargs)
except:
if kernels is None: # CUDA kernels failed if kernels is None: # CUDA kernels failed
print("Cannot load cpu or cuda kernel, quantization failed:") print("Cannot load cpu or cuda kernel, quantization failed:")
assert kernels is not None assert kernels is not None

View File

@ -31,9 +31,6 @@ class TextTokenizer:
def tokenize(self, text): def tokenize(self, text):
return self.sp.EncodeAsPieces(text) return self.sp.EncodeAsPieces(text)
def convert_tokens_to_string(self, tokens):
return self.sp.DecodePieces(tokens)
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
return [self.sp.PieceToId(token) for token in tokens] return [self.sp.PieceToId(token) for token in tokens]
@ -114,23 +111,14 @@ class SPTokenizer:
tokens = [x + self.num_image_tokens for x in tmp] tokens = [x + self.num_image_tokens for x in tmp]
return tokens if add_dummy_prefix else tokens[2:] return tokens if add_dummy_prefix else tokens[2:]
def postprocess(self, text):
text = text.replace("<n>", "\n")
text = text.replace(SPTokenizer.get_tab_token(), "\t")
for i in range(2, self.max_blank_length + 1):
text = text.replace(self.get_blank_token(i), " " * i)
return text
def decode(self, text_ids: List[int]) -> str: def decode(self, text_ids: List[int]) -> str:
ids = [int(_id) - self.num_image_tokens for _id in text_ids] ids = [int(_id) - self.num_image_tokens for _id in text_ids]
ids = [_id for _id in ids if _id >= 0] ids = [_id for _id in ids if _id >= 0]
text = self._get_text_tokenizer().decode(ids) text = self._get_text_tokenizer().decode(ids)
text = self.postprocess(text) text = text.replace("<n>", "\n")
return text text = text.replace(SPTokenizer.get_tab_token(), "\t")
for i in range(2, self.max_blank_length + 1):
def decode_tokens(self, tokens: List[str]) -> str: text = text.replace(self.get_blank_token(i), " " * i)
text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
text = self.postprocess(text)
return text return text
def tokenize( def tokenize(
@ -268,12 +256,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return seq return seq
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.sp_tokenizer.decode_tokens(tokens)
def _decode( def _decode(
self, self,
token_ids: Union[int, List[int]], token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
**kwargs **kwargs
) -> str: ) -> str:
if isinstance(token_ids, int): if isinstance(token_ids, int):
@ -282,7 +269,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return "" return ""
if self.pad_token_id in token_ids: # remove pad if self.pad_token_id in token_ids: # remove pad
token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
return super()._decode(token_ids, **kwargs) return self.sp_tokenizer.decode(token_ids)
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """