sync with huggingface

This commit is contained in:
root 2023-08-31 14:49:00 +08:00
parent 478de9d646
commit 3af7eae8ac
1 changed files with 101 additions and 30 deletions

View File

@ -16,6 +16,7 @@ from .configuration_baichuan import BaichuanConfig
logger = logging.get_logger(__name__)
def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = (2 ** (-2 ** -(math.log2(n) - 3)))
@ -34,6 +35,7 @@ def _fill_with_neg_inf(t):
return t.float().fill_(float("-inf")).type_as(t)
def _gen_alibi_mask(n_head, max_pos):
"""used in inference only"""
slopes = torch.Tensor(_get_interleave(n_head))
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(
n_head, -1, -1)
@ -44,6 +46,16 @@ def _gen_alibi_mask(n_head, max_pos):
alibi_mask = alibi_mask.unsqueeze(0) + alibi
return alibi_mask
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
"""used in training only"""
dim = tensor.size(1)
_future_mask = torch.triu(
_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1
)
_future_mask = _future_mask.unsqueeze(0) + alibi
_future_mask = _future_mask.to(tensor)
return _future_mask[:tensor.shape[0] * attn_heads, :maxpos, :maxpos]
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, epsilon=1e-6):
@ -80,7 +92,6 @@ class MLP(torch.nn.Module):
class BaichuanAttention(torch.nn.Module):
def __init__(self, config: BaichuanConfig):
super().__init__()
self.config = config
@ -130,12 +141,16 @@ class BaichuanAttention(torch.nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
if attn_weights.size(-2) == 1:
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask.unsqueeze(0)
if q_len == 1: # inference with cache
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
@ -221,7 +236,6 @@ class BaichuanPreTrainedModel(PreTrainedModel):
module.gradient_checkpointing = value
class BaichuanModel(BaichuanPreTrainedModel):
def __init__(self, config: BaichuanConfig):
super().__init__(config)
@ -235,27 +249,37 @@ class BaichuanModel(BaichuanPreTrainedModel):
self.gradient_checkpointing = config.gradient_checkpointing
self.post_init()
self.max_cache_pos = config.model_max_length
self.first_run = True
self.first_run = True
self.alibi_mask = None
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
self.embed_tokens = value
def get_alibi_mask(self, tensor, seq_length_with_past):
if self.first_run:
self.first_run = False
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
if seq_length_with_past > self.max_cache_pos:
self.max_cache_pos = seq_length_with_past
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
if self.training:
slopes = torch.Tensor(_get_interleave(self.n_head))
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(0).expand(
self.n_head,
-1, -1)
alibi = alibi.view(self.n_head, 1, seq_length_with_past)
mask = _buffered_future_mask(tensor, seq_length_with_past, alibi, self.n_head)
else:
if self.first_run:
self.first_run = False
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
if seq_length_with_past > self.max_cache_pos:
self.max_cache_pos = seq_length_with_past
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
return mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
@ -264,7 +288,6 @@ class BaichuanModel(BaichuanPreTrainedModel):
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
elif input_ids is not None:
@ -274,6 +297,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
else:
raise ValueError("You need to provide input_ids or inputs_embeds")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
seq_length_with_past = seq_length
if past_key_values is not None:
@ -283,8 +308,28 @@ class BaichuanModel(BaichuanPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
attention_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
if self.training:
if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past:
self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
alibi_mask = self.alibi_mask
else:
alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
if attention_mask is not None:
if len(attention_mask.shape) == 2:
expanded_mask = attention_mask.to(alibi_mask.dtype)
expanded_mask = torch.tril(torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
else:
expanded_mask = attention_mask
bsz = inputs_embeds.size(0)
src_len, tgt_len = alibi_mask.size()[-2:]
expanded_mask = expanded_mask.unsqueeze(1).expand(bsz, 1, src_len, tgt_len).to(alibi_mask.dtype)
inverted_mask = 1.0 - expanded_mask
inverted_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min)
attention_mask = inverted_mask + alibi_mask.unsqueeze(0)
else:
attention_mask = alibi_mask
hidden_states = inputs_embeds
@ -353,7 +398,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class BaichuanForCausalLM(BaichuanPreTrainedModel):
def __init__(self, config):
@ -364,9 +409,28 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -377,17 +441,19 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
**kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
@ -418,8 +484,13 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
@ -430,11 +501,12 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
}
)
"attention_mask": attention_mask
}
)
return model_inputs
@staticmethod
@ -444,7 +516,6 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
for layer_past in past_key_values
)
def quantize(self, bits: int):
try:
from .quantizer import QLinear
@ -452,7 +523,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
raise ImportError(
f"Needs QLinear to run quantize."
)
for layer in self.model.layers:
layer.self_attn.W_pack = QLinear(
bits=bits,
@ -479,7 +550,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
weight=layer.mlp.up_proj.weight,
bias = None,
)
return self
return self
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens