From aea6cefcf522f3ba3fb2e2c0a6c507259cf6a564 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 30 Mar 2023 19:42:01 +0800 Subject: [PATCH] Implement gradient checkpointing --- modeling_chatglm.py | 59 ++++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 8ea3776..c7ff677 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -244,7 +244,7 @@ def attention_fn( use_cache=False, ): if layer_past is not None: - past_key, past_value = layer_past + past_key, past_value = layer_past[0], layer_past[1] key_layer = torch.cat((past_key, key_layer), dim=0) value_layer = torch.cat((past_value, value_layer), dim=0) @@ -644,7 +644,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel): """ is_parallelizable = False - supports_gradient_checkpointing = False + supports_gradient_checkpointing = True config_class = ChatGLMConfig base_model_prefix = "transformer" _no_split_modules = ["GLM6BBlock"] @@ -656,6 +656,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel): """Initialize the weights.""" return + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + CHATGLM_6B_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. @@ -760,6 +764,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype ) + self.gradient_checkpointing = False def get_layer(layer_id): return GLMBlock( @@ -812,9 +817,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): #seq_len, b, nh, hidden_size past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - past_key_values = [(v[0], v[1]) for v in past_key_values] - # past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(self.num_layers) - # past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])] + # past_key_values = [(v[0], v[1]) for v in past_key_values] return past_key_values def get_masks(self, input_ids, device): @@ -877,6 +880,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -926,31 +936,42 @@ class ChatGLMModel(ChatGLMPreTrainedModel): all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[0] - seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() else: attention_mask = attention_mask.to(input_ids.device) + if self.training: + hidden_states = hidden_states.requires_grad_(True) + for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + layer_past = past_key_values[i] - layer_ret = layer( - hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - layer_id=torch.tensor(i), - layer_past=past_key_values[i], - use_cache=use_cache, - output_attentions=output_attentions - ) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + position_ids, + attention_mask, + torch.tensor(i), + layer_past, + use_cache, + output_attentions + ) + else: + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) hidden_states = layer_ret[0]