Fix overflow in FP16

This commit is contained in:
duzx16 2023-03-16 09:26:05 +08:00
parent f9f74fda55
commit 220f772e9a
1 changed files with 1 additions and 1 deletions

View File

@ -51,7 +51,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 20005] = 1e5
scores[..., 20005] = 5e4
return scores