Close CPU fusion on Mac

This commit is contained in:
duzx16 2023-03-23 22:43:06 +08:00
parent d2bbc82a2c
commit 4a9b711e61
1 changed files with 8 additions and 5 deletions

View File

@ -5,6 +5,7 @@ import copy
import os import os
import warnings import warnings
import re import re
import sys
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -32,10 +33,12 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
from .configuration_chatglm import ChatGLMConfig from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False) if sys.platform != 'darwin':
torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -266,7 +269,7 @@ def attention_fn(
if not (attention_mask == 0).all(): if not (attention_mask == 0).all():
# if auto-regressive, skip # if auto-regressive, skip
attention_scores.masked_fill_(attention_mask, -10000.0) attention_scores.masked_fill_(attention_mask, -10000.0)
dtype = attention_scores.type() dtype = attention_scores.dtype
attention_scores = attention_scores.float() attention_scores = attention_scores.float()
attention_scores = attention_scores * query_key_layer_scaling_coeff attention_scores = attention_scores * query_key_layer_scaling_coeff