Update CPU kernel loading method

This commit is contained in:
songxxzp 2023-04-14 18:52:35 +08:00
parent 3485994337
commit c7d8998bb3
1 changed files with 49 additions and 14 deletions

View File

@ -103,7 +103,7 @@ class CPUKernel:
self.int8WeightExtractionFloat = None self.int8WeightExtractionFloat = None
self.int4WeightExtractionFloat = None self.int4WeightExtractionFloat = None
self.int4WeightCompression = None self.int4WeightCompression = None
self.SetNumThreads = None self.SetNumThreads = lambda x: x
try: try:
if not os.path.exists(default_cpu_kernel_code_path): if not os.path.exists(default_cpu_kernel_code_path):
@ -127,38 +127,74 @@ class CPUKernel:
if compile_parallel_kernel and source_code == default_cpu_kernel_code_path: if compile_parallel_kernel and source_code == default_cpu_kernel_code_path:
source_code = default_cpu_parallel_kernel_code_path source_code = default_cpu_parallel_kernel_code_path
kernels = None
if (not kernel_file) or (not os.path.exists(kernel_file)): if (not kernel_file) or (not os.path.exists(kernel_file)):
print("No compiled kernel found.") print("No compiled kernel found.")
try: try:
if os.path.exists(source_code): if os.path.exists(source_code):
print("Compiling kernels :", source_code) print("Compiling kernels :", source_code)
kernel_file = source_code[:-2] + ".so" kernel_file = source_code[:-2] + ".so"
if compile_parallel_kernel: if compile_parallel_kernel:
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(source_code, kernel_file) compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(source_code, kernel_file)
print("Compiling", compile_command) print("Compiling", compile_command)
exit_state = os.system(compile_command) exit_state = os.system(compile_command)
if exit_state: if not exit_state:
print("Compile failed, using default cpu kernel code.") try:
kernels = ctypes.cdll.LoadLibrary(kernel_file)
print("Load kernel :", kernel_file)
except:
kernels = None
print("Load parallel cpu kernel failed, using default cpu kernel code:")
import traceback
exception = traceback.format_exc()
print(exception)
else:
print("Compile default cpu kernel failed, using default cpu kernel code.")
if kernels is None: # adjust config, use default cpu kernel
compile_parallel_kernel = False compile_parallel_kernel = False
source_code = default_cpu_kernel_code_path source_code = default_cpu_kernel_code_path
kernel_file = source_code[:-2] + ".so" kernel_file = source_code[:-2] + ".so"
compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file)
print("Compiling", compile_command)
exit_state = os.system(compile_command)
else:
compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file)
print("Compiling", compile_command)
exit_state = os.system(compile_command)
print("Kernels compiled :", kernel_file) if kernels is None:
compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file)
print("Compiling", compile_command)
exit_state = os.system(compile_command)
if not exit_state:
try:
kernels = ctypes.cdll.LoadLibrary(kernel_file)
print("Load kernel :", kernel_file)
except:
kernels = None
print("Load default cpu kernel failed:")
import traceback
exception = traceback.format_exc()
print(exception)
else:
print("Compile default cpu kernel failed.")
else: else:
print("Kernel source code not found.") print("Kernel source code not found.")
return return
except: except:
print("Failed to build kernel.") print("Failed to build cpu kernel:")
import traceback
exception = traceback.format_exc()
print(exception)
return return
if kernel_file: else:
try:
kernels = ctypes.cdll.LoadLibrary(kernel_file) kernels = ctypes.cdll.LoadLibrary(kernel_file)
print("Load kernel :", kernel_file)
except:
kernels = None
print("Load custom cpu kernel failed:")
import traceback
exception = traceback.format_exc()
print(exception)
if kernels is not None:
self.int8WeightExtractionFloat = kernels.extract_int8_weight_to_float self.int8WeightExtractionFloat = kernels.extract_int8_weight_to_float
self.int4WeightExtractionFloat = kernels.extract_int4_weight_to_float self.int4WeightExtractionFloat = kernels.extract_int4_weight_to_float
self.int4WeightCompression = kernels.compress_int4_weight self.int4WeightCompression = kernels.compress_int4_weight
@ -167,11 +203,10 @@ class CPUKernel:
self.SetNumThreads = kernels.set_num_threads self.SetNumThreads = kernels.set_num_threads
except: except:
print("No set_num_threads() found in kernel.") print("No set_num_threads() found in kernel.")
self.SetNumThreads = lambda x: x
self.load = True self.load = True
print("Load kernel :", kernel_file)
else: else:
print("Failed to load kernel.") print("Failed to load kernel.")
return
if compile_parallel_kernel: if compile_parallel_kernel:
if parallel_num is None: if parallel_num is None: