Fix LogitsProcessor using slim checkpoint (#29)

- Fix LogitsProcessor using slim checkpoint (7f8f01fee41efeaac4f926bdb96aea42f1c6076b)


Co-authored-by: bcol <bcol@users.noreply.huggingface.co>
This commit is contained in:
Zhengxiao Du 2023-04-08 02:54:27 +00:00 committed by system
parent 9324de70a9
commit 61eee50c9f
1 changed files with 1 additions and 1 deletions

View File

@ -55,7 +55,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 20005] = 5e4
scores[..., 5] = 5e4
return scores