Add support for streaming output

This commit is contained in:
duzx16 2023-03-19 14:31:26 +08:00
parent 220f772e9a
commit 42095d42ff
1 changed files with 117 additions and 39 deletions

View File

@ -3,7 +3,7 @@
import math import math
import copy import copy
import os import os
import time import warnings
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -11,7 +11,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List, Callable
from transformers.utils import ( from transformers.utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
@ -26,7 +26,7 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from .configuration_chatglm import ChatGLMConfig from .configuration_chatglm import ChatGLMConfig
@ -1107,7 +1107,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
input_ids = tokenizer([prompt], return_tensors="pt", padding=True) input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
input_ids = input_ids.to(self.device) input_ids = input_ids.to(self.device)
outputs = self.generate(**input_ids, **gen_kwargs) outputs = self.generate(**input_ids, **gen_kwargs)
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:] outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
response = tokenizer.decode(outputs) response = tokenizer.decode(outputs)
response = response.strip() response = response.strip()
response = response.replace("[[训练时间]]", "2023年") response = response.replace("[[训练时间]]", "2023年")
@ -1115,55 +1115,133 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
return response, history return response, history
@torch.no_grad() @torch.no_grad()
def generate( def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
if not history:
prompt = query
else:
prompt = ""
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
input_ids = input_ids.to(self.device)
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
response = tokenizer.decode(outputs)
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
new_history = history + [(query, response)]
yield response, new_history
@torch.no_grad()
def stream_generate(
self, self,
input_ids,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs, **kwargs,
): ):
MASK, gMASK = 150000, 150001 batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
bos, eos = 150004, 150005
if "eos_token_id" not in kwargs: if generation_config is None:
kwargs["eos_token_id"] = eos generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
stop = False if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
return_seqs = [] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = self._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True: while True:
output_ids = super().generate(**kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return_seqs = [] next_token_logits = outputs.logits[:, -1, :]
max_length = 0
for i in range(output_ids.shape[0]): # pre-process distribution
output_seq = output_ids[i].tolist() next_token_scores = logits_processor(input_ids, next_token_logits)
mask_token = MASK if MASK in output_seq else gMASK next_token_scores = logits_warper(input_ids, next_token_scores)
mask_position = output_seq.index(mask_token)
bos_position = output_seq.index(bos)
if eos in output_seq:
eos_position = output_seq.index(eos)
else:
eos_position = len(output_seq)
return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[ # sample
mask_position + 1:bos_position] probs = nn.functional.softmax(next_token_scores, dim=-1)
max_length = max(max_length, len(return_seq)) if generation_config.do_sample:
return_seqs.append(return_seq) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
for i in range(output_ids.shape[0]): # update generated ids, model inputs, and length for next step
return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if mask_token not in return_seqs[i]: model_kwargs = self._update_model_kwargs_for_generation(
stop = True outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
if stop: # stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break break
yield input_ids
for return_seq in return_seqs:
return_seq += [bos]
kwargs['input_ids'] = torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
def quantize(self, bits: int): def quantize(self, bits: int):
from .quantization import quantize from .quantization import quantize