From 5c643572950288250fa2dac1a02684e319f8b718 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Wed, 29 Mar 2023 21:19:38 +0800 Subject: [PATCH] Set ignore_index for CrossEntropyLoss --- modeling_chatglm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index f0f3a3b..9c06117 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -1124,7 +1124,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() + loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) lm_logits = lm_logits.to(hidden_states.dtype)