generated from xuyuqing/ailab
sync with huggingface
This commit is contained in:
parent
478de9d646
commit
3af7eae8ac
|
@ -16,6 +16,7 @@ from .configuration_baichuan import BaichuanConfig
|
|||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_interleave(n):
|
||||
def _get_interleave_power_of_2(n):
|
||||
start = (2 ** (-2 ** -(math.log2(n) - 3)))
|
||||
|
@ -34,6 +35,7 @@ def _fill_with_neg_inf(t):
|
|||
return t.float().fill_(float("-inf")).type_as(t)
|
||||
|
||||
def _gen_alibi_mask(n_head, max_pos):
|
||||
"""used in inference only"""
|
||||
slopes = torch.Tensor(_get_interleave(n_head))
|
||||
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(
|
||||
n_head, -1, -1)
|
||||
|
@ -44,6 +46,16 @@ def _gen_alibi_mask(n_head, max_pos):
|
|||
alibi_mask = alibi_mask.unsqueeze(0) + alibi
|
||||
return alibi_mask
|
||||
|
||||
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
|
||||
"""used in training only"""
|
||||
dim = tensor.size(1)
|
||||
_future_mask = torch.triu(
|
||||
_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1
|
||||
)
|
||||
_future_mask = _future_mask.unsqueeze(0) + alibi
|
||||
_future_mask = _future_mask.to(tensor)
|
||||
return _future_mask[:tensor.shape[0] * attn_heads, :maxpos, :maxpos]
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, epsilon=1e-6):
|
||||
|
@ -80,7 +92,6 @@ class MLP(torch.nn.Module):
|
|||
|
||||
|
||||
class BaichuanAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, config: BaichuanConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
@ -130,12 +141,16 @@ class BaichuanAttention(torch.nn.Module):
|
|||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attn_weights.size(-2) == 1:
|
||||
attention_mask = attention_mask[:, -1:, :]
|
||||
attn_weights = attn_weights + attention_mask.unsqueeze(0)
|
||||
if q_len == 1: # inference with cache
|
||||
if len(attention_mask.size()) == 4:
|
||||
attention_mask = attention_mask[:, :, -1:, :]
|
||||
else:
|
||||
attention_mask = attention_mask[:, -1:, :]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
@ -221,7 +236,6 @@ class BaichuanPreTrainedModel(PreTrainedModel):
|
|||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
|
||||
class BaichuanModel(BaichuanPreTrainedModel):
|
||||
def __init__(self, config: BaichuanConfig):
|
||||
super().__init__(config)
|
||||
|
@ -235,27 +249,37 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|||
self.gradient_checkpointing = config.gradient_checkpointing
|
||||
self.post_init()
|
||||
self.max_cache_pos = config.model_max_length
|
||||
self.first_run = True
|
||||
self.first_run = True
|
||||
self.alibi_mask = None
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
self.embed_tokens = value
|
||||
|
||||
def get_alibi_mask(self, tensor, seq_length_with_past):
|
||||
if self.first_run:
|
||||
self.first_run = False
|
||||
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
|
||||
if seq_length_with_past > self.max_cache_pos:
|
||||
self.max_cache_pos = seq_length_with_past
|
||||
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
|
||||
mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
|
||||
if self.training:
|
||||
slopes = torch.Tensor(_get_interleave(self.n_head))
|
||||
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(0).expand(
|
||||
self.n_head,
|
||||
-1, -1)
|
||||
alibi = alibi.view(self.n_head, 1, seq_length_with_past)
|
||||
mask = _buffered_future_mask(tensor, seq_length_with_past, alibi, self.n_head)
|
||||
else:
|
||||
if self.first_run:
|
||||
self.first_run = False
|
||||
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
|
||||
if seq_length_with_past > self.max_cache_pos:
|
||||
self.max_cache_pos = seq_length_with_past
|
||||
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
|
||||
mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
|
@ -264,7 +288,6 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|||
return_dict: Optional[bool] = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
|
||||
elif input_ids is not None:
|
||||
|
@ -274,6 +297,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|||
else:
|
||||
raise ValueError("You need to provide input_ids or inputs_embeds")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
|
||||
if past_key_values is not None:
|
||||
|
@ -283,8 +308,28 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# embed positions
|
||||
attention_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
|
||||
if self.training:
|
||||
if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past:
|
||||
self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
|
||||
alibi_mask = self.alibi_mask
|
||||
else:
|
||||
alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
|
||||
|
||||
if attention_mask is not None:
|
||||
if len(attention_mask.shape) == 2:
|
||||
expanded_mask = attention_mask.to(alibi_mask.dtype)
|
||||
expanded_mask = torch.tril(torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
|
||||
) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
|
||||
else:
|
||||
expanded_mask = attention_mask
|
||||
bsz = inputs_embeds.size(0)
|
||||
src_len, tgt_len = alibi_mask.size()[-2:]
|
||||
expanded_mask = expanded_mask.unsqueeze(1).expand(bsz, 1, src_len, tgt_len).to(alibi_mask.dtype)
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
inverted_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min)
|
||||
attention_mask = inverted_mask + alibi_mask.unsqueeze(0)
|
||||
else:
|
||||
attention_mask = alibi_mask
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
|
@ -353,7 +398,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
|
@ -364,9 +409,28 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
|
@ -377,17 +441,19 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|||
**kwargs
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
@ -418,8 +484,13 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
):
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
|
@ -430,11 +501,12 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
}
|
||||
)
|
||||
"attention_mask": attention_mask
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
|
@ -444,7 +516,6 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|||
for layer_past in past_key_values
|
||||
)
|
||||
|
||||
|
||||
def quantize(self, bits: int):
|
||||
try:
|
||||
from .quantizer import QLinear
|
||||
|
@ -452,7 +523,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|||
raise ImportError(
|
||||
f"Needs QLinear to run quantize."
|
||||
)
|
||||
|
||||
|
||||
for layer in self.model.layers:
|
||||
layer.self_attn.W_pack = QLinear(
|
||||
bits=bits,
|
||||
|
@ -479,7 +550,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|||
weight=layer.mlp.up_proj.weight,
|
||||
bias = None,
|
||||
)
|
||||
return self
|
||||
return self
|
||||
|
||||
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
|
||||
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
|
||||
|
|
Loading…
Reference in New Issue