Implement gradient checkpointing
This commit is contained in:
parent
0564795e6e
commit
aea6cefcf5
|
@ -244,7 +244,7 @@ def attention_fn(
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
):
|
):
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past
|
past_key, past_value = layer_past[0], layer_past[1]
|
||||||
key_layer = torch.cat((past_key, key_layer), dim=0)
|
key_layer = torch.cat((past_key, key_layer), dim=0)
|
||||||
value_layer = torch.cat((past_value, value_layer), dim=0)
|
value_layer = torch.cat((past_value, value_layer), dim=0)
|
||||||
|
|
||||||
|
@ -644,7 +644,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
is_parallelizable = False
|
is_parallelizable = False
|
||||||
supports_gradient_checkpointing = False
|
supports_gradient_checkpointing = True
|
||||||
config_class = ChatGLMConfig
|
config_class = ChatGLMConfig
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
_no_split_modules = ["GLM6BBlock"]
|
_no_split_modules = ["GLM6BBlock"]
|
||||||
|
@ -656,6 +656,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
||||||
"""Initialize the weights."""
|
"""Initialize the weights."""
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, ChatGLMModel):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
|
||||||
CHATGLM_6B_START_DOCSTRING = r"""
|
CHATGLM_6B_START_DOCSTRING = r"""
|
||||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
|
||||||
|
@ -760,6 +764,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
||||||
dtype=self.params_dtype
|
dtype=self.params_dtype
|
||||||
)
|
)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def get_layer(layer_id):
|
def get_layer(layer_id):
|
||||||
return GLMBlock(
|
return GLMBlock(
|
||||||
|
@ -812,9 +817,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
#seq_len, b, nh, hidden_size
|
#seq_len, b, nh, hidden_size
|
||||||
past_key_values = self.dropout(past_key_values)
|
past_key_values = self.dropout(past_key_values)
|
||||||
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
||||||
past_key_values = [(v[0], v[1]) for v in past_key_values]
|
# past_key_values = [(v[0], v[1]) for v in past_key_values]
|
||||||
# past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(self.num_layers)
|
|
||||||
# past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])]
|
|
||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
def get_masks(self, input_ids, device):
|
def get_masks(self, input_ids, device):
|
||||||
|
@ -877,6 +880,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
|
@ -926,31 +936,42 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
|
||||||
if past_key_values[0] is not None:
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[0]
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attention_mask = attention_mask.to(input_ids.device)
|
attention_mask = attention_mask.to(input_ids.device)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
hidden_states = hidden_states.requires_grad_(True)
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
layer_past = past_key_values[i]
|
||||||
|
|
||||||
layer_ret = layer(
|
if self.gradient_checkpointing and self.training:
|
||||||
hidden_states,
|
layer_ret = torch.utils.checkpoint.checkpoint(
|
||||||
position_ids=position_ids,
|
layer,
|
||||||
attention_mask=attention_mask,
|
hidden_states,
|
||||||
layer_id=torch.tensor(i),
|
position_ids,
|
||||||
layer_past=past_key_values[i],
|
attention_mask,
|
||||||
use_cache=use_cache,
|
torch.tensor(i),
|
||||||
output_attentions=output_attentions
|
layer_past,
|
||||||
)
|
use_cache,
|
||||||
|
output_attentions
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_ret = layer(
|
||||||
|
hidden_states,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_id=torch.tensor(i),
|
||||||
|
layer_past=layer_past,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_ret[0]
|
hidden_states = layer_ret[0]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue