Add pad_token_id in config.json

Fix position_ids in ChatGLMModel
Add batch position_ids
This commit is contained in:
duzx16 2023-03-29 21:52:46 +08:00
parent db2249979c
commit 2200e2bc52
2 changed files with 32 additions and 30 deletions

View File

@ -10,6 +10,7 @@
},
"bos_token_id": 150004,
"eos_token_id": 150005,
"pad_token_id": 20003,
"hidden_size": 4096,
"inner_hidden_size": 16384,
"layernorm_epsilon": 1e-05,

View File

@ -850,8 +850,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
for i, context_length in enumerate(context_lengths):
position_ids[context_length:] = mask_positions[i]
position_ids = position_ids.unsqueeze(0)
return position_ids
@add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1007,29 +1005,34 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
attention_mask = torch.ones((1, context_length, context_length), device=device)
def get_masks_and_position_ids(self, input_ids, mask_positions, device, gmask=False):
batch_size, seq_length = input_ids.shape
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
attention_mask.tril_()
attention_mask[..., :context_length - 1] = 1
for i, context_length in enumerate(context_lengths):
attention_mask[i, :, :context_length] = 1
attention_mask.unsqueeze_(1)
attention_mask = (attention_mask < 0.5).bool()
batch_size, seq_length = input_ids.shape
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
if self.position_encoding_2d:
seq_length = seq.index(self.config.bos_token_id)
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
if not gmask:
position_ids[seq_length:] = mask_position
block_position_ids = torch.cat((
torch.zeros(seq_length, dtype=torch.long, device=device),
torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
))
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
for i, context_length in enumerate(context_lengths):
position_ids[i, context_length:] = mask_positions[i]
block_position_ids = [torch.cat((
torch.zeros(context_length, dtype=torch.long, device=device),
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
)) for context_length in context_lengths]
block_position_ids = torch.stack(block_position_ids, dim=0)
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
else:
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
if not gmask:
position_ids[context_length - 1:] = mask_position
position_ids = position_ids.unsqueeze(0)
for i, context_length in enumerate(context_lengths):
position_ids[context_length:] = mask_positions[i]
return attention_mask, position_ids
@ -1041,25 +1044,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
**kwargs
) -> dict:
batch_size, seq_length = input_ids.shape
MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK
seq = input_ids[0].tolist()
mask_position = seq.index(mask_token)
if mask_token not in seq:
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
seqs = input_ids.tolist()
mask_positions = [seq.index(mask_token) for seq in seqs]
# only last token for input_ids if past is not None
if past is not None or past_key_values is not None:
context_length = seq.index(self.config.bos_token_id)
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
last_token = input_ids[:, -1].unsqueeze(-1)
if self.position_encoding_2d:
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
device=input_ids.device)
position_ids = torch.tensor(
[[mask_position, seq_length - context_length] for mask_position, context_length in
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
else:
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
device=input_ids.device).unsqueeze(-1)
if past is None:
past = past_key_values
@ -1070,9 +1072,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
}
else:
attention_mask, position_ids = self.get_masks_and_position_ids(
seq=seq,
mask_position=mask_position,
context_length=len(seq),
input_ids,
mask_positions=mask_positions,
device=input_ids.device,
gmask=use_gmask
)