From c7d8998bb3c7f891e1f9a5f205d7bb8c7acf7b8d Mon Sep 17 00:00:00 2001 From: songxxzp Date: Fri, 14 Apr 2023 18:52:35 +0800 Subject: [PATCH] Update CPU kernel loading method --- quantization.py | 63 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/quantization.py b/quantization.py index 07e78b1..788b5e7 100644 --- a/quantization.py +++ b/quantization.py @@ -103,7 +103,7 @@ class CPUKernel: self.int8WeightExtractionFloat = None self.int4WeightExtractionFloat = None self.int4WeightCompression = None - self.SetNumThreads = None + self.SetNumThreads = lambda x: x try: if not os.path.exists(default_cpu_kernel_code_path): @@ -127,38 +127,74 @@ class CPUKernel: if compile_parallel_kernel and source_code == default_cpu_kernel_code_path: source_code = default_cpu_parallel_kernel_code_path + kernels = None + if (not kernel_file) or (not os.path.exists(kernel_file)): print("No compiled kernel found.") try: if os.path.exists(source_code): print("Compiling kernels :", source_code) kernel_file = source_code[:-2] + ".so" + if compile_parallel_kernel: compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(source_code, kernel_file) print("Compiling", compile_command) exit_state = os.system(compile_command) - if exit_state: - print("Compile failed, using default cpu kernel code.") + if not exit_state: + try: + kernels = ctypes.cdll.LoadLibrary(kernel_file) + print("Load kernel :", kernel_file) + except: + kernels = None + print("Load parallel cpu kernel failed, using default cpu kernel code:") + import traceback + exception = traceback.format_exc() + print(exception) + else: + print("Compile default cpu kernel failed, using default cpu kernel code.") + + if kernels is None: # adjust config, use default cpu kernel compile_parallel_kernel = False source_code = default_cpu_kernel_code_path kernel_file = source_code[:-2] + ".so" - compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file) - print("Compiling", compile_command) - exit_state = os.system(compile_command) - else: + + if kernels is None: compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file) print("Compiling", compile_command) exit_state = os.system(compile_command) - - print("Kernels compiled :", kernel_file) + if not exit_state: + try: + kernels = ctypes.cdll.LoadLibrary(kernel_file) + print("Load kernel :", kernel_file) + except: + kernels = None + print("Load default cpu kernel failed:") + import traceback + exception = traceback.format_exc() + print(exception) + else: + print("Compile default cpu kernel failed.") else: print("Kernel source code not found.") return except: - print("Failed to build kernel.") + print("Failed to build cpu kernel:") + import traceback + exception = traceback.format_exc() + print(exception) return - if kernel_file: - kernels = ctypes.cdll.LoadLibrary(kernel_file) + else: + try: + kernels = ctypes.cdll.LoadLibrary(kernel_file) + print("Load kernel :", kernel_file) + except: + kernels = None + print("Load custom cpu kernel failed:") + import traceback + exception = traceback.format_exc() + print(exception) + + if kernels is not None: self.int8WeightExtractionFloat = kernels.extract_int8_weight_to_float self.int4WeightExtractionFloat = kernels.extract_int4_weight_to_float self.int4WeightCompression = kernels.compress_int4_weight @@ -167,11 +203,10 @@ class CPUKernel: self.SetNumThreads = kernels.set_num_threads except: print("No set_num_threads() found in kernel.") - self.SetNumThreads = lambda x: x self.load = True - print("Load kernel :", kernel_file) else: print("Failed to load kernel.") + return if compile_parallel_kernel: if parallel_num is None: