Add pad_token_id in config.json
Fix position_ids in ChatGLMModel Add batch position_ids
This commit is contained in:
parent
db2249979c
commit
2200e2bc52
|
@ -10,6 +10,7 @@
|
|||
},
|
||||
"bos_token_id": 150004,
|
||||
"eos_token_id": 150005,
|
||||
"pad_token_id": 20003,
|
||||
"hidden_size": 4096,
|
||||
"inner_hidden_size": 16384,
|
||||
"layernorm_epsilon": 1e-05,
|
||||
|
|
|
@ -850,8 +850,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[context_length:] = mask_positions[i]
|
||||
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
return position_ids
|
||||
|
||||
@add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
|
@ -1007,29 +1005,34 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
|
||||
attention_mask = torch.ones((1, context_length, context_length), device=device)
|
||||
def get_masks_and_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
||||
attention_mask.tril_()
|
||||
attention_mask[..., :context_length - 1] = 1
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
attention_mask[i, :, :context_length] = 1
|
||||
attention_mask.unsqueeze_(1)
|
||||
attention_mask = (attention_mask < 0.5).bool()
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||
if self.position_encoding_2d:
|
||||
seq_length = seq.index(self.config.bos_token_id)
|
||||
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||
if not gmask:
|
||||
position_ids[seq_length:] = mask_position
|
||||
block_position_ids = torch.cat((
|
||||
torch.zeros(seq_length, dtype=torch.long, device=device),
|
||||
torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
|
||||
))
|
||||
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[i, context_length:] = mask_positions[i]
|
||||
block_position_ids = [torch.cat((
|
||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||
)) for context_length in context_lengths]
|
||||
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||
else:
|
||||
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||
if not gmask:
|
||||
position_ids[context_length - 1:] = mask_position
|
||||
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[context_length:] = mask_positions[i]
|
||||
|
||||
return attention_mask, position_ids
|
||||
|
||||
|
@ -1041,25 +1044,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
MASK, gMASK = 150000, 150001
|
||||
mask_token = MASK if MASK in input_ids else gMASK
|
||||
use_gmask = False if MASK in input_ids else gMASK
|
||||
seq = input_ids[0].tolist()
|
||||
mask_position = seq.index(mask_token)
|
||||
|
||||
if mask_token not in seq:
|
||||
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
|
||||
seqs = input_ids.tolist()
|
||||
mask_positions = [seq.index(mask_token) for seq in seqs]
|
||||
|
||||
# only last token for input_ids if past is not None
|
||||
if past is not None or past_key_values is not None:
|
||||
context_length = seq.index(self.config.bos_token_id)
|
||||
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
||||
last_token = input_ids[:, -1].unsqueeze(-1)
|
||||
if self.position_encoding_2d:
|
||||
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
||||
device=input_ids.device)
|
||||
position_ids = torch.tensor(
|
||||
[[mask_position, seq_length - context_length] for mask_position, context_length in
|
||||
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
||||
else:
|
||||
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
|
||||
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
||||
device=input_ids.device).unsqueeze(-1)
|
||||
|
||||
if past is None:
|
||||
past = past_key_values
|
||||
|
@ -1070,9 +1072,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
}
|
||||
else:
|
||||
attention_mask, position_ids = self.get_masks_and_position_ids(
|
||||
seq=seq,
|
||||
mask_position=mask_position,
|
||||
context_length=len(seq),
|
||||
input_ids,
|
||||
mask_positions=mask_positions,
|
||||
device=input_ids.device,
|
||||
gmask=use_gmask
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue