diff --git a/quantization.py b/quantization.py index d739c90..2f6396a 100644 --- a/quantization.py +++ b/quantization.py @@ -7,6 +7,7 @@ import bz2 import torch import base64 import ctypes +import sys from transformers.utils import logging from typing import List @@ -142,8 +143,12 @@ class CPUKernel: kernel_file = source_code[:-2] + ".so" if compile_parallel_kernel: - compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format( - source_code, kernel_file) + if sys.platform != 'darwin': + compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format( + source_code, kernel_file) + else: + compile_command = "clang -O3 -fPIC -pthread -Xclang -fopenmp -lomp -std=c99 {} -shared -o {}".format( + source_code, kernel_file) print("Compiling", compile_command) exit_state = os.system(compile_command) if not exit_state: