From 2460dc243020b034420cb759ff966b83a7601d06 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Sun, 19 Mar 2023 14:56:15 +0800 Subject: [PATCH] Remove hardcode bos_token_id --- modeling_chatglm.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index a2a158b..10a1df2 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -753,9 +753,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel): def set_input_embeddings(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings - @staticmethod - def get_masks(seq, device): - context_length = seq.index(150004) + 1 + def get_masks(self, seq, device): + context_length = seq.index(self.config.bos_token_id) + 1 attention_mask = torch.ones((1, len(seq), len(seq)), device=device) attention_mask.tril_() @@ -766,9 +765,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel): return attention_mask def get_position_ids(self, seq, mask_position, device, gmask=False): - context_length = seq.index(150004) + 1 + context_length = seq.index(self.config.bos_token_id) + 1 if self.position_encoding_2d: - seq_length = seq.index(150004) + seq_length = seq.index(self.config.bos_token_id) position_ids = torch.arange(context_length, dtype=torch.long, device=device) if not gmask: position_ids[seq_length:] = mask_position @@ -823,14 +822,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel): if past_key_values is None: past_key_values = tuple([None] * len(self.layers)) - - MASK, gMASK = 150000, 150001 - mask_token = MASK if MASK in input_ids else gMASK - use_gmask = False if MASK in input_ids else gMASK seq = input_ids[0].tolist() - mask_position = seq.index(mask_token) - if attention_mask is None: attention_mask = self.get_masks( seq=seq, @@ -838,6 +831,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ) if position_ids is None: + MASK, gMASK = 150000, 150001 + mask_token = MASK if MASK in input_ids else gMASK + use_gmask = False if MASK in input_ids else gMASK + + mask_position = seq.index(mask_token) position_ids = self.get_position_ids( seq=seq, mask_position=mask_position, @@ -941,7 +939,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): attention_mask = (attention_mask < 0.5).bool() if self.position_encoding_2d: - seq_length = seq.index(150004) + seq_length = seq.index(self.config.bos_token_id) position_ids = torch.arange(context_length, dtype=torch.long, device=device) if not gmask: position_ids[seq_length:] = mask_position @@ -979,7 +977,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): # only last token for input_ids if past is not None if past is not None or past_key_values is not None: - context_length = seq.index(150004) + context_length = seq.index(self.config.bos_token_id) last_token = input_ids[:, -1].unsqueeze(-1) if self.position_encoding_2d: position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,