Add empty_init option

This commit is contained in:
duzx16 2023-04-13 20:35:45 +08:00
parent 6466cdcff5
commit 9333486c30
1 changed files with 37 additions and 14 deletions

View File

@ -346,10 +346,18 @@ def attention_fn(
return outputs
def default_init(cls, *args, **kwargs):
return cls(*args, **kwargs)
class SelfAttention(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads,
layer_id, hidden_size_per_attention_head=None, bias=True,
params_dtype=torch.float, position_encoding_2d=True):
params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
if empty_init:
init_method = skip_init
else:
init_method = default_init
super(SelfAttention, self).__init__()
self.layer_id = layer_id
@ -377,7 +385,7 @@ class SelfAttention(torch.nn.Module):
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
# Strided linear layer.
self.query_key_value = skip_init(
self.query_key_value = init_method(
torch.nn.Linear,
hidden_size,
3 * self.inner_hidden_size,
@ -385,7 +393,7 @@ class SelfAttention(torch.nn.Module):
dtype=params_dtype,
)
self.dense = skip_init(
self.dense = init_method(
torch.nn.Linear,
self.inner_hidden_size,
hidden_size,
@ -498,8 +506,12 @@ class GEGLU(torch.nn.Module):
class GLU(torch.nn.Module):
def __init__(self, hidden_size, inner_hidden_size=None,
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
super(GLU, self).__init__()
if empty_init:
init_method = skip_init
else:
init_method = default_init
self.layer_id = layer_id
self.activation_func = activation_func
@ -508,7 +520,7 @@ class GLU(torch.nn.Module):
if inner_hidden_size is None:
inner_hidden_size = 4 * hidden_size
self.inner_hidden_size = inner_hidden_size
self.dense_h_to_4h = skip_init(
self.dense_h_to_4h = init_method(
torch.nn.Linear,
self.hidden_size,
self.inner_hidden_size,
@ -516,7 +528,7 @@ class GLU(torch.nn.Module):
dtype=params_dtype,
)
# Project back to h.
self.dense_4h_to_h = skip_init(
self.dense_4h_to_h = init_method(
torch.nn.Linear,
self.inner_hidden_size,
self.hidden_size,
@ -552,7 +564,8 @@ class GLMBlock(torch.nn.Module):
use_bias=True,
params_dtype=torch.float,
num_layers=28,
position_encoding_2d=True
position_encoding_2d=True,
empty_init=True
):
super(GLMBlock, self).__init__()
# Set output layer initialization if not provided.
@ -572,7 +585,8 @@ class GLMBlock(torch.nn.Module):
hidden_size_per_attention_head=hidden_size_per_attention_head,
bias=use_bias,
params_dtype=params_dtype,
position_encoding_2d=self.position_encoding_2d
position_encoding_2d=self.position_encoding_2d,
empty_init=empty_init
)
# Layernorm on the input data.
@ -587,6 +601,7 @@ class GLMBlock(torch.nn.Module):
bias=use_bias,
layer_id=layer_id,
params_dtype=params_dtype,
empty_init=empty_init
)
def forward(
@ -781,9 +796,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
`encoder_hidden_states` is then expected as an input to the forward pass.
"""
def __init__(self, config: ChatGLMConfig):
def __init__(self, config: ChatGLMConfig, empty_init=True):
super().__init__(config)
if empty_init:
init_method = skip_init
else:
init_method = default_init
# recording parameters
self.max_sequence_length = config.max_sequence_length
self.hidden_size = config.hidden_size
@ -798,7 +816,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.pre_seq_len = config.pre_seq_len
self.prefix_projection = config.prefix_projection
self.word_embeddings = skip_init(
self.word_embeddings = init_method(
torch.nn.Embedding,
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
dtype=self.params_dtype
@ -817,6 +835,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
use_bias=True,
params_dtype=self.params_dtype,
position_encoding_2d=self.position_encoding_2d,
empty_init=empty_init
)
self.layers = torch.nn.ModuleList(
@ -1004,8 +1023,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
def __init__(self, config: ChatGLMConfig):
def __init__(self, config: ChatGLMConfig, empty_init=True):
super().__init__(config)
if empty_init:
init_method = skip_init
else:
init_method = default_init
# self.hidden_size = config.hidden_size
# self.params_dtype = torch.half
@ -1014,9 +1037,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
self.position_encoding_2d = config.position_encoding_2d
self.transformer = ChatGLMModel(config)
self.transformer = ChatGLMModel(config, empty_init=empty_init)
self.lm_head = skip_init(
self.lm_head = init_method(
nn.Linear,
config.hidden_size,
config.vocab_size,