Add logit processor for NaN or Inf scores
This commit is contained in:
parent
9d1509a1ad
commit
c3dece3f01
|
@ -3,6 +3,7 @@
|
||||||
import math
|
import math
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
@ -23,8 +24,10 @@ from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
)
|
)
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
from transformers.generation.utils import LogitsProcessorList
|
||||||
|
|
||||||
from .configuration_chatglm import ChatGLMConfig
|
from .configuration_chatglm import ChatGLMConfig
|
||||||
|
|
||||||
# flags required to enable jit fusion kernels
|
# flags required to enable jit fusion kernels
|
||||||
|
@ -44,6 +47,14 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||||
|
scores.zero_()
|
||||||
|
scores[..., 20005] = 1e5
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
||||||
"""Load tf checkpoints in a pytorch model."""
|
"""Load tf checkpoints in a pytorch model."""
|
||||||
try:
|
try:
|
||||||
|
@ -1078,11 +1089,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
||||||
do_sample=True, top_p=0.7, temperature=0.95, **kwargs):
|
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
|
if logits_processor is None:
|
||||||
|
logits_processor = LogitsProcessorList()
|
||||||
|
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
||||||
"temperature": temperature, **kwargs}
|
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
||||||
if not history:
|
if not history:
|
||||||
prompt = query
|
prompt = query
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue