From 2200e2bc52d2de98427098915ce8297bbcbca3eb Mon Sep 17 00:00:00 2001 From: duzx16 Date: Wed, 29 Mar 2023 21:52:46 +0800 Subject: [PATCH] Add pad_token_id in config.json Fix position_ids in ChatGLMModel Add batch position_ids --- config.json | 1 + modeling_chatglm.py | 61 +++++++++++++++++++++++---------------------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/config.json b/config.json index a907a4d..82b9211 100644 --- a/config.json +++ b/config.json @@ -10,6 +10,7 @@ }, "bos_token_id": 150004, "eos_token_id": 150005, + "pad_token_id": 20003, "hidden_size": 4096, "inner_hidden_size": 16384, "layernorm_epsilon": 1e-05, diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 9c06117..982c814 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -850,8 +850,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel): for i, context_length in enumerate(context_lengths): position_ids[context_length:] = mask_positions[i] - position_ids = position_ids.unsqueeze(0) - return position_ids @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @@ -1007,29 +1005,34 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False): - attention_mask = torch.ones((1, context_length, context_length), device=device) + def get_masks_and_position_ids(self, input_ids, mask_positions, device, gmask=False): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) attention_mask.tril_() - attention_mask[..., :context_length - 1] = 1 + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 attention_mask.unsqueeze_(1) attention_mask = (attention_mask < 0.5).bool() + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] if self.position_encoding_2d: - seq_length = seq.index(self.config.bos_token_id) - position_ids = torch.arange(context_length, dtype=torch.long, device=device) + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) if not gmask: - position_ids[seq_length:] = mask_position - block_position_ids = torch.cat(( - torch.zeros(seq_length, dtype=torch.long, device=device), - torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 - )) - position_ids = torch.stack((position_ids, block_position_ids), dim=0) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [torch.cat(( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 + )) for context_length in context_lengths] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) else: - position_ids = torch.arange(context_length, dtype=torch.long, device=device) + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) if not gmask: - position_ids[context_length - 1:] = mask_position - - position_ids = position_ids.unsqueeze(0) + for i, context_length in enumerate(context_lengths): + position_ids[context_length:] = mask_positions[i] return attention_mask, position_ids @@ -1041,25 +1044,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): attention_mask: Optional[torch.Tensor] = None, **kwargs ) -> dict: - + batch_size, seq_length = input_ids.shape MASK, gMASK = 150000, 150001 mask_token = MASK if MASK in input_ids else gMASK use_gmask = False if MASK in input_ids else gMASK - seq = input_ids[0].tolist() - mask_position = seq.index(mask_token) - - if mask_token not in seq: - raise ValueError("You have to add either [MASK] or [gMASK] in your input") + seqs = input_ids.tolist() + mask_positions = [seq.index(mask_token) for seq in seqs] # only last token for input_ids if past is not None if past is not None or past_key_values is not None: - context_length = seq.index(self.config.bos_token_id) + context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] last_token = input_ids[:, -1].unsqueeze(-1) if self.position_encoding_2d: - position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, - device=input_ids.device) + position_ids = torch.tensor( + [[mask_position, seq_length - context_length] for mask_position, context_length in + zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) else: - position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device) + position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, + device=input_ids.device).unsqueeze(-1) if past is None: past = past_key_values @@ -1070,9 +1072,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): } else: attention_mask, position_ids = self.get_masks_and_position_ids( - seq=seq, - mask_position=mask_position, - context_length=len(seq), + input_ids, + mask_positions=mask_positions, device=input_ids.device, gmask=use_gmask )