From 08bc85104db4e8da2c215a29c469218953056251 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Sun, 2 Apr 2023 02:25:03 +0800 Subject: [PATCH] Fix attention mask for prefix prompt --- modeling_chatglm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index ef38154..cdcc39b 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -919,11 +919,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel): device=input_ids.device ) - if self.pre_seq_len is not None: - prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( - attention_mask.device) - prefix_attention_mask = (prefix_attention_mask < 0.5).bool() - attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) if position_ids is None: MASK, gMASK = 150000, 150001 @@ -938,6 +933,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel): gmask=use_gmask ) + if self.pre_seq_len is not None and attention_mask is not None: + prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( + attention_mask.device) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) + # [seq_len, batch, hidden_size] hidden_states = inputs_embeds.transpose(0, 1)