From 9333486c30bfc3860052b7ff4b2c006b576dcb4c Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 13 Apr 2023 20:35:45 +0800 Subject: [PATCH] Add empty_init option --- modeling_chatglm.py | 51 ++++++++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 883f774..49798d5 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -346,10 +346,18 @@ def attention_fn( return outputs +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + class SelfAttention(torch.nn.Module): def __init__(self, hidden_size, num_attention_heads, layer_id, hidden_size_per_attention_head=None, bias=True, - params_dtype=torch.float, position_encoding_2d=True): + params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + if empty_init: + init_method = skip_init + else: + init_method = default_init super(SelfAttention, self).__init__() self.layer_id = layer_id @@ -377,7 +385,7 @@ class SelfAttention(torch.nn.Module): self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head # Strided linear layer. - self.query_key_value = skip_init( + self.query_key_value = init_method( torch.nn.Linear, hidden_size, 3 * self.inner_hidden_size, @@ -385,7 +393,7 @@ class SelfAttention(torch.nn.Module): dtype=params_dtype, ) - self.dense = skip_init( + self.dense = init_method( torch.nn.Linear, self.inner_hidden_size, hidden_size, @@ -498,8 +506,12 @@ class GEGLU(torch.nn.Module): class GLU(torch.nn.Module): def __init__(self, hidden_size, inner_hidden_size=None, - layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float): + layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): super(GLU, self).__init__() + if empty_init: + init_method = skip_init + else: + init_method = default_init self.layer_id = layer_id self.activation_func = activation_func @@ -508,7 +520,7 @@ class GLU(torch.nn.Module): if inner_hidden_size is None: inner_hidden_size = 4 * hidden_size self.inner_hidden_size = inner_hidden_size - self.dense_h_to_4h = skip_init( + self.dense_h_to_4h = init_method( torch.nn.Linear, self.hidden_size, self.inner_hidden_size, @@ -516,7 +528,7 @@ class GLU(torch.nn.Module): dtype=params_dtype, ) # Project back to h. - self.dense_4h_to_h = skip_init( + self.dense_4h_to_h = init_method( torch.nn.Linear, self.inner_hidden_size, self.hidden_size, @@ -552,7 +564,8 @@ class GLMBlock(torch.nn.Module): use_bias=True, params_dtype=torch.float, num_layers=28, - position_encoding_2d=True + position_encoding_2d=True, + empty_init=True ): super(GLMBlock, self).__init__() # Set output layer initialization if not provided. @@ -572,7 +585,8 @@ class GLMBlock(torch.nn.Module): hidden_size_per_attention_head=hidden_size_per_attention_head, bias=use_bias, params_dtype=params_dtype, - position_encoding_2d=self.position_encoding_2d + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init ) # Layernorm on the input data. @@ -587,6 +601,7 @@ class GLMBlock(torch.nn.Module): bias=use_bias, layer_id=layer_id, params_dtype=params_dtype, + empty_init=empty_init ) def forward( @@ -781,9 +796,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel): `encoder_hidden_states` is then expected as an input to the forward pass. """ - def __init__(self, config: ChatGLMConfig): + def __init__(self, config: ChatGLMConfig, empty_init=True): super().__init__(config) - + if empty_init: + init_method = skip_init + else: + init_method = default_init # recording parameters self.max_sequence_length = config.max_sequence_length self.hidden_size = config.hidden_size @@ -798,7 +816,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): self.pre_seq_len = config.pre_seq_len self.prefix_projection = config.prefix_projection - self.word_embeddings = skip_init( + self.word_embeddings = init_method( torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype @@ -817,6 +835,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): use_bias=True, params_dtype=self.params_dtype, position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init ) self.layers = torch.nn.ModuleList( @@ -1004,8 +1023,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel): class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig): + def __init__(self, config: ChatGLMConfig, empty_init=True): super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init # self.hidden_size = config.hidden_size # self.params_dtype = torch.half @@ -1014,9 +1037,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): self.position_encoding_2d = config.position_encoding_2d - self.transformer = ChatGLMModel(config) + self.transformer = ChatGLMModel(config, empty_init=empty_init) - self.lm_head = skip_init( + self.lm_head = init_method( nn.Linear, config.hidden_size, config.vocab_size,