Add logit processor for NaN or Inf scores
This commit is contained in:
parent
9d1509a1ad
commit
c3dece3f01
|
@ -3,6 +3,7 @@
|
|||
import math
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
@ -23,8 +24,10 @@ from transformers.modeling_outputs import (
|
|||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from transformers.utils import logging
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
|
||||
from .configuration_chatglm import ChatGLMConfig
|
||||
|
||||
# flags required to enable jit fusion kernels
|
||||
|
@ -44,6 +47,14 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||
]
|
||||
|
||||
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 20005] = 1e5
|
||||
return scores
|
||||
|
||||
|
||||
def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
||||
"""Load tf checkpoints in a pytorch model."""
|
||||
try:
|
||||
|
@ -1078,11 +1089,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
|
||||
@torch.no_grad()
|
||||
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
||||
do_sample=True, top_p=0.7, temperature=0.95, **kwargs):
|
||||
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
||||
if history is None:
|
||||
history = []
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
||||
"temperature": temperature, **kwargs}
|
||||
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
||||
if not history:
|
||||
prompt = query
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue