diff --git a/quantization.py b/quantization.py index 788b5e7..0ebb94a 100644 --- a/quantization.py +++ b/quantization.py @@ -441,10 +441,10 @@ def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=F try: load_cpu_kernel(**kwargs) except: - print("Cannot load cpu kernel, don't use quantized model on cpu.") if kernels is None: # CUDA kernels failed - print("Cannot load cuda kernel, quantization failed.") - return model + print("Cannot load cpu or cuda kernel, quantization failed:") + assert kernels is None + print("Cannot load cpu kernel, don't use quantized model on cpu.") current_device = model.device