Use dynamic dtype for prompts
This commit is contained in:
parent
0cfae21ef8
commit
c949d03152
|
@ -804,9 +804,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
||||||
self.word_embeddings = new_embeddings
|
self.word_embeddings = new_embeddings
|
||||||
|
|
||||||
def get_prompt(self, batch_size, device):
|
def get_prompt(self, batch_size, device, dtype=torch.half):
|
||||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
||||||
past_key_values = self.prefix_encoder(prefix_tokens).half()
|
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
||||||
past_key_values = past_key_values.view(
|
past_key_values = past_key_values.view(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.pre_seq_len,
|
self.pre_seq_len,
|
||||||
|
@ -896,9 +896,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
if self.pre_seq_len is not None:
|
if self.pre_seq_len is not None:
|
||||||
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
|
||||||
|
dtype=inputs_embeds.dtype)
|
||||||
else:
|
else:
|
||||||
past_key_values = tuple([None] * len(self.layers))
|
past_key_values = tuple([None] * len(self.layers))
|
||||||
|
|
||||||
|
@ -927,8 +931,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
gmask=use_gmask
|
gmask=use_gmask
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
# [seq_len, batch, hidden_size]
|
# [seq_len, batch, hidden_size]
|
||||||
hidden_states = inputs_embeds.transpose(0, 1)
|
hidden_states = inputs_embeds.transpose(0, 1)
|
||||||
|
|
Loading…
Reference in New Issue