diff --git a/modeling_chatglm.py b/modeling_chatglm.py index bab58f8..f0f3a3b 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -818,33 +818,37 @@ class ChatGLMModel(ChatGLMPreTrainedModel): return past_key_values @staticmethod - def get_masks(self, seq, device): - context_length = seq.index(self.config.bos_token_id) + 1 - - attention_mask = torch.ones((1, len(seq), len(seq)), device=device) + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) attention_mask.tril_() - attention_mask[..., :context_length - 1] = 1 + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 attention_mask.unsqueeze_(1) attention_mask = (attention_mask < 0.5).bool() return attention_mask - def get_position_ids(self, seq, mask_position, device, gmask=False): - context_length = len(seq) + def get_position_ids(self, input_ids, mask_positions, device, gmask=False): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] if self.position_encoding_2d: - seq_length = seq.index(self.config.bos_token_id) - position_ids = torch.arange(context_length, dtype=torch.long, device=device) + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) if not gmask: - position_ids[seq_length:] = mask_position - block_position_ids = torch.cat(( - torch.zeros(seq_length, dtype=torch.long, device=device), - torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 - )) - position_ids = torch.stack((position_ids, block_position_ids), dim=0) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [torch.cat(( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 + )) for context_length in context_lengths] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) else: - position_ids = torch.arange(context_length, dtype=torch.long, device=device) + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) if not gmask: - position_ids[context_length - 1:] = mask_position + for i, context_length in enumerate(context_lengths): + position_ids[context_length:] = mask_positions[i] position_ids = position_ids.unsqueeze(0) @@ -890,16 +894,15 @@ class ChatGLMModel(ChatGLMPreTrainedModel): past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device) else: past_key_values = tuple([None] * len(self.layers)) - seq = input_ids[0].tolist() if attention_mask is None: attention_mask = self.get_masks( - seq=seq, + input_ids, device=input_ids.device ) if self.pre_seq_len is not None: - prefix_attention_mask = torch.ones(1, 1, len(seq), self.pre_seq_len).to(attention_mask.device) + prefix_attention_mask = torch.ones(1, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device) prefix_attention_mask = (prefix_attention_mask < 0.5).bool() attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) @@ -908,10 +911,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel): mask_token = MASK if MASK in input_ids else gMASK use_gmask = False if MASK in input_ids else gMASK - mask_position = seq.index(mask_token) + mask_positions = [seq.tolist().index(mask_token) for seq in input_ids] position_ids = self.get_position_ids( - seq=seq, - mask_position=mask_position, + input_ids, + mask_positions=mask_positions, device=input_ids.device, gmask=use_gmask )