Add pad_token_id in config.json
Fix position_ids in ChatGLMModel Add batch position_ids
This commit is contained in:
parent
db2249979c
commit
2200e2bc52
|
@ -10,6 +10,7 @@
|
||||||
},
|
},
|
||||||
"bos_token_id": 150004,
|
"bos_token_id": 150004,
|
||||||
"eos_token_id": 150005,
|
"eos_token_id": 150005,
|
||||||
|
"pad_token_id": 20003,
|
||||||
"hidden_size": 4096,
|
"hidden_size": 4096,
|
||||||
"inner_hidden_size": 16384,
|
"inner_hidden_size": 16384,
|
||||||
"layernorm_epsilon": 1e-05,
|
"layernorm_epsilon": 1e-05,
|
||||||
|
|
|
@ -850,8 +850,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
for i, context_length in enumerate(context_lengths):
|
for i, context_length in enumerate(context_lengths):
|
||||||
position_ids[context_length:] = mask_positions[i]
|
position_ids[context_length:] = mask_positions[i]
|
||||||
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
return position_ids
|
return position_ids
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
@ -1007,29 +1005,34 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.lm_head = new_embeddings
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
|
def get_masks_and_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
||||||
attention_mask = torch.ones((1, context_length, context_length), device=device)
|
batch_size, seq_length = input_ids.shape
|
||||||
|
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||||
|
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
||||||
attention_mask.tril_()
|
attention_mask.tril_()
|
||||||
attention_mask[..., :context_length - 1] = 1
|
for i, context_length in enumerate(context_lengths):
|
||||||
|
attention_mask[i, :, :context_length] = 1
|
||||||
attention_mask.unsqueeze_(1)
|
attention_mask.unsqueeze_(1)
|
||||||
attention_mask = (attention_mask < 0.5).bool()
|
attention_mask = (attention_mask < 0.5).bool()
|
||||||
|
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
seq_length = seq.index(self.config.bos_token_id)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||||
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
|
||||||
if not gmask:
|
if not gmask:
|
||||||
position_ids[seq_length:] = mask_position
|
for i, context_length in enumerate(context_lengths):
|
||||||
block_position_ids = torch.cat((
|
position_ids[i, context_length:] = mask_positions[i]
|
||||||
torch.zeros(seq_length, dtype=torch.long, device=device),
|
block_position_ids = [torch.cat((
|
||||||
torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
|
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||||
))
|
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||||
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
|
)) for context_length in context_lengths]
|
||||||
|
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||||
|
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||||
else:
|
else:
|
||||||
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||||
if not gmask:
|
if not gmask:
|
||||||
position_ids[context_length - 1:] = mask_position
|
for i, context_length in enumerate(context_lengths):
|
||||||
|
position_ids[context_length:] = mask_positions[i]
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
return attention_mask, position_ids
|
return attention_mask, position_ids
|
||||||
|
|
||||||
|
@ -1041,25 +1044,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
MASK, gMASK = 150000, 150001
|
MASK, gMASK = 150000, 150001
|
||||||
mask_token = MASK if MASK in input_ids else gMASK
|
mask_token = MASK if MASK in input_ids else gMASK
|
||||||
use_gmask = False if MASK in input_ids else gMASK
|
use_gmask = False if MASK in input_ids else gMASK
|
||||||
seq = input_ids[0].tolist()
|
seqs = input_ids.tolist()
|
||||||
mask_position = seq.index(mask_token)
|
mask_positions = [seq.index(mask_token) for seq in seqs]
|
||||||
|
|
||||||
if mask_token not in seq:
|
|
||||||
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
|
|
||||||
|
|
||||||
# only last token for input_ids if past is not None
|
# only last token for input_ids if past is not None
|
||||||
if past is not None or past_key_values is not None:
|
if past is not None or past_key_values is not None:
|
||||||
context_length = seq.index(self.config.bos_token_id)
|
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
||||||
last_token = input_ids[:, -1].unsqueeze(-1)
|
last_token = input_ids[:, -1].unsqueeze(-1)
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
position_ids = torch.tensor(
|
||||||
device=input_ids.device)
|
[[mask_position, seq_length - context_length] for mask_position, context_length in
|
||||||
|
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
|
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
||||||
|
device=input_ids.device).unsqueeze(-1)
|
||||||
|
|
||||||
if past is None:
|
if past is None:
|
||||||
past = past_key_values
|
past = past_key_values
|
||||||
|
@ -1070,9 +1072,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
attention_mask, position_ids = self.get_masks_and_position_ids(
|
attention_mask, position_ids = self.get_masks_and_position_ids(
|
||||||
seq=seq,
|
input_ids,
|
||||||
mask_position=mask_position,
|
mask_positions=mask_positions,
|
||||||
context_length=len(seq),
|
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
gmask=use_gmask
|
gmask=use_gmask
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue