Add logit processor for NaN or Inf scores

This commit is contained in:
duzx16 2023-03-15 18:14:34 +08:00
parent 9d1509a1ad
commit c3dece3f01
1 changed files with 17 additions and 3 deletions

View File

@ -3,6 +3,7 @@
import math
import copy
import os
import time
import torch
import torch.utils.checkpoint
@ -23,8 +24,10 @@ from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels
@ -44,6 +47,14 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 20005] = 1e5
return scores
def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
@ -1078,11 +1089,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
@torch.no_grad()
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
do_sample=True, top_p=0.7, temperature=0.95, **kwargs):
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, **kwargs}
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
if not history:
prompt = query
else: