From cc96a2271aff0cdff7efbdbf57a78104d98a9047 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Sat, 1 Apr 2023 19:41:28 +0800 Subject: [PATCH] Implement batch generation --- modeling_chatglm.py | 183 ++++++++++++++++++++++------------------ tokenization_chatglm.py | 115 ++++++++++++++++++++++--- 2 files changed, 203 insertions(+), 95 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index fcbef2d..e2ff3f2 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable +from typing import Optional, Tuple, Union, List, Callable, Dict, Any from transformers.utils import ( add_code_sample_docstrings, @@ -28,7 +28,7 @@ from transformers.modeling_outputs import ( from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput from .configuration_chatglm import ChatGLMConfig @@ -664,6 +664,39 @@ class ChatGLMPreTrainedModel(PreTrainedModel): """Initialize the weights.""" return + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, input_ids, mask_positions, device, gmask=False): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + if self.position_encoding_2d: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [torch.cat(( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 + )) for context_length in context_lengths] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) + if not gmask: + for i, context_length in enumerate(context_lengths): + position_ids[context_length:] = mask_positions[i] + + return position_ids + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, ChatGLMModel): module.gradient_checkpointing = value @@ -828,39 +861,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel): # past_key_values = [(v[0], v[1]) for v in past_key_values] return past_key_values - def get_masks(self, input_ids, device): - batch_size, seq_length = input_ids.shape - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) - attention_mask.tril_() - for i, context_length in enumerate(context_lengths): - attention_mask[i, :, :context_length] = 1 - attention_mask.unsqueeze_(1) - attention_mask = (attention_mask < 0.5).bool() - - return attention_mask - - def get_position_ids(self, input_ids, mask_positions, device, gmask=False): - batch_size, seq_length = input_ids.shape - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - if self.position_encoding_2d: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) - for i, context_length in enumerate(context_lengths): - position_ids[i, context_length:] = mask_positions[i] - block_position_ids = [torch.cat(( - torch.zeros(context_length, dtype=torch.long, device=device), - torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 - )) for context_length in context_lengths] - block_position_ids = torch.stack(block_position_ids, dim=0) - position_ids = torch.stack((position_ids, block_position_ids), dim=1) - else: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) - if not gmask: - for i, context_length in enumerate(context_lengths): - position_ids[context_length:] = mask_positions[i] - - return position_ids - @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1038,35 +1038,39 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def get_masks_and_position_ids(self, input_ids, mask_positions, device, gmask=False): - batch_size, seq_length = input_ids.shape - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) - attention_mask.tril_() - for i, context_length in enumerate(context_lengths): - attention_mask[i, :, :context_length] = 1 - attention_mask.unsqueeze_(1) - attention_mask = (attention_mask < 0.5).bool() + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) - batch_size, seq_length = input_ids.shape - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - if self.position_encoding_2d: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) - for i, context_length in enumerate(context_lengths): - position_ids[i, context_length:] = mask_positions[i] - block_position_ids = [torch.cat(( - torch.zeros(context_length, dtype=torch.long, device=device), - torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 - )) for context_length in context_lengths] - block_position_ids = torch.stack(block_position_ids, dim=0) - position_ids = torch.stack((position_ids, block_position_ids), dim=1) - else: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length) - if not gmask: - for i, context_length in enumerate(context_lengths): - position_ids[context_length:] = mask_positions[i] + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = False + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, new_attention_mask], dim=2 + ) - return attention_mask, position_ids + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id[:, 1, :] += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + return model_kwargs def prepare_inputs_for_generation( self, @@ -1074,6 +1078,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): past: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, **kwargs ) -> dict: batch_size, seq_length = input_ids.shape @@ -1085,15 +1090,20 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): # only last token for input_ids if past is not None if past is not None or past_key_values is not None: - context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] last_token = input_ids[:, -1].unsqueeze(-1) - if self.position_encoding_2d: - position_ids = torch.tensor( - [[mask_position, seq_length - context_length] for mask_position, context_length in - zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) + if attention_mask is not None: + attention_mask = attention_mask[:, :, -1:] + if position_ids is not None: + position_ids = position_ids[..., -1:] else: - position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, - device=input_ids.device).unsqueeze(-1) + context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] + if self.position_encoding_2d: + position_ids = torch.tensor( + [[mask_position, seq_length - context_length] for mask_position, context_length in + zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) + else: + position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, + device=input_ids.device).unsqueeze(-1) if past is None: past = past_key_values @@ -1101,14 +1111,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): "input_ids": last_token, "past_key_values": past, "position_ids": position_ids, + "attention_mask": attention_mask } else: - attention_mask, position_ids = self.get_masks_and_position_ids( - input_ids, - mask_positions=mask_positions, - device=input_ids.device, - gmask=use_gmask - ) + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, + device=input_ids.device + ) + if position_ids is None: + position_ids = self.get_position_ids( + input_ids, + device=input_ids.device, + mask_positions=mask_positions, + gmask=use_gmask + ) return { "input_ids": input_ids, @@ -1226,10 +1243,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): for i, (old_query, response) in enumerate(history): prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - input_ids = tokenizer([prompt], return_tensors="pt", padding=True) - input_ids = input_ids.to(self.device) - outputs = self.generate(**input_ids, **gen_kwargs) - outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] + inputs = tokenizer([prompt], return_tensors="pt", padding=True) + inputs = inputs.to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs) response = self.process_response(response) history = history + [(query, response)] @@ -1252,10 +1269,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): for i, (old_query, response) in enumerate(history): prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - input_ids = tokenizer([prompt], return_tensors="pt", padding=True) - input_ids = input_ids.to(self.device) - for outputs in self.stream_generate(**input_ids, **gen_kwargs): - outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] + inputs = tokenizer([prompt], return_tensors="pt", padding=True) + inputs = inputs.to(self.device) + for outputs in self.stream_generate(**inputs, **gen_kwargs): + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs) response = self.process_response(response) new_history = history + [(query, response)] diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index 3c70605..956746c 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -1,17 +1,14 @@ """Tokenization classes for ChatGLM.""" -import sys -import unicodedata from typing import List, Optional, Union -from functools import lru_cache import os -import collections -import re from transformers.tokenization_utils import PreTrainedTokenizer from icetk.text_tokenizer import TextTokenizer -from icetk.utils import auto_create import icetk.sentencepiece_model_pb2 as sp_model -from transformers.utils import logging +from transformers.utils import logging, PaddingStrategy +from transformers.tokenization_utils_base import EncodedInput, BatchEncoding +from typing import Dict +import numpy as np logger = logging.get_logger(__name__) @@ -192,7 +189,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): eop_token='eop', mask_token='[MASK]', gmask_token='[gMASK]', - padding_side="right", + padding_side="left", **kwargs ) -> None: super().__init__( @@ -210,7 +207,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): self.eos_token = eos_token self.eop_token = eop_token self.mask_token = mask_token - self.gMASK_token = gmask_token + self.gmask_token = gmask_token self.sp_tokenizer = SPTokenizer(vocab_file) @@ -331,10 +328,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer): Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ - if token_ids_1 is not None: - token_ids_0 += token_ids_1 mask_ids = self.sp_tokenizer[self.mask_token] - gmask_ids = self.sp_tokenizer[self.gMASK_token] + gmask_ids = self.sp_tokenizer[self.gmask_token] + eop_id = self.sp_tokenizer[self.eop_token] if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0: token_ids_0 += [gmask_ids] @@ -343,4 +339,99 @@ class ChatGLMTokenizer(PreTrainedTokenizer): token_ids_0 += [self.sp_tokenizer[self.bos_token]] + if token_ids_1 is not None: + if token_ids_1[-1] != eop_id: + token_ids_1 += [eop_id] + token_ids_0 += token_ids_1 + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + bos_token_id = self.sp_tokenizer[self.bos_token] + mask_token_id = self.sp_tokenizer[self.mask_token] + gmask_token_id = self.sp_tokenizer[self.gmask_token] + assert self.padding_side == "left" + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if needs_to_be_padded or return_attention_mask: + context_length = required_input.index(bos_token_id) + attention_mask = np.ones((1, seq_length, seq_length)) + attention_mask = np.tril(attention_mask) + attention_mask[:, :, :context_length] = 1 + attention_mask = np.bool_(attention_mask < 0.5) + encoded_inputs["attention_mask"] = attention_mask + + if needs_to_be_padded or return_attention_mask: + mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id + mask_position = required_input.index(mask_token) + context_length = required_input.index(bos_token_id) + position_ids = np.arange(seq_length, dtype=np.int64) + position_ids[context_length:] = mask_position + block_position_ids = np.concatenate( + [np.zeros(context_length, dtype=np.int64), + np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) + encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode='constant', constant_values=True) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"], + pad_width=[(0, 0), (difference, 0)]) + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs