diff --git a/configuration_chatglm.py b/configuration_chatglm.py index b4b196a..3bc4b6f 100644 --- a/configuration_chatglm.py +++ b/configuration_chatglm.py @@ -70,6 +70,8 @@ class ChatGLMConfig(PretrainedConfig): max_sequence_length=2048, inner_hidden_size=16384, position_encoding_2d=True, + pre_seq_len=None, + prefix_projection=False, **kwargs ): self.num_layers = num_layers @@ -84,6 +86,8 @@ class ChatGLMConfig(PretrainedConfig): self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.position_encoding_2d = position_encoding_2d + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 648a1d4..d961505 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -129,6 +129,35 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): return model +class PrefixEncoder(torch.nn.Module): + r''' + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + ''' + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + @torch.jit.script def gelu_impl(x): """OpenAI's gelu implementation.""" @@ -719,6 +748,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel): self.inner_hidden_size = config.inner_hidden_size self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads self.position_encoding_2d = config.position_encoding_2d + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection self.word_embeddings = skip_init( torch.nn.Embedding, @@ -747,12 +778,41 @@ class ChatGLMModel(ChatGLMPreTrainedModel): # Final layer norm before output. self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + # total_params = sum(p.numel() for p in self.parameters()) + # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) + def get_input_embeddings(self): return self.word_embeddings def set_input_embeddings(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings + def get_prompt(self, batch_size, device): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).half() + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads + ) + #seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + past_key_values = [(v[0], v[1]) for v in past_key_values] + # past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(self.num_layers) + # past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])] + return past_key_values + @staticmethod def get_masks(seq, device): context_length = seq.index(150004) + 1 @@ -822,7 +882,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) + if self.pre_seq_len is not None: + past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device) + else: + past_key_values = tuple([None] * len(self.layers)) MASK, gMASK = 150000, 150001 mask_token = MASK if MASK in input_ids else gMASK @@ -837,6 +900,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel): device=input_ids.device ) + if self.pre_seq_len is not None: + prefix_attention_mask = torch.ones(1, 1, len(seq), self.pre_seq_len).to(attention_mask.device) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) + if position_ids is None: position_ids = self.get_position_ids( seq=seq, @@ -1125,18 +1193,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): if "eos_token_id" not in kwargs: kwargs["eos_token_id"] = eos + truncate = kwargs.pop("truncate") if "truncate" in kwargs else False + stop = False return_seqs = [] while True: output_ids = super().generate(**kwargs) - return_seqs = [] max_length = 0 for i in range(output_ids.shape[0]): output_seq = output_ids[i].tolist() + if truncate: + output_seq = output_seq[len(kwargs["input_ids"][i]) - 2:] mask_token = MASK if MASK in output_seq else gMASK mask_position = output_seq.index(mask_token) bos_position = output_seq.index(bos)