Fix input embeds

This commit is contained in:
duzx16 2023-04-18 20:46:39 +08:00
parent 0829959f96
commit 35ca52301f
1 changed files with 2 additions and 3 deletions

View File

@ -918,7 +918,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2] batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2] batch_size, seq_length = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
@ -972,9 +972,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
else: else:
attention_mask = attention_mask.to(input_ids.device) attention_mask = attention_mask.to(hidden_states.device)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):