diff --git a/modeling_chatglm.py b/modeling_chatglm.py index c7ff677..c5a0d31 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -134,11 +134,11 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): class PrefixEncoder(torch.nn.Module): - r''' + """ The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size, prefix-length, 2*layers*hidden) - ''' + """ def __init__(self, config): super().__init__() self.prefix_projection = config.prefix_projection @@ -148,7 +148,7 @@ class PrefixEncoder(torch.nn.Module): self.trans = torch.nn.Sequential( torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) ) else: self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) @@ -814,7 +814,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): self.num_attention_heads, self.hidden_size // self.num_attention_heads ) - #seq_len, b, nh, hidden_size + # seq_len, b, nh, hidden_size past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) # past_key_values = [(v[0], v[1]) for v in past_key_values] @@ -909,7 +909,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ) if self.pre_seq_len is not None: - prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device) + prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( + attention_mask.device) prefix_attention_mask = (prefix_attention_mask < 0.5).bool() attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) @@ -942,9 +943,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel): else: attention_mask = attention_mask.to(input_ids.device) - if self.training: - hidden_states = hidden_states.requires_grad_(True) - for i, layer in enumerate(self.layers): if output_hidden_states: diff --git a/quantization.py b/quantization.py index bf30790..861cbd9 100644 --- a/quantization.py +++ b/quantization.py @@ -14,11 +14,11 @@ class W8A16Linear(torch.autograd.Function): @staticmethod def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): ctx.inp_shape = inp.size() - ctx.weight_shape = quant_w.size() ctx.weight_bit_width = weight_bit_width out_features = quant_w.size(0) inp = inp.contiguous().view(-1, inp.size(-1)) weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() output = inp.mm(weight.t()) ctx.save_for_backward(inp, quant_w, scale_w) return output.view(*(ctx.inp_shape[:-1] + (out_features,))) @@ -30,7 +30,7 @@ class W8A16Linear(torch.autograd.Function): grad_output = grad_output.contiguous().view(-1, weight.size(0)) grad_input = grad_output.mm(weight) grad_weight = grad_output.t().mm(inp) - return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None + return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None class Kernel: