Fix attention mask for prefix prompt
This commit is contained in:
parent
4b7ffbf070
commit
08bc85104d
|
@ -919,11 +919,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
device=input_ids.device
|
||||
)
|
||||
|
||||
if self.pre_seq_len is not None:
|
||||
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
||||
attention_mask.device)
|
||||
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
||||
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
||||
|
||||
if position_ids is None:
|
||||
MASK, gMASK = 150000, 150001
|
||||
|
@ -938,6 +933,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
gmask=use_gmask
|
||||
)
|
||||
|
||||
if self.pre_seq_len is not None and attention_mask is not None:
|
||||
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
||||
attention_mask.device)
|
||||
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
||||
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
||||
|
||||
# [seq_len, batch, hidden_size]
|
||||
hidden_states = inputs_embeds.transpose(0, 1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue