From 41fda884213fbb7df5baa498a6a03624744a3324 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Sun, 9 Apr 2023 20:30:23 +0800 Subject: [PATCH] Fix attention score on mps --- modeling_chatglm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modeling_chatglm.py b/modeling_chatglm.py index 3edaeb1..883f774 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -280,10 +280,8 @@ def attention_fn( # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - matmul_result = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], + matmul_result = torch.zeros( + 1, 1, 1, dtype=query_layer.dtype, device=query_layer.device, )