diff --git a/modeling_chatglm.py b/modeling_chatglm.py index e6c6511..330077d 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -254,13 +254,13 @@ def attention_fn( if not (attention_mask == 0).all(): # if auto-regressive, skip attention_scores.masked_fill_(attention_mask, -10000.0) - + dtype = attention_scores.type() attention_scores = attention_scores.float() attention_scores = attention_scores * query_key_layer_scaling_coeff attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.half() + attention_probs = attention_probs.type(dtype) # ========================= # Context layer. [sq, b, hp]