Add empty_init option
This commit is contained in:
parent
6466cdcff5
commit
9333486c30
|
@ -346,10 +346,18 @@ def attention_fn(
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def default_init(cls, *args, **kwargs):
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(torch.nn.Module):
|
class SelfAttention(torch.nn.Module):
|
||||||
def __init__(self, hidden_size, num_attention_heads,
|
def __init__(self, hidden_size, num_attention_heads,
|
||||||
layer_id, hidden_size_per_attention_head=None, bias=True,
|
layer_id, hidden_size_per_attention_head=None, bias=True,
|
||||||
params_dtype=torch.float, position_encoding_2d=True):
|
params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
|
||||||
|
if empty_init:
|
||||||
|
init_method = skip_init
|
||||||
|
else:
|
||||||
|
init_method = default_init
|
||||||
super(SelfAttention, self).__init__()
|
super(SelfAttention, self).__init__()
|
||||||
|
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
@ -377,7 +385,7 @@ class SelfAttention(torch.nn.Module):
|
||||||
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
|
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
|
||||||
|
|
||||||
# Strided linear layer.
|
# Strided linear layer.
|
||||||
self.query_key_value = skip_init(
|
self.query_key_value = init_method(
|
||||||
torch.nn.Linear,
|
torch.nn.Linear,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3 * self.inner_hidden_size,
|
3 * self.inner_hidden_size,
|
||||||
|
@ -385,7 +393,7 @@ class SelfAttention(torch.nn.Module):
|
||||||
dtype=params_dtype,
|
dtype=params_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dense = skip_init(
|
self.dense = init_method(
|
||||||
torch.nn.Linear,
|
torch.nn.Linear,
|
||||||
self.inner_hidden_size,
|
self.inner_hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -498,8 +506,12 @@ class GEGLU(torch.nn.Module):
|
||||||
|
|
||||||
class GLU(torch.nn.Module):
|
class GLU(torch.nn.Module):
|
||||||
def __init__(self, hidden_size, inner_hidden_size=None,
|
def __init__(self, hidden_size, inner_hidden_size=None,
|
||||||
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):
|
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
|
||||||
super(GLU, self).__init__()
|
super(GLU, self).__init__()
|
||||||
|
if empty_init:
|
||||||
|
init_method = skip_init
|
||||||
|
else:
|
||||||
|
init_method = default_init
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.activation_func = activation_func
|
self.activation_func = activation_func
|
||||||
|
|
||||||
|
@ -508,7 +520,7 @@ class GLU(torch.nn.Module):
|
||||||
if inner_hidden_size is None:
|
if inner_hidden_size is None:
|
||||||
inner_hidden_size = 4 * hidden_size
|
inner_hidden_size = 4 * hidden_size
|
||||||
self.inner_hidden_size = inner_hidden_size
|
self.inner_hidden_size = inner_hidden_size
|
||||||
self.dense_h_to_4h = skip_init(
|
self.dense_h_to_4h = init_method(
|
||||||
torch.nn.Linear,
|
torch.nn.Linear,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.inner_hidden_size,
|
self.inner_hidden_size,
|
||||||
|
@ -516,7 +528,7 @@ class GLU(torch.nn.Module):
|
||||||
dtype=params_dtype,
|
dtype=params_dtype,
|
||||||
)
|
)
|
||||||
# Project back to h.
|
# Project back to h.
|
||||||
self.dense_4h_to_h = skip_init(
|
self.dense_4h_to_h = init_method(
|
||||||
torch.nn.Linear,
|
torch.nn.Linear,
|
||||||
self.inner_hidden_size,
|
self.inner_hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
|
@ -552,7 +564,8 @@ class GLMBlock(torch.nn.Module):
|
||||||
use_bias=True,
|
use_bias=True,
|
||||||
params_dtype=torch.float,
|
params_dtype=torch.float,
|
||||||
num_layers=28,
|
num_layers=28,
|
||||||
position_encoding_2d=True
|
position_encoding_2d=True,
|
||||||
|
empty_init=True
|
||||||
):
|
):
|
||||||
super(GLMBlock, self).__init__()
|
super(GLMBlock, self).__init__()
|
||||||
# Set output layer initialization if not provided.
|
# Set output layer initialization if not provided.
|
||||||
|
@ -572,7 +585,8 @@ class GLMBlock(torch.nn.Module):
|
||||||
hidden_size_per_attention_head=hidden_size_per_attention_head,
|
hidden_size_per_attention_head=hidden_size_per_attention_head,
|
||||||
bias=use_bias,
|
bias=use_bias,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
position_encoding_2d=self.position_encoding_2d
|
position_encoding_2d=self.position_encoding_2d,
|
||||||
|
empty_init=empty_init
|
||||||
)
|
)
|
||||||
|
|
||||||
# Layernorm on the input data.
|
# Layernorm on the input data.
|
||||||
|
@ -587,6 +601,7 @@ class GLMBlock(torch.nn.Module):
|
||||||
bias=use_bias,
|
bias=use_bias,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
|
empty_init=empty_init
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -781,9 +796,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
`encoder_hidden_states` is then expected as an input to the forward pass.
|
`encoder_hidden_states` is then expected as an input to the forward pass.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: ChatGLMConfig):
|
def __init__(self, config: ChatGLMConfig, empty_init=True):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
if empty_init:
|
||||||
|
init_method = skip_init
|
||||||
|
else:
|
||||||
|
init_method = default_init
|
||||||
# recording parameters
|
# recording parameters
|
||||||
self.max_sequence_length = config.max_sequence_length
|
self.max_sequence_length = config.max_sequence_length
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
@ -798,7 +816,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
self.pre_seq_len = config.pre_seq_len
|
self.pre_seq_len = config.pre_seq_len
|
||||||
self.prefix_projection = config.prefix_projection
|
self.prefix_projection = config.prefix_projection
|
||||||
|
|
||||||
self.word_embeddings = skip_init(
|
self.word_embeddings = init_method(
|
||||||
torch.nn.Embedding,
|
torch.nn.Embedding,
|
||||||
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
||||||
dtype=self.params_dtype
|
dtype=self.params_dtype
|
||||||
|
@ -817,6 +835,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
use_bias=True,
|
use_bias=True,
|
||||||
params_dtype=self.params_dtype,
|
params_dtype=self.params_dtype,
|
||||||
position_encoding_2d=self.position_encoding_2d,
|
position_encoding_2d=self.position_encoding_2d,
|
||||||
|
empty_init=empty_init
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers = torch.nn.ModuleList(
|
self.layers = torch.nn.ModuleList(
|
||||||
|
@ -1004,8 +1023,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
def __init__(self, config: ChatGLMConfig):
|
def __init__(self, config: ChatGLMConfig, empty_init=True):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
if empty_init:
|
||||||
|
init_method = skip_init
|
||||||
|
else:
|
||||||
|
init_method = default_init
|
||||||
|
|
||||||
# self.hidden_size = config.hidden_size
|
# self.hidden_size = config.hidden_size
|
||||||
# self.params_dtype = torch.half
|
# self.params_dtype = torch.half
|
||||||
|
@ -1014,9 +1037,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
self.position_encoding_2d = config.position_encoding_2d
|
self.position_encoding_2d = config.position_encoding_2d
|
||||||
|
|
||||||
self.transformer = ChatGLMModel(config)
|
self.transformer = ChatGLMModel(config, empty_init=empty_init)
|
||||||
|
|
||||||
self.lm_head = skip_init(
|
self.lm_head = init_method(
|
||||||
nn.Linear,
|
nn.Linear,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
|
|
Loading…
Reference in New Issue