Set ignore_index for CrossEntropyLoss
This commit is contained in:
parent
8127ab6abf
commit
5c64357295
|
@ -1124,7 +1124,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
|
||||||
lm_logits = lm_logits.to(hidden_states.dtype)
|
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||||
|
|
Loading…
Reference in New Issue