Add support for float32
This commit is contained in:
parent
a034f2a1ed
commit
d4832e8142
|
@ -254,13 +254,13 @@ def attention_fn(
|
||||||
if not (attention_mask == 0).all():
|
if not (attention_mask == 0).all():
|
||||||
# if auto-regressive, skip
|
# if auto-regressive, skip
|
||||||
attention_scores.masked_fill_(attention_mask, -10000.0)
|
attention_scores.masked_fill_(attention_mask, -10000.0)
|
||||||
|
dtype = attention_scores.type()
|
||||||
attention_scores = attention_scores.float()
|
attention_scores = attention_scores.float()
|
||||||
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
||||||
|
|
||||||
attention_probs = F.softmax(attention_scores, dim=-1)
|
attention_probs = F.softmax(attention_scores, dim=-1)
|
||||||
|
|
||||||
attention_probs = attention_probs.half()
|
attention_probs = attention_probs.type(dtype)
|
||||||
|
|
||||||
# =========================
|
# =========================
|
||||||
# Context layer. [sq, b, hp]
|
# Context layer. [sq, b, hp]
|
||||||
|
|
Loading…
Reference in New Issue