diff --git a/app.py b/app.py index 5611940..9f98474 100644 --- a/app.py +++ b/app.py @@ -17,8 +17,8 @@ base_model = "decapoda-research/llama-7b-hf" adapter_model = "project-baize/baize-lora-7B" tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model) -global total_cont -total_cont = 0 +global total_count +total_count = 0 def predict(text, chatbot, history, @@ -44,8 +44,8 @@ def predict(text, begin_length = len(prompt) torch.cuda.empty_cache() input_ids = inputs["input_ids"].to(device) - total_cont += 1 - print(total_cont) + total_count += 1 + print(total_count) with torch.no_grad(): for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p): if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False: