Fix position ids expand

This commit is contained in:
duzx16 2023-04-03 14:14:20 +08:00
parent fb23542cfe
commit f82b180d8d
1 changed files with 2 additions and 2 deletions

View File

@ -680,7 +680,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
if self.position_encoding_2d: if self.position_encoding_2d:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
for i, context_length in enumerate(context_lengths): for i, context_length in enumerate(context_lengths):
position_ids[i, context_length:] = mask_positions[i] position_ids[i, context_length:] = mask_positions[i]
block_position_ids = [torch.cat(( block_position_ids = [torch.cat((
@ -690,7 +690,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
block_position_ids = torch.stack(block_position_ids, dim=0) block_position_ids = torch.stack(block_position_ids, dim=0)
position_ids = torch.stack((position_ids, block_position_ids), dim=1) position_ids = torch.stack((position_ids, block_position_ids), dim=1)
else: else:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
if not gmask: if not gmask:
for i, context_length in enumerate(context_lengths): for i, context_length in enumerate(context_lengths):
position_ids[context_length:] = mask_positions[i] position_ids[context_length:] = mask_positions[i]