Synchronize with chatglm 6b repo

This commit is contained in:
songxxzp 2023-04-03 09:20:14 +08:00
parent 7458231b5a
commit 7aaf3fe491
5 changed files with 359 additions and 137 deletions

View File

@ -10,6 +10,7 @@
}, },
"bos_token_id": 150004, "bos_token_id": 150004,
"eos_token_id": 150005, "eos_token_id": 150005,
"pad_token_id": 20003,
"hidden_size": 4096, "hidden_size": 4096,
"inner_hidden_size": 16384, "inner_hidden_size": 16384,
"layernorm_epsilon": 1e-05, "layernorm_epsilon": 1e-05,

View File

@ -72,6 +72,8 @@ class ChatGLMConfig(PretrainedConfig):
position_encoding_2d=True, position_encoding_2d=True,
quantization_bit=0, quantization_bit=0,
quantization_embeddings=False, quantization_embeddings=False,
pre_seq_len=None,
prefix_projection=False,
**kwargs **kwargs
): ):
self.num_layers = num_layers self.num_layers = num_layers
@ -86,8 +88,11 @@ class ChatGLMConfig(PretrainedConfig):
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.position_encoding_2d = position_encoding_2d self.position_encoding_2d = position_encoding_2d
self.quantization_bit=quantization_bit self.quantization_bit = quantization_bit
self.quantization_embeddings=quantization_embeddings self.quantization_embeddings = quantization_embeddings
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,

View File

@ -5,6 +5,7 @@ import copy
import os import os
import warnings import warnings
import re import re
import sys
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -12,7 +13,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from transformers.utils import ( from transformers.utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
@ -27,16 +28,18 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
from .configuration_chatglm import ChatGLMConfig from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False) if sys.platform != 'darwin':
torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -131,6 +134,36 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
return model return model
class PrefixEncoder(torch.nn.Module):
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
"""
def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
@torch.jit.script @torch.jit.script
def gelu_impl(x): def gelu_impl(x):
"""OpenAI's gelu implementation.""" """OpenAI's gelu implementation."""
@ -219,7 +252,7 @@ def attention_fn(
use_cache=False, use_cache=False,
): ):
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past past_key, past_value = layer_past[0], layer_past[1]
key_layer = torch.cat((past_key, key_layer), dim=0) key_layer = torch.cat((past_key, key_layer), dim=0)
value_layer = torch.cat((past_value, value_layer), dim=0) value_layer = torch.cat((past_value, value_layer), dim=0)
@ -273,7 +306,7 @@ def attention_fn(
if not (attention_mask == 0).all(): if not (attention_mask == 0).all():
# if auto-regressive, skip # if auto-regressive, skip
attention_scores.masked_fill_(attention_mask, -10000.0) attention_scores.masked_fill_(attention_mask, -10000.0)
dtype = attention_scores.type() dtype = attention_scores.dtype
attention_scores = attention_scores.float() attention_scores = attention_scores.float()
attention_scores = attention_scores * query_key_layer_scaling_coeff attention_scores = attention_scores * query_key_layer_scaling_coeff
@ -619,10 +652,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
""" """
is_parallelizable = False is_parallelizable = False
supports_gradient_checkpointing = False supports_gradient_checkpointing = True
config_class = ChatGLMConfig config_class = ChatGLMConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
_no_split_modules = ["GLM6BBlock"] _no_split_modules = ["GLMBlock"]
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
@ -631,6 +664,43 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
"""Initialize the weights.""" """Initialize the weights."""
return return
def get_masks(self, input_ids, device):
batch_size, seq_length = input_ids.shape
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
attention_mask.tril_()
for i, context_length in enumerate(context_lengths):
attention_mask[i, :, :context_length] = 1
attention_mask.unsqueeze_(1)
attention_mask = (attention_mask < 0.5).bool()
return attention_mask
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
batch_size, seq_length = input_ids.shape
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
if self.position_encoding_2d:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
for i, context_length in enumerate(context_lengths):
position_ids[i, context_length:] = mask_positions[i]
block_position_ids = [torch.cat((
torch.zeros(context_length, dtype=torch.long, device=device),
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
)) for context_length in context_lengths]
block_position_ids = torch.stack(block_position_ids, dim=0)
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
else:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
if not gmask:
for i, context_length in enumerate(context_lengths):
position_ids[context_length:] = mask_positions[i]
return position_ids
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ChatGLMModel):
module.gradient_checkpointing = value
CHATGLM_6B_START_DOCSTRING = r""" CHATGLM_6B_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
@ -727,12 +797,15 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.inner_hidden_size = config.inner_hidden_size self.inner_hidden_size = config.inner_hidden_size
self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
self.position_encoding_2d = config.position_encoding_2d self.position_encoding_2d = config.position_encoding_2d
self.pre_seq_len = config.pre_seq_len
self.prefix_projection = config.prefix_projection
self.word_embeddings = skip_init( self.word_embeddings = skip_init(
torch.nn.Embedding, torch.nn.Embedding,
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
dtype=self.params_dtype dtype=self.params_dtype
) )
self.gradient_checkpointing = False
def get_layer(layer_id): def get_layer(layer_id):
return GLMBlock( return GLMBlock(
@ -755,43 +828,38 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
if self.pre_seq_len is not None:
for param in self.parameters():
param.requires_grad = False
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
self.prefix_encoder = PrefixEncoder(config)
self.dropout = torch.nn.Dropout(0.1)
# total_params = sum(p.numel() for p in self.parameters())
# trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
# print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))
def get_input_embeddings(self): def get_input_embeddings(self):
return self.word_embeddings return self.word_embeddings
def set_input_embeddings(self, new_embeddings: torch.Tensor): def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings self.word_embeddings = new_embeddings
def get_masks(self, seq, device): def get_prompt(self, batch_size, device, dtype=torch.half):
context_length = seq.index(self.config.bos_token_id) + 1 prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
attention_mask = torch.ones((1, len(seq), len(seq)), device=device) past_key_values = past_key_values.view(
attention_mask.tril_() batch_size,
attention_mask[..., :context_length - 1] = 1 self.pre_seq_len,
attention_mask.unsqueeze_(1) self.num_layers * 2,
attention_mask = (attention_mask < 0.5).bool() self.num_attention_heads,
self.hidden_size // self.num_attention_heads
return attention_mask )
# seq_len, b, nh, hidden_size
def get_position_ids(self, seq, mask_position, device, gmask=False): past_key_values = self.dropout(past_key_values)
context_length = seq.index(self.config.bos_token_id) + 1 past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
if self.position_encoding_2d: # past_key_values = [(v[0], v[1]) for v in past_key_values]
seq_length = seq.index(self.config.bos_token_id) return past_key_values
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[seq_length:] = mask_position
block_position_ids = torch.cat((
torch.zeros(seq_length, dtype=torch.long, device=device),
torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
))
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
else:
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[context_length - 1:] = mask_position
position_ids = position_ids.unsqueeze(0)
return position_ids
@add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
@ -819,6 +887,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -828,31 +903,41 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if past_key_values is None: if past_key_values is None:
past_key_values = tuple([None] * len(self.layers)) if self.pre_seq_len is not None:
seq = input_ids[0].tolist() past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
dtype=inputs_embeds.dtype)
else:
past_key_values = tuple([None] * len(self.layers))
if attention_mask is None: if attention_mask is None:
attention_mask = self.get_masks( attention_mask = self.get_masks(
seq=seq, input_ids,
device=input_ids.device device=input_ids.device
) )
if position_ids is None: if position_ids is None:
MASK, gMASK = 150000, 150001 MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK use_gmask = False if MASK in input_ids else gMASK
mask_position = seq.index(mask_token) mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
position_ids = self.get_position_ids( position_ids = self.get_position_ids(
seq=seq, input_ids,
mask_position=mask_position, mask_positions=mask_positions,
device=input_ids.device, device=input_ids.device,
gmask=use_gmask gmask=use_gmask
) )
if inputs_embeds is None: if self.pre_seq_len is not None and attention_mask is not None:
inputs_embeds = self.word_embeddings(input_ids) prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
attention_mask.device)
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
# [seq_len, batch, hidden_size] # [seq_len, batch, hidden_size]
hidden_states = inputs_embeds.transpose(0, 1) hidden_states = inputs_embeds.transpose(0, 1)
@ -861,11 +946,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[0]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None: if attention_mask is None:
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
@ -876,16 +956,29 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_past = past_key_values[i]
layer_ret = layer( if self.gradient_checkpointing and self.training:
hidden_states, layer_ret = torch.utils.checkpoint.checkpoint(
position_ids=position_ids, layer,
attention_mask=attention_mask, hidden_states,
layer_id=torch.tensor(i), position_ids,
layer_past=past_key_values[i], attention_mask,
use_cache=use_cache, torch.tensor(i),
output_attentions=output_attentions layer_past,
) use_cache,
output_attentions
)
else:
layer_ret = layer(
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
layer_id=torch.tensor(i),
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions
)
hidden_states = layer_ret[0] hidden_states = layer_ret[0]
@ -946,31 +1039,40 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False): def _update_model_kwargs_for_generation(
attention_mask = torch.ones((1, context_length, context_length), device=device) self,
attention_mask.tril_() outputs: ModelOutput,
attention_mask[..., :context_length - 1] = 1 model_kwargs: Dict[str, Any],
attention_mask.unsqueeze_(1) is_encoder_decoder: bool = False,
attention_mask = (attention_mask < 0.5).bool() standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if self.position_encoding_2d: # update attention mask
seq_length = seq.index(self.config.bos_token_id) if "attention_mask" in model_kwargs:
position_ids = torch.arange(context_length, dtype=torch.long, device=device) attention_mask = model_kwargs["attention_mask"]
if not gmask: if attention_mask is not None and attention_mask.dtype == torch.bool:
position_ids[seq_length:] = mask_position attention_mask = torch.cat(
block_position_ids = torch.cat(( [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
torch.zeros(seq_length, dtype=torch.long, device=device), new_attention_mask = attention_mask[:, :, -1:].clone()
torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 new_attention_mask[..., -1] = False
)) model_kwargs["attention_mask"] = torch.cat(
position_ids = torch.stack((position_ids, block_position_ids), dim=0) [attention_mask, new_attention_mask], dim=2
else: )
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[context_length - 1:] = mask_position
position_ids = position_ids.unsqueeze(0) # update position ids
if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id[:, 1, :] += 1
model_kwargs["position_ids"] = torch.cat(
[position_ids, new_position_id], dim=-1
)
return attention_mask, position_ids return model_kwargs
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
@ -978,27 +1080,34 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
past: Optional[torch.Tensor] = None, past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs **kwargs
) -> dict: ) -> dict:
batch_size, seq_length = input_ids.shape
MASK, gMASK = 150000, 150001 MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK use_gmask = False if MASK in input_ids else gMASK
seq = input_ids[0].tolist() seqs = input_ids.tolist()
mask_position = seq.index(mask_token) mask_positions = [seq.index(mask_token) for seq in seqs]
if mask_token not in seq:
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
# only last token for input_ids if past is not None # only last token for input_ids if past is not None
if past is not None or past_key_values is not None: if past is not None or past_key_values is not None:
context_length = seq.index(self.config.bos_token_id)
last_token = input_ids[:, -1].unsqueeze(-1) last_token = input_ids[:, -1].unsqueeze(-1)
if self.position_encoding_2d: if attention_mask is not None and attention_mask.dtype == torch.bool:
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, attention_mask = attention_mask[:, :, -1:]
device=input_ids.device)
else: else:
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device) attention_mask = None
if position_ids is not None:
position_ids = position_ids[..., -1:]
else:
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
if self.position_encoding_2d:
position_ids = torch.tensor(
[[mask_position, seq_length - context_length] for mask_position, context_length in
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
else:
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
device=input_ids.device).unsqueeze(-1)
if past is None: if past is None:
past = past_key_values past = past_key_values
@ -1006,15 +1115,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
"input_ids": last_token, "input_ids": last_token,
"past_key_values": past, "past_key_values": past,
"position_ids": position_ids, "position_ids": position_ids,
"attention_mask": attention_mask
} }
else: else:
attention_mask, position_ids = self.get_masks_and_position_ids( if attention_mask is not None and attention_mask.dtype != torch.bool:
seq=seq, logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
mask_position=mask_position, attention_mask = None
context_length=len(seq), if attention_mask is None:
device=input_ids.device, attention_mask = self.get_masks(
gmask=use_gmask input_ids,
) device=input_ids.device
)
if position_ids is None:
position_ids = self.get_position_ids(
input_ids,
device=input_ids.device,
mask_positions=mask_positions,
gmask=use_gmask
)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
@ -1063,7 +1181,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype) lm_logits = lm_logits.to(hidden_states.dtype)
@ -1132,10 +1250,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
for i, (old_query, response) in enumerate(history): for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
input_ids = tokenizer([prompt], return_tensors="pt", padding=True) inputs = tokenizer([prompt], return_tensors="pt")
input_ids = input_ids.to(self.device) inputs = inputs.to(self.device)
outputs = self.generate(**input_ids, **gen_kwargs) outputs = self.generate(**inputs, **gen_kwargs)
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs) response = tokenizer.decode(outputs)
response = self.process_response(response) response = self.process_response(response)
history = history + [(query, response)] history = history + [(query, response)]
@ -1158,10 +1276,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
for i, (old_query, response) in enumerate(history): for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
input_ids = tokenizer([prompt], return_tensors="pt", padding=True) inputs = tokenizer([prompt], return_tensors="pt")
input_ids = input_ids.to(self.device) inputs = inputs.to(self.device)
for outputs in self.stream_generate(**input_ids, **gen_kwargs): for outputs in self.stream_generate(**inputs, **gen_kwargs):
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs) response = tokenizer.decode(outputs)
response = self.process_response(response) response = self.process_response(response)
new_history = history + [(query, response)] new_history = history + [(query, response)]
@ -1298,6 +1416,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
num_embeddings=self.transformer.word_embeddings.num_embeddings, num_embeddings=self.transformer.word_embeddings.num_embeddings,
embedding_dim=self.transformer.word_embeddings.embedding_dim, embedding_dim=self.transformer.word_embeddings.embedding_dim,
dtype=torch.half, dtype=torch.half,
empty_init=True,
device=self.transformer.word_embeddings.weight.device, device=self.transformer.word_embeddings.weight.device,
) )
self.lm_head = QuantizedLinear( self.lm_head = QuantizedLinear(
@ -1310,6 +1429,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
quantized_weight=self.transformer.word_embeddings.weight, quantized_weight=self.transformer.word_embeddings.weight,
quantized_weight_scale=self.transformer.word_embeddings.weight_scale, quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
dtype=torch.half, dtype=torch.half,
empty_init=True,
device=self.lm_head.weight.device, device=self.lm_head.weight.device,
) )

