Fix backward for quantization

This commit is contained in:
duzx16 2023-03-30 21:49:06 +08:00
parent aea6cefcf5
commit 0cfae21ef8
2 changed files with 8 additions and 10 deletions

View File

@ -134,11 +134,11 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
class PrefixEncoder(torch.nn.Module):
r'''
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
'''
"""
def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
@ -148,7 +148,7 @@ class PrefixEncoder(torch.nn.Module):
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
@ -814,7 +814,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.num_attention_heads,
self.hidden_size // self.num_attention_heads
)
#seq_len, b, nh, hidden_size
# seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
# past_key_values = [(v[0], v[1]) for v in past_key_values]
@ -909,7 +909,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
)
if self.pre_seq_len is not None:
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
attention_mask.device)
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
@ -942,9 +943,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
else:
attention_mask = attention_mask.to(input_ids.device)
if self.training:
hidden_states = hidden_states.requires_grad_(True)
for i, layer in enumerate(self.layers):
if output_hidden_states:

View File

@ -14,11 +14,11 @@ class W8A16Linear(torch.autograd.Function):
@staticmethod
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
ctx.inp_shape = inp.size()
ctx.weight_shape = quant_w.size()
ctx.weight_bit_width = weight_bit_width
out_features = quant_w.size(0)
inp = inp.contiguous().view(-1, inp.size(-1))
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
ctx.weight_shape = weight.size()
output = inp.mm(weight.t())
ctx.save_for_backward(inp, quant_w, scale_w)
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@ -30,7 +30,7 @@ class W8A16Linear(torch.autograd.Function):
grad_output = grad_output.contiguous().view(-1, weight.size(0))
grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(inp)
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
class Kernel: