From 42095d42ff56b81ff5647dec51cf31045b6eccb5 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Sun, 19 Mar 2023 14:31:26 +0800 Subject: [PATCH] Add support for streaming output --- modeling_chatglm.py | 156 +++++++++++++++++++++++++++++++++----------- 1 file changed, 117 insertions(+), 39 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 648a1d4..a2a158b 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -3,7 +3,7 @@ import math import copy import os -import time +import warnings import torch import torch.utils.checkpoint @@ -11,7 +11,7 @@ import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List +from typing import Optional, Tuple, Union, List, Callable from transformers.utils import ( add_code_sample_docstrings, @@ -26,7 +26,7 @@ from transformers.modeling_outputs import ( from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig from .configuration_chatglm import ChatGLMConfig @@ -1107,7 +1107,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): input_ids = tokenizer([prompt], return_tensors="pt", padding=True) input_ids = input_ids.to(self.device) outputs = self.generate(**input_ids, **gen_kwargs) - outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:] + outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] response = tokenizer.decode(outputs) response = response.strip() response = response.replace("[[训练时间]]", "2023年") @@ -1115,55 +1115,133 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): return response, history @torch.no_grad() - def generate( + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + input_ids = tokenizer([prompt], return_tensors="pt", padding=True) + input_ids = input_ids.to(self.device) + for outputs in self.stream_generate(**input_ids, **gen_kwargs): + outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + new_history = history + [(query, response)] + yield response, new_history + + @torch.no_grad() + def stream_generate( self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, **kwargs, ): - MASK, gMASK = 150000, 150001 - bos, eos = 150004, 150005 + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - if "eos_token_id" not in kwargs: - kwargs["eos_token_id"] = eos + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - stop = False + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] - return_seqs = [] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None while True: - output_ids = super().generate(**kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) - return_seqs = [] - max_length = 0 + next_token_logits = outputs.logits[:, -1, :] - for i in range(output_ids.shape[0]): - output_seq = output_ids[i].tolist() - mask_token = MASK if MASK in output_seq else gMASK - mask_position = output_seq.index(mask_token) - bos_position = output_seq.index(bos) - if eos in output_seq: - eos_position = output_seq.index(eos) - else: - eos_position = len(output_seq) + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) - return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[ - mask_position + 1:bos_position] - max_length = max(max_length, len(return_seq)) - return_seqs.append(return_seq) + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) - for i in range(output_ids.shape[0]): - return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding - if mask_token not in return_seqs[i]: - stop = True + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) - if stop: + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): break - - for return_seq in return_seqs: - return_seq += [bos] - - kwargs['input_ids'] = torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device) - - return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device) + yield input_ids def quantize(self, bits: int): from .quantization import quantize