diff --git a/modeling_chatglm.py b/modeling_chatglm.py index a0f3b9a..648a1d4 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -51,7 +51,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() - scores[..., 20005] = 1e5 + scores[..., 20005] = 5e4 return scores