Update app.py
This commit is contained in:
parent
7e2a3f2da9
commit
9d3530f9a9
8
app.py
8
app.py
|
@ -17,8 +17,8 @@ base_model = "decapoda-research/llama-7b-hf"
|
||||||
adapter_model = "project-baize/baize-lora-7B"
|
adapter_model = "project-baize/baize-lora-7B"
|
||||||
tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
|
tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
|
||||||
|
|
||||||
global total_cont
|
global total_count
|
||||||
total_cont = 0
|
total_count = 0
|
||||||
def predict(text,
|
def predict(text,
|
||||||
chatbot,
|
chatbot,
|
||||||
history,
|
history,
|
||||||
|
@ -44,8 +44,8 @@ def predict(text,
|
||||||
begin_length = len(prompt)
|
begin_length = len(prompt)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
input_ids = inputs["input_ids"].to(device)
|
input_ids = inputs["input_ids"].to(device)
|
||||||
total_cont += 1
|
total_count += 1
|
||||||
print(total_cont)
|
print(total_count)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
|
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
|
||||||
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
|
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
|
||||||
|
|
Loading…
Reference in New Issue