Set ignore_index for CrossEntropyLoss

This commit is contained in:
duzx16 2023-03-29 21:19:38 +08:00
parent 8127ab6abf
commit 5c64357295
1 changed files with 1 additions and 1 deletions

View File

@ -1124,7 +1124,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)