Fix position id for training
This commit is contained in:
parent
9c7416d834
commit
11c270c26c
|
@ -845,9 +845,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||||
if not gmask:
|
for i, context_length in enumerate(context_lengths):
|
||||||
for i, context_length in enumerate(context_lengths):
|
position_ids[i, context_length:] = mask_positions[i]
|
||||||
position_ids[i, context_length:] = mask_positions[i]
|
|
||||||
block_position_ids = [torch.cat((
|
block_position_ids = [torch.cat((
|
||||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||||
|
@ -1053,9 +1052,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||||
if not gmask:
|
for i, context_length in enumerate(context_lengths):
|
||||||
for i, context_length in enumerate(context_lengths):
|
position_ids[i, context_length:] = mask_positions[i]
|
||||||
position_ids[i, context_length:] = mask_positions[i]
|
|
||||||
block_position_ids = [torch.cat((
|
block_position_ids = [torch.cat((
|
||||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||||
|
|
Loading…
Reference in New Issue