diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 3d0888e..9b344d0 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -680,7 +680,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel): batch_size, seq_length = input_ids.shape context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] if self.position_encoding_2d: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) for i, context_length in enumerate(context_lengths): position_ids[i, context_length:] = mask_positions[i] block_position_ids = [torch.cat(( @@ -690,7 +690,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel): block_position_ids = torch.stack(block_position_ids, dim=0) position_ids = torch.stack((position_ids, block_position_ids), dim=1) else: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) if not gmask: for i, context_length in enumerate(context_lengths): position_ids[context_length:] = mask_positions[i]