Add support for float32

This commit is contained in:
duzx16 2023-03-14 14:49:14 +08:00
parent a034f2a1ed
commit d4832e8142
1 changed files with 2 additions and 2 deletions

View File

@ -254,13 +254,13 @@ def attention_fn(
if not (attention_mask == 0).all(): if not (attention_mask == 0).all():
# if auto-regressive, skip # if auto-regressive, skip
attention_scores.masked_fill_(attention_mask, -10000.0) attention_scores.masked_fill_(attention_mask, -10000.0)
dtype = attention_scores.type()
attention_scores = attention_scores.float() attention_scores = attention_scores.float()
attention_scores = attention_scores * query_key_layer_scaling_coeff attention_scores = attention_scores * query_key_layer_scaling_coeff
attention_probs = F.softmax(attention_scores, dim=-1) attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.half() attention_probs = attention_probs.type(dtype)
# ========================= # =========================
# Context layer. [sq, b, hp] # Context layer. [sq, b, hp]