Fix attention score on mps
This commit is contained in:
parent
a7272d4c93
commit
41fda88421
|
@ -280,10 +280,8 @@ def attention_fn(
|
||||||
# [sk, b, np, hn] -> [sk, b * np, hn]
|
# [sk, b, np, hn] -> [sk, b * np, hn]
|
||||||
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
||||||
|
|
||||||
matmul_result = torch.empty(
|
matmul_result = torch.zeros(
|
||||||
output_size[0] * output_size[1],
|
1, 1, 1,
|
||||||
output_size[2],
|
|
||||||
output_size[3],
|
|
||||||
dtype=query_layer.dtype,
|
dtype=query_layer.dtype,
|
||||||
device=query_layer.device,
|
device=query_layer.device,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue