From d4832e8142309f1e4b4261dc584c60a8e73775a2 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Tue, 14 Mar 2023 14:49:14 +0800 Subject: [PATCH] Add support for float32 --- modeling_chatglm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index e6c6511..330077d 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -254,13 +254,13 @@ def attention_fn( if not (attention_mask == 0).all(): # if auto-regressive, skip attention_scores.masked_fill_(attention_mask, -10000.0) - + dtype = attention_scores.type() attention_scores = attention_scores.float() attention_scores = attention_scores * query_key_layer_scaling_coeff attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.half() + attention_probs = attention_probs.type(dtype) # ========================= # Context layer. [sq, b, hp]