From 2a7b7b19119b1fa0796ef85ae229202031664c37 Mon Sep 17 00:00:00 2001 From: yanqiangmiffy <1185918903@qq.com> Date: Wed, 19 Apr 2023 22:50:11 +0800 Subject: [PATCH] =?UTF-8?q?feature@=E5=88=9B=E5=BB=BAhf=E5=BA=94=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- app.py | 4 +- clc/source_service.py | 8 +- main.py | 201 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 208 insertions(+), 7 deletions(-) create mode 100644 main.py diff --git a/README.md b/README.md index a49790f..1a2ac79 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ cchardet ### 启动Gradio ```shell -python app.py +python main.py ``` ## 🚀 特性 diff --git a/app.py b/app.py index b67618a..de0d62f 100644 --- a/app.py +++ b/app.py @@ -192,9 +192,9 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: demo.queue(concurrency_count=2).launch( server_name='0.0.0.0', - server_port=8888, share=False, show_error=True, debug=True, - enable_queue=True + enable_queue=True, + inbrowser=True, ) diff --git a/clc/source_service.py b/clc/source_service.py index 87ebc05..6e5590b 100644 --- a/clc/source_service.py +++ b/clc/source_service.py @@ -57,10 +57,10 @@ class SourceService(object): def search_web(self, query): - SESSION.proxies = { - "http": f"socks5h://localhost:7890", - "https": f"socks5h://localhost:7890" - } + # SESSION.proxies = { + # "http": f"socks5h://localhost:7890", + # "https": f"socks5h://localhost:7890" + # } results = ddg(query) web_content = '' if results: diff --git a/main.py b/main.py new file mode 100644 index 0000000..fa66fba --- /dev/null +++ b/main.py @@ -0,0 +1,201 @@ +import os +import shutil + +from app_modules.presets import * +from clc.langchain_application import LangChainApplication + + +# 修改成自己的配置!!! +class LangChainCFG: + llm_model_name = 'THUDM/chatglm-6b-int4-qe' # 本地模型文件 or huggingface远程仓库 + embedding_model_name = 'GanymedeNil/text2vec-large-chinese' # 检索模型文件 or huggingface远程仓库 + vector_store_path = './cache' + docs_path = './docs' + kg_vector_stores = { + '中文维基百科': './cache/zh_wikipedia', + '大规模金融研报知识图谱': '.cache/financial_research_reports', + '初始化知识库': '.cache', + } # 可以替换成自己的知识库,如果没有需要设置为None + # kg_vector_stores=None + + +config = LangChainCFG() +application = LangChainApplication(config) + + +def get_file_list(): + if not os.path.exists("docs"): + return [] + return [f for f in os.listdir("docs")] + + +file_list = get_file_list() + + +def upload_file(file): + if not os.path.exists("docs"): + os.mkdir("docs") + filename = os.path.basename(file.name) + shutil.move(file.name, "docs/" + filename) + # file_list首位插入新上传的文件 + file_list.insert(0, filename) + application.source_service.add_document("docs/" + filename) + return gr.Dropdown.update(choices=file_list, value=filename) + + +def set_knowledge(kg_name, history): + try: + application.source_service.load_vector_store(config.kg_vector_stores[kg_name]) + msg_status = f'{kg_name}知识库已成功加载' + except Exception as e: + msg_status = f'{kg_name}知识库未成功加载' + return history + [[None, msg_status]] + + +def clear_session(): + return '', None + + +def predict(input, + large_language_model, + embedding_model, + top_k, + use_web, + history=None): + # print(large_language_model, embedding_model) + print(input) + if history == None: + history = [] + + if use_web == '使用': + web_content = application.source_service.search_web(query=input) + else: + web_content = '' + resp = application.get_knowledge_based_answer( + query=input, + history_len=1, + temperature=0.1, + top_p=0.9, + top_k=top_k, + web_content=web_content, + chat_history=history + ) + history.append((input, resp['result'])) + search_text = '' + for idx, source in enumerate(resp['source_documents'][:4]): + sep = f'----------【搜索结果{idx + 1}:】---------------\n' + search_text += f'{sep}\n{source.page_content}\n\n' + print(search_text) + search_text += "----------【网络检索内容】-----------\n" + search_text += web_content + return '', history, history, search_text + + +with open("assets/custom.css", "r", encoding="utf-8") as f: + customCSS = f.read() +with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: + gr.Markdown("""

Chinese-LangChain

+
+
+ """) + state = gr.State() + + with gr.Row(): + with gr.Column(scale=1): + embedding_model = gr.Dropdown([ + "text2vec-base" + ], + label="Embedding model", + value="text2vec-base") + + large_language_model = gr.Dropdown( + [ + "ChatGLM-6B-int4", + ], + label="large language model", + value="ChatGLM-6B-int4") + + top_k = gr.Slider(1, + 20, + value=4, + step=1, + label="检索top-k文档", + interactive=True) + kg_name = gr.Radio(['中文维基百科', + '大规模金融研报知识图谱', + '初始化知识库' + ], + label="知识库", + value='初始化知识库', + interactive=True) + set_kg_btn = gr.Button("重新加载知识库") + + use_web = gr.Radio(["使用", "不使用"], label="web search", + info="是否使用网络搜索,使用时确保网络通常", + value="不使用" + ) + + file = gr.File(label="将文件上传到知识库库,内容要尽量匹配", + visible=True, + file_types=['.txt', '.md', '.docx', '.pdf'] + ) + + file.upload(upload_file, + inputs=file, + outputs=None) + with gr.Column(scale=4): + with gr.Row(): + chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400) + with gr.Row(): + message = gr.Textbox(label='请输入问题') + with gr.Row(): + clear_history = gr.Button("🧹 清除历史对话") + send = gr.Button("🚀 发送") + with gr.Row(): + gr.Markdown("""提醒:
+ [Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain)
+ 有任何使用问题[Github Issue区](https://github.com/yanqiangmiffy/Chinese-LangChain)进行反馈.
+ """) + with gr.Column(scale=2): + search = gr.Textbox(label='搜索结果') + + set_kg_btn.click( + set_knowledge, + show_progress=True, + inputs=[kg_name, chatbot], + outputs=chatbot + ) + # 发送按钮 提交 + send.click(predict, + inputs=[ + message, large_language_model, + embedding_model, top_k, use_web, + + state + ], + outputs=[message, chatbot, state, search]) + + # 清空历史对话按钮 提交 + clear_history.click(fn=clear_session, + inputs=[], + outputs=[chatbot, state], + queue=False) + + # 输入框 回车 + message.submit(predict, + inputs=[ + message, large_language_model, + embedding_model, top_k, use_web, + state + ], + outputs=[message, chatbot, state, search]) + +demo.queue(concurrency_count=2).launch( + server_name='0.0.0.0', + server_port=8888, + share=False, + show_error=True, + debug=True, + enable_queue=True, + inbrowser=True, +)