Fix input embeds
This commit is contained in:
parent
0829959f96
commit
35ca52301f
|
@ -918,7 +918,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
|
@ -972,9 +972,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
||||
|
||||
else:
|
||||
attention_mask = attention_mask.to(input_ids.device)
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
|
||||
|
|
Loading…
Reference in New Issue