Implement batch generation
This commit is contained in:
parent
11c270c26c
commit
cc96a2271a
|
@ -13,7 +13,7 @@ import torch.nn.functional as F
|
|||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
from torch.nn.utils import skip_init
|
||||
from typing import Optional, Tuple, Union, List, Callable
|
||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||
|
||||
from transformers.utils import (
|
||||
add_code_sample_docstrings,
|
||||
|
@ -28,7 +28,7 @@ from transformers.modeling_outputs import (
|
|||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
||||
|
||||
from .configuration_chatglm import ChatGLMConfig
|
||||
|
||||
|
@ -664,6 +664,39 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|||
"""Initialize the weights."""
|
||||
return
|
||||
|
||||
def get_masks(self, input_ids, device):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
||||
attention_mask.tril_()
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
attention_mask[i, :, :context_length] = 1
|
||||
attention_mask.unsqueeze_(1)
|
||||
attention_mask = (attention_mask < 0.5).bool()
|
||||
|
||||
return attention_mask
|
||||
|
||||
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||
if self.position_encoding_2d:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[i, context_length:] = mask_positions[i]
|
||||
block_position_ids = [torch.cat((
|
||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||
)) for context_length in context_lengths]
|
||||
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||
else:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||
if not gmask:
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[context_length:] = mask_positions[i]
|
||||
|
||||
return position_ids
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, ChatGLMModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
@ -828,39 +861,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
# past_key_values = [(v[0], v[1]) for v in past_key_values]
|
||||
return past_key_values
|
||||
|
||||
def get_masks(self, input_ids, device):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
||||
attention_mask.tril_()
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
attention_mask[i, :, :context_length] = 1
|
||||
attention_mask.unsqueeze_(1)
|
||||
attention_mask = (attention_mask < 0.5).bool()
|
||||
|
||||
return attention_mask
|
||||
|
||||
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||
if self.position_encoding_2d:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[i, context_length:] = mask_positions[i]
|
||||
block_position_ids = [torch.cat((
|
||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||
)) for context_length in context_lengths]
|
||||
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||
else:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||
if not gmask:
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[context_length:] = mask_positions[i]
|
||||
|
||||
return position_ids
|
||||
|
||||
@add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
|
@ -1038,35 +1038,39 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def get_masks_and_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
||||
attention_mask.tril_()
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
attention_mask[i, :, :context_length] = 1
|
||||
attention_mask.unsqueeze_(1)
|
||||
attention_mask = (attention_mask < 0.5).bool()
|
||||
def _update_model_kwargs_for_generation(
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
model_kwargs: Dict[str, Any],
|
||||
is_encoder_decoder: bool = False,
|
||||
standardize_cache_format: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
# update past_key_values
|
||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||
outputs, standardize_cache_format=standardize_cache_format
|
||||
)
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
||||
if self.position_encoding_2d:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[i, context_length:] = mask_positions[i]
|
||||
block_position_ids = [torch.cat((
|
||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||
)) for context_length in context_lengths]
|
||||
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||
else:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
||||
if not gmask:
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[context_length:] = mask_positions[i]
|
||||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
|
||||
new_attention_mask = attention_mask[:, :, -1:].clone()
|
||||
new_attention_mask[..., -1] = False
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, new_attention_mask], dim=2
|
||||
)
|
||||
|
||||
return attention_mask, position_ids
|
||||
# update position ids
|
||||
if "position_ids" in model_kwargs:
|
||||
position_ids = model_kwargs["position_ids"]
|
||||
new_position_id = position_ids[..., -1:].clone()
|
||||
new_position_id[:, 1, :] += 1
|
||||
model_kwargs["position_ids"] = torch.cat(
|
||||
[position_ids, new_position_id], dim=-1
|
||||
)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
|
@ -1074,6 +1078,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
past: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
@ -1085,15 +1090,20 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
|
||||
# only last token for input_ids if past is not None
|
||||
if past is not None or past_key_values is not None:
|
||||
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
||||
last_token = input_ids[:, -1].unsqueeze(-1)
|
||||
if self.position_encoding_2d:
|
||||
position_ids = torch.tensor(
|
||||
[[mask_position, seq_length - context_length] for mask_position, context_length in
|
||||
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, -1:]
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids[..., -1:]
|
||||
else:
|
||||
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
||||
device=input_ids.device).unsqueeze(-1)
|
||||
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
||||
if self.position_encoding_2d:
|
||||
position_ids = torch.tensor(
|
||||
[[mask_position, seq_length - context_length] for mask_position, context_length in
|
||||
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
||||
else:
|
||||
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
||||
device=input_ids.device).unsqueeze(-1)
|
||||
|
||||
if past is None:
|
||||
past = past_key_values
|
||||
|
@ -1101,14 +1111,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
"input_ids": last_token,
|
||||
"past_key_values": past,
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask
|
||||
}
|
||||
else:
|
||||
attention_mask, position_ids = self.get_masks_and_position_ids(
|
||||
input_ids,
|
||||
mask_positions=mask_positions,
|
||||
device=input_ids.device,
|
||||
gmask=use_gmask
|
||||
)
|
||||
if attention_mask is None:
|
||||
attention_mask = self.get_masks(
|
||||
input_ids,
|
||||
device=input_ids.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = self.get_position_ids(
|
||||
input_ids,
|
||||
device=input_ids.device,
|
||||
mask_positions=mask_positions,
|
||||
gmask=use_gmask
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
|
@ -1226,10 +1243,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
for i, (old_query, response) in enumerate(history):
|
||||
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
||||
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
||||
input_ids = input_ids.to(self.device)
|
||||
outputs = self.generate(**input_ids, **gen_kwargs)
|
||||
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
||||
inputs = tokenizer([prompt], return_tensors="pt", padding=True)
|
||||
inputs = inputs.to(self.device)
|
||||
outputs = self.generate(**inputs, **gen_kwargs)
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
response = tokenizer.decode(outputs)
|
||||
response = self.process_response(response)
|
||||
history = history + [(query, response)]
|
||||
|
@ -1252,10 +1269,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
for i, (old_query, response) in enumerate(history):
|
||||
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
||||
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
||||
input_ids = input_ids.to(self.device)
|
||||
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
||||
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
||||
inputs = tokenizer([prompt], return_tensors="pt", padding=True)
|
||||
inputs = inputs.to(self.device)
|
||||
for outputs in self.stream_generate(**inputs, **gen_kwargs):
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
response = tokenizer.decode(outputs)
|
||||
response = self.process_response(response)
|
||||
new_history = history + [(query, response)]
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
"""Tokenization classes for ChatGLM."""
|
||||
import sys
|
||||
import unicodedata
|
||||
from typing import List, Optional, Union
|
||||
from functools import lru_cache
|
||||
import os
|
||||
import collections
|
||||
import re
|
||||
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from icetk.text_tokenizer import TextTokenizer
|
||||
from icetk.utils import auto_create
|
||||
import icetk.sentencepiece_model_pb2 as sp_model
|
||||
from transformers.utils import logging
|
||||
from transformers.utils import logging, PaddingStrategy
|
||||
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -192,7 +189,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
eop_token='eop',
|
||||
mask_token='[MASK]',
|
||||
gmask_token='[gMASK]',
|
||||
padding_side="right",
|
||||
padding_side="left",
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(
|
||||
|
@ -210,7 +207,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
self.eos_token = eos_token
|
||||
self.eop_token = eop_token
|
||||
self.mask_token = mask_token
|
||||
self.gMASK_token = gmask_token
|
||||
self.gmask_token = gmask_token
|
||||
|
||||
self.sp_tokenizer = SPTokenizer(vocab_file)
|
||||
|
||||
|
@ -331,10 +328,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is not None:
|
||||
token_ids_0 += token_ids_1
|
||||
mask_ids = self.sp_tokenizer[self.mask_token]
|
||||
gmask_ids = self.sp_tokenizer[self.gMASK_token]
|
||||
gmask_ids = self.sp_tokenizer[self.gmask_token]
|
||||
eop_id = self.sp_tokenizer[self.eop_token]
|
||||
if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
|
||||
token_ids_0 += [gmask_ids]
|
||||
|
||||
|
@ -343,4 +339,99 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
|
||||
token_ids_0 += [self.sp_tokenizer[self.bos_token]]
|
||||
|
||||
if token_ids_1 is not None:
|
||||
if token_ids_1[-1] != eop_id:
|
||||
token_ids_1 += [eop_id]
|
||||
token_ids_0 += token_ids_1
|
||||
|
||||
return token_ids_0
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
encoded_inputs:
|
||||
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
||||
max_length: maximum length of the returned list and optionally padding length (see below).
|
||||
Will truncate by taking into account the special tokens.
|
||||
padding_strategy: PaddingStrategy to use for padding.
|
||||
|
||||
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
||||
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
||||
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
||||
The tokenizer padding sides are defined in self.padding_side:
|
||||
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta).
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
# Load from model defaults
|
||||
bos_token_id = self.sp_tokenizer[self.bos_token]
|
||||
mask_token_id = self.sp_tokenizer[self.mask_token]
|
||||
gmask_token_id = self.sp_tokenizer[self.gmask_token]
|
||||
assert self.padding_side == "left"
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
seq_length = len(required_input)
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if needs_to_be_padded or return_attention_mask:
|
||||
context_length = required_input.index(bos_token_id)
|
||||
attention_mask = np.ones((1, seq_length, seq_length))
|
||||
attention_mask = np.tril(attention_mask)
|
||||
attention_mask[:, :, :context_length] = 1
|
||||
attention_mask = np.bool_(attention_mask < 0.5)
|
||||
encoded_inputs["attention_mask"] = attention_mask
|
||||
|
||||
if needs_to_be_padded or return_attention_mask:
|
||||
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
||||
mask_position = required_input.index(mask_token)
|
||||
context_length = required_input.index(bos_token_id)
|
||||
position_ids = np.arange(seq_length, dtype=np.int64)
|
||||
position_ids[context_length:] = mask_position
|
||||
block_position_ids = np.concatenate(
|
||||
[np.zeros(context_length, dtype=np.int64),
|
||||
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
|
||||
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
|
||||
pad_width=[(0, 0), (difference, 0), (difference, 0)],
|
||||
mode='constant', constant_values=True)
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
||||
"token_type_ids"
|
||||
]
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||
if "position_ids" in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
|
||||
pad_width=[(0, 0), (difference, 0)])
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
|
||||
return encoded_inputs
|
||||
|
|
Loading…
Reference in New Issue