diff --git a/config.json b/config.json index c9ee2e2..06322c1 100644 --- a/config.json +++ b/config.json @@ -10,6 +10,7 @@ }, "bos_token_id": 150004, "eos_token_id": 150005, + "pad_token_id": 20003, "hidden_size": 4096, "inner_hidden_size": 16384, "layernorm_epsilon": 1e-05, diff --git a/configuration_chatglm.py b/configuration_chatglm.py index 0916053..70b8174 100644 --- a/configuration_chatglm.py +++ b/configuration_chatglm.py @@ -72,6 +72,8 @@ class ChatGLMConfig(PretrainedConfig): position_encoding_2d=True, quantization_bit=0, quantization_embeddings=False, + pre_seq_len=None, + prefix_projection=False, **kwargs ): self.num_layers = num_layers @@ -86,8 +88,11 @@ class ChatGLMConfig(PretrainedConfig): self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.position_encoding_2d = position_encoding_2d - self.quantization_bit=quantization_bit - self.quantization_embeddings=quantization_embeddings + self.quantization_bit = quantization_bit + self.quantization_embeddings = quantization_embeddings + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 96f37e4..78e6233 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -5,6 +5,7 @@ import copy import os import warnings import re +import sys import torch import torch.utils.checkpoint @@ -12,7 +13,7 @@ import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable +from typing import Optional, Tuple, Union, List, Callable, Dict, Any from transformers.utils import ( add_code_sample_docstrings, @@ -27,16 +28,18 @@ from transformers.modeling_outputs import ( from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput from .configuration_chatglm import ChatGLMConfig # flags required to enable jit fusion kernels -torch._C._jit_set_profiling_mode(False) -torch._C._jit_set_profiling_executor(False) -torch._C._jit_override_can_fuse_on_cpu(True) -torch._C._jit_override_can_fuse_on_gpu(True) + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) logger = logging.get_logger(__name__) @@ -131,6 +134,36 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): return model +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + @torch.jit.script def gelu_impl(x): """OpenAI's gelu implementation.""" @@ -219,7 +252,7 @@ def attention_fn( use_cache=False, ): if layer_past is not None: - past_key, past_value = layer_past + past_key, past_value = layer_past[0], layer_past[1] key_layer = torch.cat((past_key, key_layer), dim=0) value_layer = torch.cat((past_value, value_layer), dim=0) @@ -273,7 +306,7 @@ def attention_fn( if not (attention_mask == 0).all(): # if auto-regressive, skip attention_scores.masked_fill_(attention_mask, -10000.0) - dtype = attention_scores.type() + dtype = attention_scores.dtype attention_scores = attention_scores.float() attention_scores = attention_scores * query_key_layer_scaling_coeff @@ -619,10 +652,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel): """ is_parallelizable = False - supports_gradient_checkpointing = False + supports_gradient_checkpointing = True config_class = ChatGLMConfig base_model_prefix = "transformer" - _no_split_modules = ["GLM6BBlock"] + _no_split_modules = ["GLMBlock"] def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -631,6 +664,43 @@ class ChatGLMPreTrainedModel(PreTrainedModel): """Initialize the weights.""" return + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, input_ids, mask_positions, device, gmask=False): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + if self.position_encoding_2d: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [torch.cat(( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 + )) for context_length in context_lengths] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) + if not gmask: + for i, context_length in enumerate(context_lengths): + position_ids[context_length:] = mask_positions[i] + + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + CHATGLM_6B_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. @@ -727,12 +797,15 @@ class ChatGLMModel(ChatGLMPreTrainedModel): self.inner_hidden_size = config.inner_hidden_size self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads self.position_encoding_2d = config.position_encoding_2d + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection self.word_embeddings = skip_init( torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype ) + self.gradient_checkpointing = False def get_layer(layer_id): return GLMBlock( @@ -755,43 +828,38 @@ class ChatGLMModel(ChatGLMPreTrainedModel): # Final layer norm before output. self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + # total_params = sum(p.numel() for p in self.parameters()) + # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) + def get_input_embeddings(self): return self.word_embeddings def set_input_embeddings(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings - def get_masks(self, seq, device): - context_length = seq.index(self.config.bos_token_id) + 1 - - attention_mask = torch.ones((1, len(seq), len(seq)), device=device) - attention_mask.tril_() - attention_mask[..., :context_length - 1] = 1 - attention_mask.unsqueeze_(1) - attention_mask = (attention_mask < 0.5).bool() - - return attention_mask - - def get_position_ids(self, seq, mask_position, device, gmask=False): - context_length = seq.index(self.config.bos_token_id) + 1 - if self.position_encoding_2d: - seq_length = seq.index(self.config.bos_token_id) - position_ids = torch.arange(context_length, dtype=torch.long, device=device) - if not gmask: - position_ids[seq_length:] = mask_position - block_position_ids = torch.cat(( - torch.zeros(seq_length, dtype=torch.long, device=device), - torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 - )) - position_ids = torch.stack((position_ids, block_position_ids), dim=0) - else: - position_ids = torch.arange(context_length, dtype=torch.long, device=device) - if not gmask: - position_ids[context_length - 1:] = mask_position - - position_ids = position_ids.unsqueeze(0) - - return position_ids + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + # past_key_values = [(v[0], v[1]) for v in past_key_values] + return past_key_values @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -819,6 +887,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -828,31 +903,41 @@ class ChatGLMModel(ChatGLMPreTrainedModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) - seq = input_ids[0].tolist() + if self.pre_seq_len is not None: + past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, + dtype=inputs_embeds.dtype) + else: + past_key_values = tuple([None] * len(self.layers)) if attention_mask is None: attention_mask = self.get_masks( - seq=seq, + input_ids, device=input_ids.device ) + if position_ids is None: MASK, gMASK = 150000, 150001 mask_token = MASK if MASK in input_ids else gMASK use_gmask = False if MASK in input_ids else gMASK - mask_position = seq.index(mask_token) + mask_positions = [seq.tolist().index(mask_token) for seq in input_ids] position_ids = self.get_position_ids( - seq=seq, - mask_position=mask_position, + input_ids, + mask_positions=mask_positions, device=input_ids.device, gmask=use_gmask ) - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) + if self.pre_seq_len is not None and attention_mask is not None: + prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( + attention_mask.device) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) # [seq_len, batch, hidden_size] hidden_states = inputs_embeds.transpose(0, 1) @@ -861,11 +946,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel): all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[0] - seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() @@ -876,16 +956,29 @@ class ChatGLMModel(ChatGLMPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + layer_past = past_key_values[i] - layer_ret = layer( - hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - layer_id=torch.tensor(i), - layer_past=past_key_values[i], - use_cache=use_cache, - output_attentions=output_attentions - ) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + position_ids, + attention_mask, + torch.tensor(i), + layer_past, + use_cache, + output_attentions + ) + else: + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) hidden_states = layer_ret[0] @@ -946,31 +1039,40 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False): - attention_mask = torch.ones((1, context_length, context_length), device=device) - attention_mask.tril_() - attention_mask[..., :context_length - 1] = 1 - attention_mask.unsqueeze_(1) - attention_mask = (attention_mask < 0.5).bool() + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) - if self.position_encoding_2d: - seq_length = seq.index(self.config.bos_token_id) - position_ids = torch.arange(context_length, dtype=torch.long, device=device) - if not gmask: - position_ids[seq_length:] = mask_position - block_position_ids = torch.cat(( - torch.zeros(seq_length, dtype=torch.long, device=device), - torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 - )) - position_ids = torch.stack((position_ids, block_position_ids), dim=0) - else: - position_ids = torch.arange(context_length, dtype=torch.long, device=device) - if not gmask: - position_ids[context_length - 1:] = mask_position + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = False + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, new_attention_mask], dim=2 + ) - position_ids = position_ids.unsqueeze(0) + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id[:, 1, :] += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) - return attention_mask, position_ids + return model_kwargs def prepare_inputs_for_generation( self, @@ -978,27 +1080,34 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): past: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, **kwargs ) -> dict: - + batch_size, seq_length = input_ids.shape MASK, gMASK = 150000, 150001 mask_token = MASK if MASK in input_ids else gMASK use_gmask = False if MASK in input_ids else gMASK - seq = input_ids[0].tolist() - mask_position = seq.index(mask_token) - - if mask_token not in seq: - raise ValueError("You have to add either [MASK] or [gMASK] in your input") + seqs = input_ids.tolist() + mask_positions = [seq.index(mask_token) for seq in seqs] # only last token for input_ids if past is not None if past is not None or past_key_values is not None: - context_length = seq.index(self.config.bos_token_id) last_token = input_ids[:, -1].unsqueeze(-1) - if self.position_encoding_2d: - position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, - device=input_ids.device) + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = attention_mask[:, :, -1:] else: - position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device) + attention_mask = None + if position_ids is not None: + position_ids = position_ids[..., -1:] + else: + context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] + if self.position_encoding_2d: + position_ids = torch.tensor( + [[mask_position, seq_length - context_length] for mask_position, context_length in + zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) + else: + position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, + device=input_ids.device).unsqueeze(-1) if past is None: past = past_key_values @@ -1006,15 +1115,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): "input_ids": last_token, "past_key_values": past, "position_ids": position_ids, + "attention_mask": attention_mask } else: - attention_mask, position_ids = self.get_masks_and_position_ids( - seq=seq, - mask_position=mask_position, - context_length=len(seq), - device=input_ids.device, - gmask=use_gmask - ) + if attention_mask is not None and attention_mask.dtype != torch.bool: + logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") + attention_mask = None + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, + device=input_ids.device + ) + if position_ids is None: + position_ids = self.get_position_ids( + input_ids, + device=input_ids.device, + mask_positions=mask_positions, + gmask=use_gmask + ) return { "input_ids": input_ids, @@ -1063,7 +1181,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() + loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) lm_logits = lm_logits.to(hidden_states.dtype) @@ -1132,10 +1250,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): for i, (old_query, response) in enumerate(history): prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - input_ids = tokenizer([prompt], return_tensors="pt", padding=True) - input_ids = input_ids.to(self.device) - outputs = self.generate(**input_ids, **gen_kwargs) - outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs) response = self.process_response(response) history = history + [(query, response)] @@ -1158,10 +1276,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): for i, (old_query, response) in enumerate(history): prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - input_ids = tokenizer([prompt], return_tensors="pt", padding=True) - input_ids = input_ids.to(self.device) - for outputs in self.stream_generate(**input_ids, **gen_kwargs): - outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + for outputs in self.stream_generate(**inputs, **gen_kwargs): + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs) response = self.process_response(response) new_history = history + [(query, response)] @@ -1298,6 +1416,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): num_embeddings=self.transformer.word_embeddings.num_embeddings, embedding_dim=self.transformer.word_embeddings.embedding_dim, dtype=torch.half, + empty_init=True, device=self.transformer.word_embeddings.weight.device, ) self.lm_head = QuantizedLinear( @@ -1310,6 +1429,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): quantized_weight=self.transformer.word_embeddings.weight, quantized_weight_scale=self.transformer.word_embeddings.weight_scale, dtype=torch.half, + empty_init=True, device=self.lm_head.weight.device, ) diff --git a/quantization.py b/quantization.py index 228f382..61d69c2 100644 --- a/quantization.py +++ b/quantization.py @@ -7,10 +7,13 @@ import bz2 import torch import base64 import ctypes +from transformers.utils import logging from typing import List from functools import partial +logger = logging.get_logger(__name__) + try: from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up @@ -37,18 +40,18 @@ try: ) except Exception as exception: kernels = None - print("Failed to load cpm_kernels:", exception) + logger.warning("Failed to load cpm_kernels:", exception) class W8A16Linear(torch.autograd.Function): @staticmethod def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): ctx.inp_shape = inp.size() - ctx.weight_shape = quant_w.size() ctx.weight_bit_width = weight_bit_width out_features = quant_w.size(0) inp = inp.contiguous().view(-1, inp.size(-1)) weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() output = inp.mm(weight.t()) ctx.save_for_backward(inp, quant_w, scale_w) return output.view(*(ctx.inp_shape[:-1] + (out_features,))) @@ -60,18 +63,18 @@ class W8A16Linear(torch.autograd.Function): grad_output = grad_output.contiguous().view(-1, weight.size(0)) grad_input = grad_output.mm(weight) grad_weight = grad_output.t().mm(inp) - return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None + return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None class W8A16LinearCPU(torch.autograd.Function): @staticmethod def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None): ctx.inp_shape = inp.size() - ctx.weight_shape = quant_w.size() ctx.weight_bit_width = weight_bit_width out_features = quant_w.size(0) inp = inp.contiguous().view(-1, inp.size(-1)) weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache) + ctx.weight_shape = weight.size() output = inp.mm(weight.t()) ctx.save_for_backward(inp, quant_w, scale_w) return output.view(*(ctx.inp_shape[:-1] + (out_features,))) @@ -83,7 +86,7 @@ class W8A16LinearCPU(torch.autograd.Function): grad_output = grad_output.contiguous().view(-1, weight.size(0)) grad_input = grad_output.mm(weight) grad_weight = grad_output.t().mm(inp) - return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None + return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c") @@ -168,7 +171,7 @@ class CPUKernel: print("Load kernel :", kernel_file) else: print("Failed to load kernel.") - + if compile_parallel_kernel: if parallel_num is None: parallel_num = max(os.cpu_count() // 2, 1) @@ -176,7 +179,7 @@ class CPUKernel: if parallel_num < 4: print("Parallel kernel is not recommended when parallel num < 4.") self.SetNumThreads(parallel_num) - + self.parallel_num = parallel_num @@ -284,10 +287,10 @@ def extract_weight_to_float(weight: torch.Tensor, scale_list: torch.Tensor, sour class CacheTensor(): def __init__(self, *args, **kwargs): self.tensor = torch.empty(*args, **kwargs) - + def to(self, *args, **kwargs): self.tensor = self.tensor.to(*args, **kwargs) - + def data_ptr(self): return self.tensor.data_ptr() @@ -393,7 +396,7 @@ def load_cpu_kernel(**kwargs): def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs): """Replace fp16 linear with quantized linear""" - + query_key_value_quantization_cache = None dense_quantization_cache = None dense_h_to_4h_quantization_cache = None diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index aedbcbe..3062c7c 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -1,17 +1,14 @@ """Tokenization classes for ChatGLM.""" -import sys -import unicodedata from typing import List, Optional, Union -from functools import lru_cache import os -import collections -import re from transformers.tokenization_utils import PreTrainedTokenizer from icetk.text_tokenizer import TextTokenizer -from icetk.utils import auto_create import icetk.sentencepiece_model_pb2 as sp_model -from transformers.utils import logging +from transformers.utils import logging, PaddingStrategy +from transformers.tokenization_utils_base import EncodedInput, BatchEncoding +from typing import Dict +import numpy as np logger = logging.get_logger(__name__) @@ -180,7 +177,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "ice_text.model"} max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids"] + model_input_names = ["input_ids", "attention_mask", "position_ids"] def __init__( self, @@ -210,7 +207,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): self.eos_token = eos_token self.eop_token = eop_token self.mask_token = mask_token - self.gMASK_token = gmask_token + self.gmask_token = gmask_token self.sp_tokenizer = SPTokenizer(vocab_file) @@ -299,7 +296,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): """ if os.path.isdir(save_directory): vocab_file = os.path.join( - save_directory, VOCAB_FILES_NAMES["vocab_file"] + save_directory, self.vocab_files_names["vocab_file"] ) else: vocab_file = save_directory @@ -331,10 +328,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer): Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ - if token_ids_1 is not None: - token_ids_0 += token_ids_1 mask_ids = self.sp_tokenizer[self.mask_token] - gmask_ids = self.sp_tokenizer[self.gMASK_token] + gmask_ids = self.sp_tokenizer[self.gmask_token] + eop_id = self.sp_tokenizer[self.eop_token] if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0: token_ids_0 += [gmask_ids] @@ -343,4 +339,101 @@ class ChatGLMTokenizer(PreTrainedTokenizer): token_ids_0 += [self.sp_tokenizer[self.bos_token]] + if token_ids_1 is not None: + if not token_ids_1 or token_ids_1[-1] != eop_id: + token_ids_1 += [eop_id] + token_ids_0 += token_ids_1 + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + bos_token_id = self.sp_tokenizer[self.bos_token] + mask_token_id = self.sp_tokenizer[self.mask_token] + gmask_token_id = self.sp_tokenizer[self.gmask_token] + assert self.padding_side == "left" + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if max_length is not None: + if "attention_mask" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + attention_mask = np.ones((1, seq_length, seq_length)) + attention_mask = np.tril(attention_mask) + attention_mask[:, :, :context_length] = 1 + attention_mask = np.bool_(attention_mask < 0.5) + encoded_inputs["attention_mask"] = attention_mask + + if "position_ids" not in encoded_inputs: + position_ids = np.arange(seq_length, dtype=np.int64) + mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id + if mask_token in required_input: + mask_position = required_input.index(mask_token) + position_ids[context_length:] = mask_position + block_position_ids = np.concatenate( + [np.zeros(context_length, dtype=np.int64), + np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) + encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode='constant', constant_values=True) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"], + pad_width=[(0, 0), (difference, 0)]) + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs