app_demo_test_1/app.py

233 lines
7.2 KiB
Python
Raw Normal View History

2023-04-01 10:24:42 +00:00
# -*- coding:utf-8 -*-
import os
import logging
import sys
import gradio as gr
import torch
2023-04-05 01:40:27 +00:00
import gc
2023-04-01 10:24:42 +00:00
from app_modules.utils import *
from app_modules.presets import *
from app_modules.overwrites import *
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
)
base_model = "decapoda-research/llama-7b-hf"
adapter_model = "project-baize/baize-lora-7B"
2023-04-02 05:13:35 +00:00
tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
2023-04-01 10:24:42 +00:00
2023-04-04 06:41:22 +00:00
total_count = 0
2023-04-01 10:24:42 +00:00
def predict(text,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,):
if text=="":
2023-04-04 16:20:35 +00:00
yield chatbot,history,"Empty context."
2023-04-03 04:40:24 +00:00
return
2023-04-02 03:49:35 +00:00
try:
model
except:
2023-04-03 04:40:24 +00:00
yield [[text,"No Model Found"]],[],"No Model Found"
return
2023-04-04 03:50:13 +00:00
2023-04-01 10:24:42 +00:00
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
2023-04-04 15:55:15 +00:00
if inputs is None:
2023-04-04 16:20:35 +00:00
yield chatbot,history,"Input too long."
2023-04-04 15:55:15 +00:00
return
2023-04-01 10:24:42 +00:00
else:
prompt,inputs=inputs
begin_length = len(prompt)
2023-04-04 15:55:15 +00:00
input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
2023-04-04 03:55:24 +00:00
torch.cuda.empty_cache()
2023-04-04 06:42:54 +00:00
global total_count
2023-04-04 06:41:22 +00:00
total_count += 1
print(total_count)
2023-04-05 01:43:02 +00:00
if total_count % 50 == 0 :
os.system("nvidia-smi")
2023-04-01 10:24:42 +00:00
with torch.no_grad():
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
if "[|Human|]" in x:
x = x[:x.index("[|Human|]")].strip()
if "[|AI|]" in x:
x = x[:x.index("[|AI|]")].strip()
2023-04-05 01:40:27 +00:00
x = x.strip()
2023-04-01 10:24:42 +00:00
a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
2023-04-04 16:20:35 +00:00
yield a, b, "Generating..."
2023-04-01 10:24:42 +00:00
if shared_state.interrupted:
shared_state.recover()
try:
2023-04-04 16:20:35 +00:00
yield a, b, "Stop: Success"
2023-04-01 10:24:42 +00:00
return
except:
pass
2023-04-05 01:40:27 +00:00
del input_ids
gc.collect()
2023-04-04 15:55:15 +00:00
torch.cuda.empty_cache()
2023-04-03 15:11:42 +00:00
#print(text)
#print(x)
#print("="*80)
2023-04-01 10:24:42 +00:00
try:
2023-04-04 16:20:35 +00:00
yield a,b,"Generate: Success"
2023-04-01 10:24:42 +00:00
except:
pass
def retry(
text,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,
):
2023-04-04 16:20:35 +00:00
logging.info("Retry...")
2023-04-01 10:24:42 +00:00
if len(history) == 0:
yield chatbot, history, f"Empty context"
return
chatbot.pop()
inputs = history.pop()[0]
for x in predict(inputs,chatbot,history,top_p,temperature,max_length_tokens,max_context_length_tokens):
yield x
gr.Chatbot.postprocess = postprocess
with open("assets/custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
history = gr.State([])
user_question = gr.State("")
with gr.Row():
gr.HTML(title)
status_display = gr.Markdown("Success", elem_id="status_display")
gr.Markdown(description_top)
with gr.Row(scale=1).style(equal_height=True):
with gr.Column(scale=5):
with gr.Row(scale=1):
2023-04-01 13:44:45 +00:00
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
2023-04-01 10:24:42 +00:00
with gr.Row(scale=1):
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=False, placeholder="Enter text"
).style(container=False)
with gr.Column(min_width=70, scale=1):
submitBtn = gr.Button("Send")
2023-04-06 07:01:31 +00:00
with gr.Column(min_width=70, scale=1):
cancelBtn = gr.Button("Stop")
2023-04-01 10:24:42 +00:00
with gr.Row(scale=1):
emptyBtn = gr.Button(
"🧹 New Conversation",
)
retryBtn = gr.Button("🔄 Regenerate")
delLastBtn = gr.Button("🗑️ Remove Last Turn")
with gr.Column():
with gr.Column(min_width=50, scale=1):
with gr.Tab(label="Parameter Setting"):
gr.Markdown("# Parameters")
top_p = gr.Slider(
minimum=-0,
2023-04-04 03:54:48 +00:00
maximum=1.0,
2023-04-01 10:24:42 +00:00
value=0.95,
step=0.05,
interactive=True,
label="Top-p",
)
temperature = gr.Slider(
2023-04-04 03:54:48 +00:00
minimum=0.1,
2023-04-01 10:24:42 +00:00
maximum=2.0,
value=1,
step=0.1,
interactive=True,
label="Temperature",
)
max_length_tokens = gr.Slider(
minimum=0,
maximum=512,
2023-04-02 05:25:59 +00:00
value=256,
2023-04-01 10:24:42 +00:00
step=8,
interactive=True,
label="Max Generation Tokens",
)
max_context_length_tokens = gr.Slider(
minimum=0,
2023-04-06 07:01:31 +00:00
maximum=4096,
2023-04-04 04:02:09 +00:00
value=2048,
2023-04-01 10:24:42 +00:00
step=128,
interactive=True,
label="Max History Tokens",
)
gr.Markdown(description)
predict_args = dict(
fn=predict,
inputs=[
user_question,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,
],
outputs=[chatbot, history, status_display],
show_progress=True,
)
retry_args = dict(
fn=retry,
inputs=[
user_input,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,
],
outputs=[chatbot, history, status_display],
show_progress=True,
)
reset_args = dict(
fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
)
# Chatbot
transfer_input_args = dict(
2023-04-02 05:35:38 +00:00
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True
2023-04-01 10:24:42 +00:00
)
2023-04-03 04:21:00 +00:00
predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
2023-04-01 10:24:42 +00:00
2023-04-03 04:21:00 +00:00
predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args)
2023-04-01 10:24:42 +00:00
emptyBtn.click(
reset_state,
outputs=[chatbot, history, status_display],
show_progress=True,
)
emptyBtn.click(**reset_args)
2023-04-03 04:21:00 +00:00
predict_event3 = retryBtn.click(**retry_args)
2023-04-01 10:24:42 +00:00
delLastBtn.click(
delete_last_conversation,
[chatbot, history],
[chatbot, history, status_display],
show_progress=True,
)
2023-04-06 07:01:31 +00:00
cancelBtn.click(
cancel_outputing, [], [status_display],
cancels=[
predict_event1,predict_event2,predict_event3
]
)
2023-04-01 10:24:42 +00:00
demo.title = "Baize"
2023-04-04 06:51:31 +00:00
demo.queue(concurrency_count=1).launch()