diff --git a/modeling_chatglm.py b/modeling_chatglm.py index c5a0d31..12c993c 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -804,9 +804,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel): def set_input_embeddings(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings - def get_prompt(self, batch_size, device): + def get_prompt(self, batch_size, device, dtype=torch.half): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).half() + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) past_key_values = past_key_values.view( batch_size, self.pre_seq_len, @@ -896,9 +896,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + if past_key_values is None: if self.pre_seq_len is not None: - past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device) + past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, + dtype=inputs_embeds.dtype) else: past_key_values = tuple([None] * len(self.layers)) @@ -927,8 +931,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel): gmask=use_gmask ) - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) # [seq_len, batch, hidden_size] hidden_states = inputs_embeds.transpose(0, 1)