Remove hardcode bos_token_id
This commit is contained in:
parent
42095d42ff
commit
2460dc2430
|
@ -753,9 +753,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
||||||
self.word_embeddings = new_embeddings
|
self.word_embeddings = new_embeddings
|
||||||
|
|
||||||
@staticmethod
|
def get_masks(self, seq, device):
|
||||||
def get_masks(seq, device):
|
context_length = seq.index(self.config.bos_token_id) + 1
|
||||||
context_length = seq.index(150004) + 1
|
|
||||||
|
|
||||||
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
||||||
attention_mask.tril_()
|
attention_mask.tril_()
|
||||||
|
@ -766,9 +765,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
||||||
context_length = seq.index(150004) + 1
|
context_length = seq.index(self.config.bos_token_id) + 1
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
seq_length = seq.index(150004)
|
seq_length = seq.index(self.config.bos_token_id)
|
||||||
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
||||||
if not gmask:
|
if not gmask:
|
||||||
position_ids[seq_length:] = mask_position
|
position_ids[seq_length:] = mask_position
|
||||||
|
@ -823,14 +822,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
past_key_values = tuple([None] * len(self.layers))
|
past_key_values = tuple([None] * len(self.layers))
|
||||||
|
|
||||||
MASK, gMASK = 150000, 150001
|
|
||||||
mask_token = MASK if MASK in input_ids else gMASK
|
|
||||||
use_gmask = False if MASK in input_ids else gMASK
|
|
||||||
seq = input_ids[0].tolist()
|
seq = input_ids[0].tolist()
|
||||||
|
|
||||||
mask_position = seq.index(mask_token)
|
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = self.get_masks(
|
attention_mask = self.get_masks(
|
||||||
seq=seq,
|
seq=seq,
|
||||||
|
@ -838,6 +831,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
|
MASK, gMASK = 150000, 150001
|
||||||
|
mask_token = MASK if MASK in input_ids else gMASK
|
||||||
|
use_gmask = False if MASK in input_ids else gMASK
|
||||||
|
|
||||||
|
mask_position = seq.index(mask_token)
|
||||||
position_ids = self.get_position_ids(
|
position_ids = self.get_position_ids(
|
||||||
seq=seq,
|
seq=seq,
|
||||||
mask_position=mask_position,
|
mask_position=mask_position,
|
||||||
|
@ -941,7 +939,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
attention_mask = (attention_mask < 0.5).bool()
|
attention_mask = (attention_mask < 0.5).bool()
|
||||||
|
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
seq_length = seq.index(150004)
|
seq_length = seq.index(self.config.bos_token_id)
|
||||||
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
||||||
if not gmask:
|
if not gmask:
|
||||||
position_ids[seq_length:] = mask_position
|
position_ids[seq_length:] = mask_position
|
||||||
|
@ -979,7 +977,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
# only last token for input_ids if past is not None
|
# only last token for input_ids if past is not None
|
||||||
if past is not None or past_key_values is not None:
|
if past is not None or past_key_values is not None:
|
||||||
context_length = seq.index(150004)
|
context_length = seq.index(self.config.bos_token_id)
|
||||||
last_token = input_ids[:, -1].unsqueeze(-1)
|
last_token = input_ids[:, -1].unsqueeze(-1)
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
||||||
|
|
Loading…
Reference in New Issue