Remove hardcode bos_token_id

This commit is contained in:
duzx16 2023-03-19 14:56:15 +08:00
parent 42095d42ff
commit 2460dc2430
1 changed files with 11 additions and 13 deletions

View File

@ -753,9 +753,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
def set_input_embeddings(self, new_embeddings: torch.Tensor): def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings self.word_embeddings = new_embeddings
@staticmethod def get_masks(self, seq, device):
def get_masks(seq, device): context_length = seq.index(self.config.bos_token_id) + 1
context_length = seq.index(150004) + 1
attention_mask = torch.ones((1, len(seq), len(seq)), device=device) attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
attention_mask.tril_() attention_mask.tril_()
@ -766,9 +765,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
return attention_mask return attention_mask
def get_position_ids(self, seq, mask_position, device, gmask=False): def get_position_ids(self, seq, mask_position, device, gmask=False):
context_length = seq.index(150004) + 1 context_length = seq.index(self.config.bos_token_id) + 1
if self.position_encoding_2d: if self.position_encoding_2d:
seq_length = seq.index(150004) seq_length = seq.index(self.config.bos_token_id)
position_ids = torch.arange(context_length, dtype=torch.long, device=device) position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask: if not gmask:
position_ids[seq_length:] = mask_position position_ids[seq_length:] = mask_position
@ -823,14 +822,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if past_key_values is None: if past_key_values is None:
past_key_values = tuple([None] * len(self.layers)) past_key_values = tuple([None] * len(self.layers))
MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK
seq = input_ids[0].tolist() seq = input_ids[0].tolist()
mask_position = seq.index(mask_token)
if attention_mask is None: if attention_mask is None:
attention_mask = self.get_masks( attention_mask = self.get_masks(
seq=seq, seq=seq,
@ -838,6 +831,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
) )
if position_ids is None: if position_ids is None:
MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK
mask_position = seq.index(mask_token)
position_ids = self.get_position_ids( position_ids = self.get_position_ids(
seq=seq, seq=seq,
mask_position=mask_position, mask_position=mask_position,
@ -941,7 +939,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
attention_mask = (attention_mask < 0.5).bool() attention_mask = (attention_mask < 0.5).bool()
if self.position_encoding_2d: if self.position_encoding_2d:
seq_length = seq.index(150004) seq_length = seq.index(self.config.bos_token_id)
position_ids = torch.arange(context_length, dtype=torch.long, device=device) position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask: if not gmask:
position_ids[seq_length:] = mask_position position_ids[seq_length:] = mask_position
@ -979,7 +977,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
# only last token for input_ids if past is not None # only last token for input_ids if past is not None
if past is not None or past_key_values is not None: if past is not None or past_key_values is not None:
context_length = seq.index(150004) context_length = seq.index(self.config.bos_token_id)
last_token = input_ids[:, -1].unsqueeze(-1) last_token = input_ids[:, -1].unsqueeze(-1)
if self.position_encoding_2d: if self.position_encoding_2d:
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,