generated from xuyuqing/ailab
124 lines
21 KiB
Python
124 lines
21 KiB
Python
|
# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
|
||
|
|
||
|
import torch
|
||
|
from typing import List
|
||
|
import bz2
|
||
|
import base64
|
||
|
import ctypes
|
||
|
from transformers.utils import logging
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
try:
|
||
|
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
||
|
|
||
|
class Kernel:
|
||
|
def __init__(self, code: bytes, function_names: List[str]):
|
||
|
self.code = code
|
||
|
self._function_names = function_names
|
||
|
self._cmodule = LazyKernelCModule(self.code)
|
||
|
|
||
|
for name in self._function_names:
|
||
|
setattr(self, name, KernelFunction(self._cmodule, name))
|
||
|
quantization_code = "QlpoOTFBWSZTWX/mUzwAK6f///////////////////////////////7f////////////4C5duvi2D0Oj1ppVCJ2zQFYbnbsxmq20pAC7kEDb3Z3nWrextY9NZbavON7nveSRqszudmzAGGgkeh0Pewk881e3Tz13kW9YO7uA9AUUiAWLNW2HHWCE005Mdz3jHs1Ic7QNCQBNGgmE000DRNoGjUYmA0mEmJjIaI9JtT0JoaaMTaQ0aMjTTI1TzKMmETwyaJ6k8p4Ke1T0wk2aE0anpPSHppqNM1HqYzVGj0MpsTTUGpoCAAEyAAAmhpPSYowMk9U8mqb0mJtU8ETwCZT1DQ9R5R6htE9TTyRptQeoyHqA0B6g9T1AD1HpGQGgD1A0NPUAAAA0A1Mg00gmhKPU9E2SekHoJ5QHlNDEPUeoDEaBkAHqBoABoNABoAaGgBoAAAAAAA0AAAAAAAAEmoiIgmiD0maRip+qfpR+k9U/QKaZPUepiGeST1HqeU9TQ9JoANAMhoZPU0AAYnqaBoAANABoAAAADQGgAAADTQ0IAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASJEE0AJo0GkxGJoZNKeBoTCnpNNpU9knqn+ppmUnom1PKZqTaaTTwTTFPNJ6pj1BG0eoaMgwQGkYAGk2gjT0jBqaY0RoDeqZoNEYT1NpsA/+iBrt+OVIiCKqfH7N/e67XZ2Dx9tPHyWbW4gAENNTtyzk+/WdoU604SoXU0JgfqgQxVmzbfdmaFcVxQAYINDyjTKU1FCUUzUuqqptg4SBgwIAHYE4NwQOrbY1bOF26LUVuxYr3Hp4paZXaqKU1UmXO3K+IXn2hURrgAegAaTANS+QBclUN6tpvhn85+uTPCLxzj34YO8MIMg45eRAEy9IYbKxeZTRnTy6GpPLtVGWKKK6iuDLa9wjtSmUQREX6wHfE3JeTVZdoj4Hg/3cHlBdw4c4BdGvigzZsubPr3eTi2hs6tZz3J9zUVm8qH+FPwSx4Tdr6by/OA88iLHk34rWNt7fT7NwqqqqqqqrGMYxjFcdqvY2mXyh42c2ccxhtyvBHojjUlyAKRgbvAB6nhls1wGLTOrfGMBsqRXl9Bl3sOlvafSA7sDrmAQI+mw90af+bvJ8mwjP+RKtjobGNzbfl76iTHMiIIUf9oIoygqSG2NLn0Ys/mZ+hzufu7epmzbvP1t7S0Xo8TKK7q6G5MA8vTgBb7Bf/2kITSLsH7Xmfydz7ahAt4YJbBuAQJI+1M8DLJCQH+UPbv212QWIhcCKhBrR2eryfQYIiIhKE0WtbOQ7OwM7OxtURGbF28NBndi9ejVDVA3dne37uDdzrwINS+O/0AzQTCgUjfCAwkkKFMT4Kr0aV3DicVAelGBesGYoCRcLKq5iBFR6SzOzrAwFWDFVYU2XT1oFaRJk2JBDOwVk1LFZZfwY7tQBYMGdECFA1cLZAg0IlfCTCMgZ4afRQBNvXSuMORVUTxTLSTgMFoUtaGLIr524yIM+INSFFIOHQ4TG5NZbd3Su3Nu9raSLd/ueibSYpAL0D42ZkAtD0pnXrfTxYPBw+mAt1cKPCPmDNMCDYCBiQwmANVhdDjBwsdIKyfH1slCvWbJC4QO8SBxi6A+GEpDBN6UQnPaEvBqFk3TwChKSowEENpyAueDIFs6OxxLRmFSUFpjWgYpECgDgfVBJjhg4GGcI9CD0S3igCrdziS3ZoYHlQE+7AELdvbebTVsdRvrPHCgiAbSYzUN0z0SCshLjaUaREEREQQRHNKAgAS9o0kukdJx0ulaJk0kINzlUYN0wWXLLsmRgSG1BEJNh5sCuVtIybGlKUW29BziJUTpqcA8UCCLtOGU0hH17BYTERfPKhCAwxJqSSSMd+umawlsykXZiKHesslqlVDKEHPzFhIWwJHTfcYCGE9dQK9sKixjNifLkW1iLnyZo57BBx2jksXPYjcaA6Z6rlYTl9ocZHn2URKVXnY/Wsrc5l3aym6Uq7u9eu2szSbJgwhqPqfOR1JCCZl7/AehLVBSIXc9npUk8IDzrRCS9XKMeamSDmFxK6OQDhwNnxubbnQygQb4DEL6oD5qkkG6F03dyDAUJB/awNUoDCa3CmYy2QIsK0Z46BoX1N4kY8aGNFB8WZAfWvaHeUT4gYIjEsZBBARIFAk2jCTxAmpW03GtdW4WCN0bLJiiqY3ixmHAWRqqQKqgS2hlf8mwszkhUy3LDx3GLdo5AHGAgC4BogUAVgH4QM0AGAImwbS6gwANIep0rJIU3hBgaeKAEcnzfs+g/sJZnETvInDcAH5fE7azmr8EyIFx77caxbrDBC64CEU8wCqzAHPgkk4kiPREKYHn2HaoDBWCCrFBrhR+XpeNQkdbzCBHee2hW8EW373k/qd/PxGC2R+IO4vmNEAl1AE0l4bEvmnfd5/JYs5gl9XpgQIS7g/LAK7owBwgso9j0yEB9MRIBjqmkLdG5uED3tICA6PYXe4WItRawAenfJ0lCFupoGvajxuQC/5YQPnwFpgQBMNgBndpgVNJcyw+5vCJgHtWU0EDYk2HsvD8Qkg6ANAd8UQXGH/3X3gXgNDefHyaQ/wd93Xx87hWWtW0kPCQGR+KYiPeMQse27PdNLGwhlz8WJObSnEQyHJw1JmStJXTtIg0ZKEHrLZCXd1ljLGkkxtpsDofXUiBH0LLEM43kb2waJ26KZsJ9sBbxcAqzUgWxzogNFm4vSxjMR58r5Xm8H2+6ItGcNX2AK3GhDIMzSX3YyFsbNG0u0MxvZzGFv19k2E45tXrK+1OKUYRiH2OT2Fs7kqtxMDrANVp2nxreAZg02UaFEsuf6+urQi1PxvNOhuacrStndOnonV3e5Du+Xjp8mjhiHYPNexu7UKSbt0Gs2rPIVVVSFyQ7phtQ0ZOUySoyZA79muzuLBZaLAW20gZIeuJDacErguFE3e70svo0S0mRBMBu33rjqVrNEN9A5PHvOgukEPEgb0tYAMrvcvIXB5ydzJHXQ1n+t7BUI24oJtSCTAUet75rBpXL4ylQ4LGBpbQeQCiOku+8rq90o18ga4WEGBDhvHB0YYd/CDLIMdDh2cO/i/RppcEi3Zd+CCU8OdxAAiOgi5qeghJkUnO6YGZi5LEilo2WhSiEVsU2IK7unV2rXG61Q/LbUqGx72rn2Uzx/q/fzsCWUFCQyAA+XqfGVGvL1kml0MVpjJl1A9vYoYTSatnV1+z2czsdoc4QFWLILHn1S71/r3V1S/fJMgDlXX6DVv8+FeECNi1u8zf8K8r1Khq7twFu5xPfZJT+PLpYUZWgGNDG0Jlq4rsQy86u95xqTdO0TbSGBdDOUSyyGHQAmP5mgNfVvgeY2tPzlKbyrvnaZhgQ7aWeJjzbF4mjPlro1hYjmnWUshKxVsQ6pveK850taANOgIE/aJvr0IAC0g2H2d1agVwnBkAF1kl7IPZc8mBthvlYish4AqABgI9hw2cExRabO+8Xz31+enwlCxSbnfVFlqig3UKGBQiybpEBGQLIxuoUMVYLTt53sY+lPlxSAq9f3lfnVlFmiBFrOhAeAF/0/N6HI6/+rsQ2+D5U5fenadDmtFFgeZLLESwOgWWIlgWFo+uFROhke3lKQ4bf0mLH3XSOgtDGd73hfMwDM2aF7Lonl7AlbiPbV2zY2lvu1Vj7jzlmFYoKieH93wt3fLhBXgYUGJEjga5YWEVyE00qIYWXSKd0ZaZy+vuCQlhaz5ELs9n/pjuFAHpoDCMEEtseECQF+Rk58EyW3nzCdlyCeY5WPItdkDZ4egXmjfZTLSVT29ku6KCGxHbdTBD3z52SxkuXkpoaHyy3t25+JwX5zFdYawDASl7397IB2tunNbt2FygaTBIO5qrG0asQmxEVRGCn26UX6DewTmic/QqkLZjdCTqjQDGlxy4IODucyQlmE0zkwSkR02cZjZcA1MzMczZAf1hfPnZT1IGtWIJGOcpzgYwCGyiNtoxRkupRElCCAgWJcE4igRJEQogPHYVAVBAEYDBkUEBIOSMK3KJNwQllpqWZARLCgMM
|
||
|
kernels = Kernel(
|
||
|
bz2.decompress(base64.b64decode(quantization_code)),
|
||
|
[
|
||
|
"int4_to_fp16",
|
||
|
"fp16_to_int4",
|
||
|
"int8_to_fp16",
|
||
|
"fp16_to_int8",
|
||
|
"int4_to_bf16",
|
||
|
"bf16_to_int4",
|
||
|
"int8_to_bf16",
|
||
|
"bf16_to_int8",
|
||
|
],
|
||
|
)
|
||
|
except Exception as exception:
|
||
|
kernels = None
|
||
|
logger.warning("Failed to load kernels:" + str(exception))
|
||
|
|
||
|
def quant4(weight: torch.Tensor, scale: torch.Tensor):
|
||
|
stream = torch.cuda.current_stream()
|
||
|
num_row = weight.size(0)
|
||
|
num_chan_fp16 = weight.size(1)
|
||
|
# 4bit
|
||
|
num_chan_int = num_chan_fp16 // 8
|
||
|
qweight = torch.zeros((num_row, num_chan_int), dtype=torch.int32, device=weight.device)
|
||
|
intweight = torch.empty(num_row, num_chan_fp16, dtype = torch.int32)
|
||
|
intweight = torch.clip(torch.round(weight.to(scale.dtype) / scale[:, None]),-16, 15).to(dtype=torch.int32)
|
||
|
|
||
|
for j in range(num_chan_int):
|
||
|
qweight[:, j] = ((intweight[:, j*8+7] & 0x0f) << 28) \
|
||
|
| ((intweight[:, j*8+6] & 0x0f) << 24) \
|
||
|
| ((intweight[:, j*8+5] & 0x0f) << 20) \
|
||
|
| ((intweight[:, j*8+4] & 0x0f) << 16) \
|
||
|
| ((intweight[:, j*8+3] & 0x0f) << 12) \
|
||
|
| ((intweight[:, j*8+2] & 0x0f) << 8) \
|
||
|
| ((intweight[:, j*8+1] & 0x0f) << 4) \
|
||
|
| ((intweight[:, j*8] & 0x0f))
|
||
|
return qweight
|
||
|
|
||
|
def dequant4(qweight: torch.Tensor, scale: torch.Tensor, input: torch.Tensor):
|
||
|
stream = torch.cuda.current_stream()
|
||
|
num_row = qweight.size(0)
|
||
|
num_chan_int = qweight.size(1)
|
||
|
# 4bit
|
||
|
num_chan_fp16 = num_chan_int * 8
|
||
|
|
||
|
out = torch.empty((num_row, num_chan_fp16), dtype=input.dtype, device=qweight.device)
|
||
|
|
||
|
blockDim = (128, 1, 1)
|
||
|
gridDim = ((num_chan_int + blockDim[0] - 1) // blockDim[0], num_row, 1)
|
||
|
if input.dtype == torch.bfloat16:
|
||
|
kernels.int4_to_bf16(
|
||
|
gridDim,
|
||
|
blockDim,
|
||
|
0,
|
||
|
stream,
|
||
|
[ctypes.c_void_p(out.data_ptr()), ctypes.c_void_p(qweight.data_ptr()),
|
||
|
ctypes.c_void_p(scale.data_ptr()), ctypes.c_int32(num_row), ctypes.c_int32(num_chan_int), ctypes.c_int32(num_chan_fp16)],
|
||
|
)
|
||
|
elif input.dtype == torch.float16:
|
||
|
kernels.int4_to_fp16(
|
||
|
gridDim,
|
||
|
blockDim,
|
||
|
0,
|
||
|
stream,
|
||
|
[ctypes.c_void_p(out.data_ptr()), ctypes.c_void_p(qweight.data_ptr()),
|
||
|
ctypes.c_void_p(scale.data_ptr()), ctypes.c_int32(num_row), ctypes.c_int32(num_chan_int), ctypes.c_int32(num_chan_fp16)],
|
||
|
)
|
||
|
return out
|
||
|
|
||
|
class QLinear(torch.nn.Module):
|
||
|
def __init__(self, bits: int, weight: torch.Tensor, bias=None):
|
||
|
super().__init__()
|
||
|
self.quant_bits = bits
|
||
|
self.scale = weight.abs().max(dim=-1).values / ((2 ** (bits - 1)) - 1)
|
||
|
self.scale = self.scale.to(torch.float32)
|
||
|
if self.quant_bits == 4:
|
||
|
self.weight = quant4(weight, self.scale)
|
||
|
elif self.quant_bits == 8:
|
||
|
self.weight = torch.round(weight.to(self.scale.dtype) / self.scale[:, None]).to(torch.int8)
|
||
|
if self.quant_bits == 8:
|
||
|
self.weight = self.weight.T
|
||
|
self.bias = None
|
||
|
|
||
|
def forward(self, input):
|
||
|
if self.quant_bits == 4:
|
||
|
assert(input.dtype == torch.bfloat16 or input.dtype == torch.float16)
|
||
|
|
||
|
if self.weight.device != input.device:
|
||
|
self.weight = self.weight.to(input.device)
|
||
|
self.scale = self.scale.to(input.device)
|
||
|
|
||
|
if self.quant_bits == 4:
|
||
|
self.scale = self.scale.to(input.dtype)
|
||
|
rweight = dequant4(self.weight, self.scale, input).T
|
||
|
output = torch.matmul(input, rweight)
|
||
|
elif self.quant_bits == 8:
|
||
|
rweight = self.weight.to(input.dtype) * self.scale.to(input.dtype)
|
||
|
output = torch.matmul(input, rweight)
|
||
|
if self.bias is not None:
|
||
|
output = output + self.bias
|
||
|
return output
|