This commit is contained in:
duzx16 2023-03-30 17:35:58 +08:00
parent 2200e2bc52
commit 0564795e6e
1 changed files with 1 additions and 2 deletions

View File

@ -817,7 +817,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
# past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])]
return past_key_values
@staticmethod
def get_masks(self, input_ids, device):
batch_size, seq_length = input_ids.shape
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
@ -900,7 +899,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
)
if self.pre_seq_len is not None:
prefix_attention_mask = torch.ones(1, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)