diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 1429dbb..bdd2d93 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -3,6 +3,7 @@ import math import copy import os +import time import torch import torch.utils.checkpoint @@ -23,8 +24,10 @@ from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, ) from transformers.modeling_utils import PreTrainedModel - from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList + from .configuration_chatglm import ChatGLMConfig # flags required to enable jit fusion kernels @@ -44,6 +47,14 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 20005] = 1e5 + return scores + + def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): """Load tf checkpoints in a pytorch model.""" try: @@ -1078,11 +1089,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): @torch.no_grad() def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, - do_sample=True, top_p=0.7, temperature=0.95, **kwargs): + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): if history is None: history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, **kwargs} + "temperature": temperature, "logits_processor": logits_processor, **kwargs} if not history: prompt = query else: