Support batch training
This commit is contained in:
parent
fbda1206cb
commit
8127ab6abf
|
@ -818,33 +818,37 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_masks(self, seq, device):
|
def get_masks(self, input_ids, device):
|
||||||
context_length = seq.index(self.config.bos_token_id) + 1
|
batch_size, seq_length = input_ids.shape
|
||||||
|
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||||
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
||||||
attention_mask.tril_()
|
attention_mask.tril_()
|
||||||
attention_mask[..., :context_length - 1] = 1
|
for i, context_length in enumerate(context_lengths):
|
||||||
|
attention_mask[i, :, :context_length] = 1
|
||||||
attention_mask.unsqueeze_(1)
|
attention_mask.unsqueeze_(1)
|
||||||
attention_mask = (attention_mask < 0.5).bool()
|
attention_mask = (attention_mask < 0.5).bool()
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
||||||
context_length = len(seq)
|
batch_size, seq_length = input_ids.shape
|
||||||
|
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
seq_length = seq.index(self.config.bos_token_id)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||||
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
|
||||||
if not gmask:
|
if not gmask:
|
||||||
position_ids[seq_length:] = mask_position
|
for i, context_length in enumerate(context_lengths):
|
||||||
block_position_ids = torch.cat((
|
position_ids[i, context_length:] = mask_positions[i]
|
||||||
torch.zeros(seq_length, dtype=torch.long, device=device),
|
block_position_ids = [torch.cat((
|
||||||
torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
|
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||||
))
|
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||||
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
|
)) for context_length in context_lengths]
|
||||||
|
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||||
|
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||||
else:
|
else:
|
||||||
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||||
if not gmask:
|
if not gmask:
|
||||||
position_ids[context_length - 1:] = mask_position
|
for i, context_length in enumerate(context_lengths):
|
||||||
|
position_ids[context_length:] = mask_positions[i]
|
||||||
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
@ -890,16 +894,15 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
||||||
else:
|
else:
|
||||||
past_key_values = tuple([None] * len(self.layers))
|
past_key_values = tuple([None] * len(self.layers))
|
||||||
seq = input_ids[0].tolist()
|
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = self.get_masks(
|
attention_mask = self.get_masks(
|
||||||
seq=seq,
|
input_ids,
|
||||||
device=input_ids.device
|
device=input_ids.device
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pre_seq_len is not None:
|
if self.pre_seq_len is not None:
|
||||||
prefix_attention_mask = torch.ones(1, 1, len(seq), self.pre_seq_len).to(attention_mask.device)
|
prefix_attention_mask = torch.ones(1, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
|
||||||
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
||||||
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
||||||
|
|
||||||
|
@ -908,10 +911,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
mask_token = MASK if MASK in input_ids else gMASK
|
mask_token = MASK if MASK in input_ids else gMASK
|
||||||
use_gmask = False if MASK in input_ids else gMASK
|
use_gmask = False if MASK in input_ids else gMASK
|
||||||
|
|
||||||
mask_position = seq.index(mask_token)
|
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
||||||
position_ids = self.get_position_ids(
|
position_ids = self.get_position_ids(
|
||||||
seq=seq,
|
input_ids,
|
||||||
mask_position=mask_position,
|
mask_positions=mask_positions,
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
gmask=use_gmask
|
gmask=use_gmask
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue