diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 8ee5cf1..9392c13 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -55,7 +55,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() - scores[..., 20005] = 5e4 + scores[..., 5] = 5e4 return scores