Close CPU fusion on Mac
This commit is contained in:
parent
d2bbc82a2c
commit
4a9b711e61
|
@ -5,6 +5,7 @@ import copy
|
|||
import os
|
||||
import warnings
|
||||
import re
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
@ -32,6 +33,8 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
|
|||
from .configuration_chatglm import ChatGLMConfig
|
||||
|
||||
# flags required to enable jit fusion kernels
|
||||
|
||||
if sys.platform != 'darwin':
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
|
@ -266,7 +269,7 @@ def attention_fn(
|
|||
if not (attention_mask == 0).all():
|
||||
# if auto-regressive, skip
|
||||
attention_scores.masked_fill_(attention_mask, -10000.0)
|
||||
dtype = attention_scores.type()
|
||||
dtype = attention_scores.dtype
|
||||
attention_scores = attention_scores.float()
|
||||
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
||||
|
||||
|
|
Loading…
Reference in New Issue