Implement gradient checkpointing

This commit is contained in:
duzx16 2023-03-30 19:42:01 +08:00
parent 0564795e6e
commit aea6cefcf5
1 changed files with 40 additions and 19 deletions

View File

@ -244,7 +244,7 @@ def attention_fn(
use_cache=False,
):
if layer_past is not None:
past_key, past_value = layer_past
past_key, past_value = layer_past[0], layer_past[1]
key_layer = torch.cat((past_key, key_layer), dim=0)
value_layer = torch.cat((past_value, value_layer), dim=0)
@ -644,7 +644,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
"""
is_parallelizable = False
supports_gradient_checkpointing = False
supports_gradient_checkpointing = True
config_class = ChatGLMConfig
base_model_prefix = "transformer"
_no_split_modules = ["GLM6BBlock"]
@ -656,6 +656,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
"""Initialize the weights."""
return
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ChatGLMModel):
module.gradient_checkpointing = value
CHATGLM_6B_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
@ -760,6 +764,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
dtype=self.params_dtype
)
self.gradient_checkpointing = False
def get_layer(layer_id):
return GLMBlock(
@ -812,9 +817,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
#seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
past_key_values = [(v[0], v[1]) for v in past_key_values]
# past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(self.num_layers)
# past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])]
# past_key_values = [(v[0], v[1]) for v in past_key_values]
return past_key_values
def get_masks(self, input_ids, device):
@ -877,6 +880,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -926,28 +936,39 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[0]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
else:
attention_mask = attention_mask.to(input_ids.device)
if self.training:
hidden_states = hidden_states.requires_grad_(True)
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_past = past_key_values[i]
if self.gradient_checkpointing and self.training:
layer_ret = torch.utils.checkpoint.checkpoint(
layer,
hidden_states,
position_ids,
attention_mask,
torch.tensor(i),
layer_past,
use_cache,
output_attentions
)
else:
layer_ret = layer(
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
layer_id=torch.tensor(i),
layer_past=past_key_values[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions
)