Upload 2 files
This commit is contained in:
parent
4bcccc42db
commit
77ff0b3fac
|
@ -0,0 +1,219 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import logging
|
||||
import sys
|
||||
import gradio as gr
|
||||
import torch
|
||||
from app_modules.utils import *
|
||||
from app_modules.presets import *
|
||||
from app_modules.overwrites import *
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
||||
)
|
||||
|
||||
base_model = "decapoda-research/llama-7b-hf"
|
||||
adapter_model = "project-baize/baize-lora-7B"
|
||||
#tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
|
||||
|
||||
|
||||
def predict(text,
|
||||
chatbot,
|
||||
history,
|
||||
top_p,
|
||||
temperature,
|
||||
max_length_tokens,
|
||||
max_context_length_tokens,):
|
||||
if text=="":
|
||||
return history,history,"Empty Context"
|
||||
|
||||
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
|
||||
if inputs is False:
|
||||
return [[x[0],convert_to_markdown(x[1])] for x in history]+[[text,"Sorry, the input is too long."]],history,"Generate Fail"
|
||||
else:
|
||||
prompt,inputs=inputs
|
||||
begin_length = len(prompt)
|
||||
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
|
||||
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
|
||||
if "[|Human|]" in x:
|
||||
x = x[:x.index("[|Human|]")].strip()
|
||||
if "[|AI|]" in x:
|
||||
x = x[:x.index("[|AI|]")].strip()
|
||||
x = x.strip(" ")
|
||||
a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
|
||||
yield a, b, "Generating……"
|
||||
if shared_state.interrupted:
|
||||
shared_state.recover()
|
||||
try:
|
||||
yield a, b, "Stop Success"
|
||||
return
|
||||
except:
|
||||
pass
|
||||
print(prompt)
|
||||
print(x)
|
||||
print("="*80)
|
||||
try:
|
||||
yield a,b,"Generate Success"
|
||||
except:
|
||||
pass
|
||||
|
||||
def retry(
|
||||
text,
|
||||
chatbot,
|
||||
history,
|
||||
top_p,
|
||||
temperature,
|
||||
max_length_tokens,
|
||||
max_context_length_tokens,
|
||||
):
|
||||
logging.info("Retry……")
|
||||
if len(history) == 0:
|
||||
yield chatbot, history, f"Empty context"
|
||||
return
|
||||
chatbot.pop()
|
||||
inputs = history.pop()[0]
|
||||
for x in predict(inputs,chatbot,history,top_p,temperature,max_length_tokens,max_context_length_tokens):
|
||||
yield x
|
||||
|
||||
|
||||
gr.Chatbot.postprocess = postprocess
|
||||
|
||||
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
||||
customCSS = f.read()
|
||||
|
||||
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
||||
history = gr.State([])
|
||||
user_question = gr.State("")
|
||||
with gr.Row():
|
||||
gr.HTML(title)
|
||||
status_display = gr.Markdown("Success", elem_id="status_display")
|
||||
gr.Markdown(description_top)
|
||||
with gr.Row(scale=1).style(equal_height=True):
|
||||
with gr.Column(scale=5):
|
||||
with gr.Row(scale=1):
|
||||
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="80%")
|
||||
with gr.Row(scale=1):
|
||||
with gr.Column(scale=12):
|
||||
user_input = gr.Textbox(
|
||||
show_label=False, placeholder="Enter text"
|
||||
).style(container=False)
|
||||
with gr.Column(min_width=70, scale=1):
|
||||
submitBtn = gr.Button("Send")
|
||||
with gr.Column(min_width=70, scale=1):
|
||||
cancelBtn = gr.Button("Stop")
|
||||
|
||||
with gr.Row(scale=1):
|
||||
emptyBtn = gr.Button(
|
||||
"🧹 New Conversation",
|
||||
)
|
||||
retryBtn = gr.Button("🔄 Regenerate")
|
||||
delLastBtn = gr.Button("🗑️ Remove Last Turn")
|
||||
with gr.Column():
|
||||
with gr.Column(min_width=50, scale=1):
|
||||
with gr.Tab(label="Parameter Setting"):
|
||||
gr.Markdown("# Parameters")
|
||||
top_p = gr.Slider(
|
||||
minimum=-0,
|
||||
maximum=1.0,
|
||||
value=0.95,
|
||||
step=0.05,
|
||||
interactive=True,
|
||||
label="Top-p",
|
||||
)
|
||||
temperature = gr.Slider(
|
||||
minimum=-0,
|
||||
maximum=2.0,
|
||||
value=1,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
label="Temperature",
|
||||
)
|
||||
max_length_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=512,
|
||||
value=512,
|
||||
step=8,
|
||||
interactive=True,
|
||||
label="Max Generation Tokens",
|
||||
)
|
||||
max_context_length_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=4096,
|
||||
value=2048,
|
||||
step=128,
|
||||
interactive=True,
|
||||
label="Max History Tokens",
|
||||
)
|
||||
gr.Markdown(description)
|
||||
|
||||
predict_args = dict(
|
||||
fn=predict,
|
||||
inputs=[
|
||||
user_question,
|
||||
chatbot,
|
||||
history,
|
||||
top_p,
|
||||
temperature,
|
||||
max_length_tokens,
|
||||
max_context_length_tokens,
|
||||
],
|
||||
outputs=[chatbot, history, status_display],
|
||||
show_progress=True,
|
||||
)
|
||||
retry_args = dict(
|
||||
fn=retry,
|
||||
inputs=[
|
||||
user_input,
|
||||
chatbot,
|
||||
history,
|
||||
top_p,
|
||||
temperature,
|
||||
max_length_tokens,
|
||||
max_context_length_tokens,
|
||||
],
|
||||
outputs=[chatbot, history, status_display],
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
reset_args = dict(
|
||||
fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
|
||||
)
|
||||
|
||||
# Chatbot
|
||||
cancelBtn.click(cancel_outputing, [], [ status_display])
|
||||
transfer_input_args = dict(
|
||||
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn, cancelBtn], show_progress=True
|
||||
)
|
||||
|
||||
user_input.submit(**transfer_input_args).then(**predict_args)
|
||||
|
||||
submitBtn.click(**transfer_input_args).then(**predict_args)
|
||||
|
||||
emptyBtn.click(
|
||||
reset_state,
|
||||
outputs=[chatbot, history, status_display],
|
||||
show_progress=True,
|
||||
)
|
||||
emptyBtn.click(**reset_args)
|
||||
|
||||
retryBtn.click(**retry_args)
|
||||
|
||||
delLastBtn.click(
|
||||
delete_last_conversation,
|
||||
[chatbot, history],
|
||||
[chatbot, history, status_display],
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
demo.title = "Baize"
|
||||
|
||||
if __name__ == "__main__":
|
||||
reload_javascript()
|
||||
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
||||
share=True, favicon_path="./assets/favicon.ico", inbrowser=True
|
||||
)
|
|
@ -0,0 +1,14 @@
|
|||
gradio
|
||||
mdtex2html
|
||||
pypinyin
|
||||
tiktoken
|
||||
socksio
|
||||
tqdm
|
||||
colorama
|
||||
duckduckgo_search
|
||||
Pygments
|
||||
llama_index
|
||||
langchain
|
||||
markdown
|
||||
markdown2
|
||||
|
Loading…
Reference in New Issue