View File

@ -7,10 +7,13 @@ import bz2
import torch import torch
import base64 import base64
import ctypes import ctypes
from transformers.utils import logging
from typing import List from typing import List
from functools import partial from functools import partial
logger = logging.get_logger(__name__)
try: try:
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
@ -37,18 +40,18 @@ try:
) )
except Exception as exception: except Exception as exception:
kernels = None kernels = None
print("Failed to load cpm_kernels:", exception) logger.warning("Failed to load cpm_kernels:", exception)
class W8A16Linear(torch.autograd.Function): class W8A16Linear(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
ctx.inp_shape = inp.size() ctx.inp_shape = inp.size()
ctx.weight_shape = quant_w.size()
ctx.weight_bit_width = weight_bit_width ctx.weight_bit_width = weight_bit_width
out_features = quant_w.size(0) out_features = quant_w.size(0)
inp = inp.contiguous().view(-1, inp.size(-1)) inp = inp.contiguous().view(-1, inp.size(-1))
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
ctx.weight_shape = weight.size()
output = inp.mm(weight.t()) output = inp.mm(weight.t())
ctx.save_for_backward(inp, quant_w, scale_w) ctx.save_for_backward(inp, quant_w, scale_w)
return output.view(*(ctx.inp_shape[:-1] + (out_features,))) return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@ -60,18 +63,18 @@ class W8A16Linear(torch.autograd.Function):
grad_output = grad_output.contiguous().view(-1, weight.size(0)) grad_output = grad_output.contiguous().view(-1, weight.size(0))
grad_input = grad_output.mm(weight) grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(inp) grad_weight = grad_output.t().mm(inp)
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
class W8A16LinearCPU(torch.autograd.Function): class W8A16LinearCPU(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None): def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None):
ctx.inp_shape = inp.size() ctx.inp_shape = inp.size()
ctx.weight_shape = quant_w.size()
ctx.weight_bit_width = weight_bit_width ctx.weight_bit_width = weight_bit_width
out_features = quant_w.size(0) out_features = quant_w.size(0)
inp = inp.contiguous().view(-1, inp.size(-1)) inp = inp.contiguous().view(-1, inp.size(-1))
weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache) weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache)
ctx.weight_shape = weight.size()
output = inp.mm(weight.t()) output = inp.mm(weight.t())
ctx.save_for_backward(inp, quant_w, scale_w) ctx.save_for_backward(inp, quant_w, scale_w)
return output.view(*(ctx.inp_shape[:-1] + (out_features,))) return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@ -83,7 +86,7 @@ class W8A16LinearCPU(torch.autograd.Function):
grad_output = grad_output.contiguous().view(-1, weight.size(0)) grad_output = grad_output.contiguous().view(-1, weight.size(0))
grad_input = grad_output.mm(weight) grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(inp) grad_weight = grad_output.t().mm(inp)
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c") default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c")
@ -168,7 +171,7 @@ class CPUKernel:
print("Load kernel :", kernel_file) print("Load kernel :", kernel_file)
else: else:
print("Failed to load kernel.") print("Failed to load kernel.")
if compile_parallel_kernel: if compile_parallel_kernel:
if parallel_num is None: if parallel_num is None:
parallel_num = max(os.cpu_count() // 2, 1) parallel_num = max(os.cpu_count() // 2, 1)
@ -176,7 +179,7 @@ class CPUKernel:
if parallel_num < 4: if parallel_num < 4:
print("Parallel kernel is not recommended when parallel num < 4.") print("Parallel kernel is not recommended when parallel num < 4.")
self.SetNumThreads(parallel_num) self.SetNumThreads(parallel_num)
self.parallel_num = parallel_num self.parallel_num = parallel_num
@ -284,10 +287,10 @@ def extract_weight_to_float(weight: torch.Tensor, scale_list: torch.Tensor, sour
class CacheTensor(): class CacheTensor():
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.tensor = torch.empty(*args, **kwargs) self.tensor = torch.empty(*args, **kwargs)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
self.tensor = self.tensor.to(*args, **kwargs) self.tensor = self.tensor.to(*args, **kwargs)
def data_ptr(self): def data_ptr(self):
return self.tensor.data_ptr() return self.tensor.data_ptr()
@ -393,7 +396,7 @@ def load_cpu_kernel(**kwargs):
def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs): def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
"""Replace fp16 linear with quantized linear""" """Replace fp16 linear with quantized linear"""
query_key_value_quantization_cache = None query_key_value_quantization_cache = None
dense_quantization_cache = None dense_quantization_cache = None
dense_h_to_4h_quantization_cache = None dense_h_to_4h_quantization_cache = None

View File

@ -1,17 +1,14 @@
"""Tokenization classes for ChatGLM.""" """Tokenization classes for ChatGLM."""
import sys
import unicodedata
from typing import List, Optional, Union from typing import List, Optional, Union
from functools import lru_cache
import os import os
import collections
import re
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from icetk.text_tokenizer import TextTokenizer from icetk.text_tokenizer import TextTokenizer
from icetk.utils import auto_create
import icetk.sentencepiece_model_pb2 as sp_model import icetk.sentencepiece_model_pb2 as sp_model
from transformers.utils import logging from transformers.utils import logging, PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
from typing import Dict
import numpy as np
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -180,7 +177,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "ice_text.model"} vocab_files_names = {"vocab_file": "ice_text.model"}
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids"] model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__( def __init__(
self, self,
@ -210,7 +207,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
self.eos_token = eos_token self.eos_token = eos_token
self.eop_token = eop_token self.eop_token = eop_token
self.mask_token = mask_token self.mask_token = mask_token
self.gMASK_token = gmask_token self.gmask_token = gmask_token
self.sp_tokenizer = SPTokenizer(vocab_file) self.sp_tokenizer = SPTokenizer(vocab_file)
@ -299,7 +296,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
""" """
if os.path.isdir(save_directory): if os.path.isdir(save_directory):
vocab_file = os.path.join( vocab_file = os.path.join(
save_directory, VOCAB_FILES_NAMES["vocab_file"] save_directory, self.vocab_files_names["vocab_file"]
) )
else: else:
vocab_file = save_directory vocab_file = save_directory
@ -331,10 +328,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
Returns: Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
""" """
if token_ids_1 is not None:
token_ids_0 += token_ids_1
mask_ids = self.sp_tokenizer[self.mask_token] mask_ids = self.sp_tokenizer[self.mask_token]
gmask_ids = self.sp_tokenizer[self.gMASK_token] gmask_ids = self.sp_tokenizer[self.gmask_token]
eop_id = self.sp_tokenizer[self.eop_token]
if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0: if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
token_ids_0 += [gmask_ids] token_ids_0 += [gmask_ids]
@ -343,4 +339,101 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
token_ids_0 += [self.sp_tokenizer[self.bos_token]] token_ids_0 += [self.sp_tokenizer[self.bos_token]]
if token_ids_1 is not None:
if not token_ids_1 or token_ids_1[-1] != eop_id:
token_ids_1 += [eop_id]
token_ids_0 += token_ids_1
return token_ids_0 return token_ids_0
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
Args:
encoded_inputs:
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
bos_token_id = self.sp_tokenizer[self.bos_token]
mask_token_id = self.sp_tokenizer[self.mask_token]
gmask_token_id = self.sp_tokenizer[self.gmask_token]
assert self.padding_side == "left"
required_input = encoded_inputs[self.model_input_names[0]]
seq_length = len(required_input)
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if max_length is not None:
if "attention_mask" not in encoded_inputs:
if bos_token_id in required_input:
context_length = required_input.index(bos_token_id)
else:
context_length = seq_length
attention_mask = np.ones((1, seq_length, seq_length))
attention_mask = np.tril(attention_mask)
attention_mask[:, :, :context_length] = 1
attention_mask = np.bool_(attention_mask < 0.5)
encoded_inputs["attention_mask"] = attention_mask
if "position_ids" not in encoded_inputs:
position_ids = np.arange(seq_length, dtype=np.int64)
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
if mask_token in required_input:
mask_position = required_input.index(mask_token)
position_ids[context_length:] = mask_position
block_position_ids = np.concatenate(
[np.zeros(context_length, dtype=np.int64),
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
if needs_to_be_padded:
difference = max_length - len(required_input)
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
pad_width=[(0, 0), (difference, 0), (difference, 0)],
mode='constant', constant_values=True)
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
]
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
if "position_ids" in encoded_inputs:
encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
pad_width=[(0, 0), (difference, 0)])
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
return encoded_inputs