Fix attention score on mps
This commit is contained in:
parent
a7272d4c93
commit
41fda88421
|
@ -280,10 +280,8 @@ def attention_fn(
|
|||
# [sk, b, np, hn] -> [sk, b * np, hn]
|
||||
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
||||
|
||||
matmul_result = torch.empty(
|
||||
output_size[0] * output_size[1],
|
||||
output_size[2],
|
||||
output_size[3],
|
||||
matmul_result = torch.zeros(
|
||||
1, 1, 1,
|
||||
dtype=query_layer.dtype,
|
||||
device=query_layer.device,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue