Fix backward for quantization
This commit is contained in:
parent
aea6cefcf5
commit
0cfae21ef8
|
@ -134,11 +134,11 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
||||||
|
|
||||||
|
|
||||||
class PrefixEncoder(torch.nn.Module):
|
class PrefixEncoder(torch.nn.Module):
|
||||||
r'''
|
"""
|
||||||
The torch.nn model to encode the prefix
|
The torch.nn model to encode the prefix
|
||||||
Input shape: (batch-size, prefix-length)
|
Input shape: (batch-size, prefix-length)
|
||||||
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
||||||
'''
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.prefix_projection = config.prefix_projection
|
self.prefix_projection = config.prefix_projection
|
||||||
|
@ -909,7 +909,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pre_seq_len is not None:
|
if self.pre_seq_len is not None:
|
||||||
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
|
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
||||||
|
attention_mask.device)
|
||||||
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
||||||
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
||||||
|
|
||||||
|
@ -942,9 +943,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
attention_mask = attention_mask.to(input_ids.device)
|
attention_mask = attention_mask.to(input_ids.device)
|
||||||
|
|
||||||
if self.training:
|
|
||||||
hidden_states = hidden_states.requires_grad_(True)
|
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
|
|
@ -14,11 +14,11 @@ class W8A16Linear(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
||||||
ctx.inp_shape = inp.size()
|
ctx.inp_shape = inp.size()
|
||||||
ctx.weight_shape = quant_w.size()
|
|
||||||
ctx.weight_bit_width = weight_bit_width
|
ctx.weight_bit_width = weight_bit_width
|
||||||
out_features = quant_w.size(0)
|
out_features = quant_w.size(0)
|
||||||
inp = inp.contiguous().view(-1, inp.size(-1))
|
inp = inp.contiguous().view(-1, inp.size(-1))
|
||||||
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
||||||
|
ctx.weight_shape = weight.size()
|
||||||
output = inp.mm(weight.t())
|
output = inp.mm(weight.t())
|
||||||
ctx.save_for_backward(inp, quant_w, scale_w)
|
ctx.save_for_backward(inp, quant_w, scale_w)
|
||||||
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
||||||
|
@ -30,7 +30,7 @@ class W8A16Linear(torch.autograd.Function):
|
||||||
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
||||||
grad_input = grad_output.mm(weight)
|
grad_input = grad_output.mm(weight)
|
||||||
grad_weight = grad_output.t().mm(inp)
|
grad_weight = grad_output.t().mm(inp)
|
||||||
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
|
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
||||||
|
|
||||||
|
|
||||||
class Kernel:
|
class Kernel:
|
||||||
|
|
Loading…
Reference in New Issue