diff --git a/app_modules/utils.py b/app_modules/utils.py index 0a0453d..b6c7728 100644 --- a/app_modules/utils.py +++ b/app_modules/utils.py @@ -352,12 +352,13 @@ def load_tokenizer_and_model(base_model,adapter_model,load_8bit=False): ) else: model = LlamaForCausalLM.from_pretrained( - base_model, device_map={"": device}, low_cpu_mem_usage=True + base_model, device_map={"": device}, low_cpu_mem_usage=True,torch_dtype=torch.float16 ) model = PeftModel.from_pretrained( model, adapter_model, device_map={"": device}, + torch_dtype=torch.float16 ) if not load_8bit: