Use dynamic dtype for prompts

This commit is contained in:
duzx16 2023-03-31 01:13:32 +08:00
parent 0cfae21ef8
commit c949d03152
1 changed files with 7 additions and 5 deletions

View File

@ -804,9 +804,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
def set_input_embeddings(self, new_embeddings: torch.Tensor): def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings self.word_embeddings = new_embeddings
def get_prompt(self, batch_size, device): def get_prompt(self, batch_size, device, dtype=torch.half):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
past_key_values = self.prefix_encoder(prefix_tokens).half() past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
past_key_values = past_key_values.view( past_key_values = past_key_values.view(
batch_size, batch_size,
self.pre_seq_len, self.pre_seq_len,
@ -896,9 +896,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if past_key_values is None: if past_key_values is None:
if self.pre_seq_len is not None: if self.pre_seq_len is not None:
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device) past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
dtype=inputs_embeds.dtype)
else: else:
past_key_values = tuple([None] * len(self.layers)) past_key_values = tuple([None] * len(self.layers))
@ -927,8 +931,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
gmask=use_gmask gmask=use_gmask
) )
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# [seq_len, batch, hidden_size] # [seq_len, batch, hidden_size]
hidden_states = inputs_embeds.transpose(0, 1) hidden_states = inputs_embeds.transpose(0, 1)