diff --git a/quantization.py b/quantization.py index 0ebb94a..5be8b0b 100644 --- a/quantization.py +++ b/quantization.py @@ -443,7 +443,7 @@ def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=F except: if kernels is None: # CUDA kernels failed print("Cannot load cpu or cuda kernel, quantization failed:") - assert kernels is None + assert kernels is not None print("Cannot load cpu kernel, don't use quantized model on cpu.") current_device = model.device