Update app.py

This commit is contained in:
Baize 2023-04-04 06:41:22 +00:00 committed by huggingface-web
parent 7e2a3f2da9
commit 9d3530f9a9
1 changed files with 4 additions and 4 deletions

8
app.py
View File

@ -17,8 +17,8 @@ base_model = "decapoda-research/llama-7b-hf"
adapter_model = "project-baize/baize-lora-7B" adapter_model = "project-baize/baize-lora-7B"
tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model) tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
global total_cont global total_count
total_cont = 0 total_count = 0
def predict(text, def predict(text,
chatbot, chatbot,
history, history,
@ -44,8 +44,8 @@ def predict(text,
begin_length = len(prompt) begin_length = len(prompt)
torch.cuda.empty_cache() torch.cuda.empty_cache()
input_ids = inputs["input_ids"].to(device) input_ids = inputs["input_ids"].to(device)
total_cont += 1 total_count += 1
print(total_cont) print(total_count)
with torch.no_grad(): with torch.no_grad():
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p): for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False: if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False: