Sync with chatglm-6b

This commit is contained in:
duzx16 2023-04-28 20:17:25 +08:00
parent e02ba894cf
commit f55a1089a2
3 changed files with 35 additions and 21 deletions

View File

@ -10,16 +10,16 @@
}, },
"bos_token_id": 130004, "bos_token_id": 130004,
"eos_token_id": 130005, "eos_token_id": 130005,
"mask_token_id": 130000,
"gmask_token_id": 130001, "gmask_token_id": 130001,
"pad_token_id": 3,
"hidden_size": 4096, "hidden_size": 4096,
"inner_hidden_size": 16384, "inner_hidden_size": 16384,
"layernorm_epsilon": 1e-05, "layernorm_epsilon": 1e-05,
"mask_token_id": 130000,
"max_sequence_length": 2048, "max_sequence_length": 2048,
"model_type": "chatglm", "model_type": "chatglm",
"num_attention_heads": 32, "num_attention_heads": 32,
"num_layers": 28, "num_layers": 28,
"pad_token_id": 3,
"position_encoding_2d": true, "position_encoding_2d": true,
"quantization_bit": 4, "quantization_bit": 4,
"quantization_embeddings": false, "quantization_embeddings": false,

View File

@ -918,7 +918,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2] batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2] batch_size, seq_length = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
@ -972,9 +972,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
else: else:
attention_mask = attention_mask.to(input_ids.device) attention_mask = attention_mask.to(hidden_states.device)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):

File diff suppressed because one or more lines are too long