Fix attention mask for prefix prompt
This commit is contained in:
parent
4b7ffbf070
commit
08bc85104d
|
@ -919,11 +919,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
device=input_ids.device
|
device=input_ids.device
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pre_seq_len is not None:
|
|
||||||
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
|
||||||
attention_mask.device)
|
|
||||||
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
|
||||||
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
MASK, gMASK = 150000, 150001
|
MASK, gMASK = 150000, 150001
|
||||||
|
@ -938,6 +933,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
gmask=use_gmask
|
gmask=use_gmask
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.pre_seq_len is not None and attention_mask is not None:
|
||||||
|
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
||||||
|
attention_mask.device)
|
||||||
|
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
||||||
|
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
||||||
|
|
||||||
# [seq_len, batch, hidden_size]
|
# [seq_len, batch, hidden_size]
|
||||||
hidden_states = inputs_embeds.transpose(0, 1)
|
hidden_states = inputs_embeds.transpose(0, 1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue