diff --git a/quantization.py b/quantization.py index 49b53de..d739c90 100644 --- a/quantization.py +++ b/quantization.py @@ -442,7 +442,6 @@ class QuantizedEmbedding(Embedding): # TODO: backward, check empty_init def load_cpu_kernel(**kwargs): global cpu_kernels cpu_kernels = CPUKernel(**kwargs) - assert cpu_kernels.load def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs): @@ -453,9 +452,8 @@ def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=F dense_h_to_4h_quantization_cache = None dense_4h_to_h_quantization_cache = None - try: - load_cpu_kernel(**kwargs) - except: + load_cpu_kernel(**kwargs) + if not cpu_kernels.load: if kernels is None: # CUDA kernels failed print("Cannot load cpu or cuda kernel, quantization failed:") assert kernels is not None