From 77ff0b3fac729989ddfa397edcdcb42e1ccc9138 Mon Sep 17 00:00:00 2001 From: Baize Date: Sat, 1 Apr 2023 10:24:42 +0000 Subject: [PATCH] Upload 2 files --- app.py | 219 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 14 +++ 2 files changed, 233 insertions(+) create mode 100644 app.py create mode 100644 requirements.txt diff --git a/app.py b/app.py new file mode 100644 index 0000000..274da8d --- /dev/null +++ b/app.py @@ -0,0 +1,219 @@ +# -*- coding:utf-8 -*- +import os +import logging +import sys +import gradio as gr +import torch +from app_modules.utils import * +from app_modules.presets import * +from app_modules.overwrites import * + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", +) + +base_model = "decapoda-research/llama-7b-hf" +adapter_model = "project-baize/baize-lora-7B" +#tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model) + + +def predict(text, + chatbot, + history, + top_p, + temperature, + max_length_tokens, + max_context_length_tokens,): + if text=="": + return history,history,"Empty Context" + + inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens) + if inputs is False: + return [[x[0],convert_to_markdown(x[1])] for x in history]+[[text,"Sorry, the input is too long."]],history,"Generate Fail" + else: + prompt,inputs=inputs + begin_length = len(prompt) + + input_ids = inputs["input_ids"].to(device) + + with torch.no_grad(): + for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p): + if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False: + if "[|Human|]" in x: + x = x[:x.index("[|Human|]")].strip() + if "[|AI|]" in x: + x = x[:x.index("[|AI|]")].strip() + x = x.strip(" ") + a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]] + yield a, b, "Generating……" + if shared_state.interrupted: + shared_state.recover() + try: + yield a, b, "Stop Success" + return + except: + pass + print(prompt) + print(x) + print("="*80) + try: + yield a,b,"Generate Success" + except: + pass + +def retry( + text, + chatbot, + history, + top_p, + temperature, + max_length_tokens, + max_context_length_tokens, + ): + logging.info("Retry……") + if len(history) == 0: + yield chatbot, history, f"Empty context" + return + chatbot.pop() + inputs = history.pop()[0] + for x in predict(inputs,chatbot,history,top_p,temperature,max_length_tokens,max_context_length_tokens): + yield x + + +gr.Chatbot.postprocess = postprocess + +with open("assets/custom.css", "r", encoding="utf-8") as f: + customCSS = f.read() + +with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: + history = gr.State([]) + user_question = gr.State("") + with gr.Row(): + gr.HTML(title) + status_display = gr.Markdown("Success", elem_id="status_display") + gr.Markdown(description_top) + with gr.Row(scale=1).style(equal_height=True): + with gr.Column(scale=5): + with gr.Row(scale=1): + chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="80%") + with gr.Row(scale=1): + with gr.Column(scale=12): + user_input = gr.Textbox( + show_label=False, placeholder="Enter text" + ).style(container=False) + with gr.Column(min_width=70, scale=1): + submitBtn = gr.Button("Send") + with gr.Column(min_width=70, scale=1): + cancelBtn = gr.Button("Stop") + + with gr.Row(scale=1): + emptyBtn = gr.Button( + "🧹 New Conversation", + ) + retryBtn = gr.Button("🔄 Regenerate") + delLastBtn = gr.Button("🗑️ Remove Last Turn") + with gr.Column(): + with gr.Column(min_width=50, scale=1): + with gr.Tab(label="Parameter Setting"): + gr.Markdown("# Parameters") + top_p = gr.Slider( + minimum=-0, + maximum=1.0, + value=0.95, + step=0.05, + interactive=True, + label="Top-p", + ) + temperature = gr.Slider( + minimum=-0, + maximum=2.0, + value=1, + step=0.1, + interactive=True, + label="Temperature", + ) + max_length_tokens = gr.Slider( + minimum=0, + maximum=512, + value=512, + step=8, + interactive=True, + label="Max Generation Tokens", + ) + max_context_length_tokens = gr.Slider( + minimum=0, + maximum=4096, + value=2048, + step=128, + interactive=True, + label="Max History Tokens", + ) + gr.Markdown(description) + + predict_args = dict( + fn=predict, + inputs=[ + user_question, + chatbot, + history, + top_p, + temperature, + max_length_tokens, + max_context_length_tokens, + ], + outputs=[chatbot, history, status_display], + show_progress=True, + ) + retry_args = dict( + fn=retry, + inputs=[ + user_input, + chatbot, + history, + top_p, + temperature, + max_length_tokens, + max_context_length_tokens, + ], + outputs=[chatbot, history, status_display], + show_progress=True, + ) + + reset_args = dict( + fn=reset_textbox, inputs=[], outputs=[user_input, status_display] + ) + + # Chatbot + cancelBtn.click(cancel_outputing, [], [ status_display]) + transfer_input_args = dict( + fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn, cancelBtn], show_progress=True + ) + + user_input.submit(**transfer_input_args).then(**predict_args) + + submitBtn.click(**transfer_input_args).then(**predict_args) + + emptyBtn.click( + reset_state, + outputs=[chatbot, history, status_display], + show_progress=True, + ) + emptyBtn.click(**reset_args) + + retryBtn.click(**retry_args) + + delLastBtn.click( + delete_last_conversation, + [chatbot, history], + [chatbot, history, status_display], + show_progress=True, + ) + +demo.title = "Baize" + +if __name__ == "__main__": + reload_javascript() + demo.queue(concurrency_count=CONCURRENT_COUNT).launch( + share=True, favicon_path="./assets/favicon.ico", inbrowser=True + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..cab40a5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +gradio +mdtex2html +pypinyin +tiktoken +socksio +tqdm +colorama +duckduckgo_search +Pygments +llama_index +langchain +markdown +markdown2 +