Fix attention score on mps

This commit is contained in:
duzx16 2023-04-09 20:30:23 +08:00
parent a7272d4c93
commit 41fda88421
1 changed files with 2 additions and 4 deletions

View File

@ -280,10 +280,8 @@ def attention_fn(
# [sk, b, np, hn] -> [sk, b * np, hn] # [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
matmul_result = torch.empty( matmul_result = torch.zeros(
output_size[0] * output_size[1], 1, 1, 1,
output_size[2],
output_size[3],
dtype=query_layer.dtype, dtype=query_layer.dtype,
device=query_layer.device, device=query_layer.device,
) )