diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 982c814..8ea3776 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -817,7 +817,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel): # past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])] return past_key_values - @staticmethod def get_masks(self, input_ids, device): batch_size, seq_length = input_ids.shape context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] @@ -900,7 +899,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ) if self.pre_seq_len is not None: - prefix_attention_mask = torch.ones(1, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device) + prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device) prefix_attention_mask = (prefix_attention_mask < 0.5).bool() attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)