Close CPU fusion on Mac
This commit is contained in:
parent
d2bbc82a2c
commit
4a9b711e61
|
@ -5,6 +5,7 @@ import copy
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
@ -32,6 +33,8 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
|
||||||
from .configuration_chatglm import ChatGLMConfig
|
from .configuration_chatglm import ChatGLMConfig
|
||||||
|
|
||||||
# flags required to enable jit fusion kernels
|
# flags required to enable jit fusion kernels
|
||||||
|
|
||||||
|
if sys.platform != 'darwin':
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
torch._C._jit_set_profiling_executor(False)
|
torch._C._jit_set_profiling_executor(False)
|
||||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||||
|
@ -266,7 +269,7 @@ def attention_fn(
|
||||||
if not (attention_mask == 0).all():
|
if not (attention_mask == 0).all():
|
||||||
# if auto-regressive, skip
|
# if auto-regressive, skip
|
||||||
attention_scores.masked_fill_(attention_mask, -10000.0)
|
attention_scores.masked_fill_(attention_mask, -10000.0)
|
||||||
dtype = attention_scores.type()
|
dtype = attention_scores.dtype
|
||||||
attention_scores = attention_scores.float()
|
attention_scores = attention_scores.float()
|
||||||
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue