2023-04-17 16:20:32 +08:00
|
|
|
|
import os
|
|
|
|
|
import shutil
|
|
|
|
|
|
2023-04-20 02:06:02 +08:00
|
|
|
|
from app_modules.overwrites import postprocess
|
2023-04-19 01:46:46 +08:00
|
|
|
|
from app_modules.presets import *
|
2023-04-17 16:20:32 +08:00
|
|
|
|
from clc.langchain_application import LangChainApplication
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 修改成自己的配置!!!
|
|
|
|
|
class LangChainCFG:
|
2023-04-19 01:15:22 +08:00
|
|
|
|
llm_model_name = 'THUDM/chatglm-6b-int4-qe' # 本地模型文件 or huggingface远程仓库
|
2023-04-20 02:06:02 +08:00
|
|
|
|
embedding_model_name = 'GanymedeNil/text2vec-large-chinese' # 检索模型文件 or huggingface远程仓库
|
2023-04-17 16:20:32 +08:00
|
|
|
|
vector_store_path = './cache'
|
|
|
|
|
docs_path = './docs'
|
2023-04-18 23:45:16 +08:00
|
|
|
|
kg_vector_stores = {
|
2023-04-19 01:15:22 +08:00
|
|
|
|
'中文维基百科': './cache/zh_wikipedia',
|
2023-04-20 02:06:02 +08:00
|
|
|
|
'大规模金融研报': './cache/financial_research_reports',
|
|
|
|
|
'初始化': './cache',
|
2023-04-18 23:45:16 +08:00
|
|
|
|
} # 可以替换成自己的知识库,如果没有需要设置为None
|
|
|
|
|
# kg_vector_stores=None
|
2023-04-20 02:06:02 +08:00
|
|
|
|
patterns = ['模型问答', '知识库问答'] #
|
2023-04-17 16:20:32 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = LangChainCFG()
|
|
|
|
|
application = LangChainApplication(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_file_list():
|
|
|
|
|
if not os.path.exists("docs"):
|
|
|
|
|
return []
|
|
|
|
|
return [f for f in os.listdir("docs")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
file_list = get_file_list()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def upload_file(file):
|
|
|
|
|
if not os.path.exists("docs"):
|
|
|
|
|
os.mkdir("docs")
|
|
|
|
|
filename = os.path.basename(file.name)
|
|
|
|
|
shutil.move(file.name, "docs/" + filename)
|
|
|
|
|
# file_list首位插入新上传的文件
|
|
|
|
|
file_list.insert(0, filename)
|
|
|
|
|
application.source_service.add_document("docs/" + filename)
|
|
|
|
|
return gr.Dropdown.update(choices=file_list, value=filename)
|
|
|
|
|
|
|
|
|
|
|
2023-04-18 23:45:16 +08:00
|
|
|
|
def set_knowledge(kg_name, history):
|
|
|
|
|
try:
|
|
|
|
|
application.source_service.load_vector_store(config.kg_vector_stores[kg_name])
|
|
|
|
|
msg_status = f'{kg_name}知识库已成功加载'
|
|
|
|
|
except Exception as e:
|
|
|
|
|
msg_status = f'{kg_name}知识库未成功加载'
|
|
|
|
|
return history + [[None, msg_status]]
|
|
|
|
|
|
|
|
|
|
|
2023-04-17 16:20:32 +08:00
|
|
|
|
def clear_session():
|
|
|
|
|
return '', None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(input,
|
|
|
|
|
large_language_model,
|
|
|
|
|
embedding_model,
|
2023-04-19 01:15:22 +08:00
|
|
|
|
top_k,
|
|
|
|
|
use_web,
|
2023-04-20 02:06:02 +08:00
|
|
|
|
use_pattern,
|
2023-04-17 16:20:32 +08:00
|
|
|
|
history=None):
|
2023-04-18 02:22:26 +08:00
|
|
|
|
# print(large_language_model, embedding_model)
|
|
|
|
|
print(input)
|
2023-04-17 16:20:32 +08:00
|
|
|
|
if history == None:
|
|
|
|
|
history = []
|
2023-04-19 01:15:22 +08:00
|
|
|
|
|
|
|
|
|
if use_web == '使用':
|
|
|
|
|
web_content = application.source_service.search_web(query=input)
|
|
|
|
|
else:
|
|
|
|
|
web_content = ''
|
2023-04-17 17:31:54 +08:00
|
|
|
|
search_text = ''
|
2023-04-20 02:06:02 +08:00
|
|
|
|
if use_pattern == '模型问答':
|
|
|
|
|
result = application.get_llm_answer(query=input, web_content=web_content)
|
|
|
|
|
history.append((input, result))
|
|
|
|
|
search_text += web_content
|
|
|
|
|
return '', history, history, search_text
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
resp = application.get_knowledge_based_answer(
|
|
|
|
|
query=input,
|
|
|
|
|
history_len=1,
|
|
|
|
|
temperature=0.1,
|
|
|
|
|
top_p=0.9,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
web_content=web_content,
|
|
|
|
|
chat_history=history
|
|
|
|
|
)
|
|
|
|
|
history.append((input, resp['result']))
|
|
|
|
|
for idx, source in enumerate(resp['source_documents'][:4]):
|
|
|
|
|
sep = f'----------【搜索结果{idx + 1}:】---------------\n'
|
|
|
|
|
search_text += f'{sep}\n{source.page_content}\n\n'
|
|
|
|
|
print(search_text)
|
|
|
|
|
search_text += "----------【网络检索内容】-----------\n"
|
|
|
|
|
search_text += web_content
|
|
|
|
|
return '', history, history, search_text
|
2023-04-17 16:20:32 +08:00
|
|
|
|
|
2023-04-19 02:09:00 +08:00
|
|
|
|
|
2023-04-19 01:46:46 +08:00
|
|
|
|
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
|
|
|
|
customCSS = f.read()
|
|
|
|
|
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
2023-04-17 16:20:32 +08:00
|
|
|
|
gr.Markdown("""<h1><center>Chinese-LangChain</center></h1>
|
|
|
|
|
<center><font size=3>
|
|
|
|
|
</center></font>
|
|
|
|
|
""")
|
2023-04-18 02:22:26 +08:00
|
|
|
|
state = gr.State()
|
|
|
|
|
|
2023-04-17 16:20:32 +08:00
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column(scale=1):
|
|
|
|
|
embedding_model = gr.Dropdown([
|
|
|
|
|
"text2vec-base"
|
|
|
|
|
],
|
|
|
|
|
label="Embedding model",
|
|
|
|
|
value="text2vec-base")
|
|
|
|
|
|
|
|
|
|
large_language_model = gr.Dropdown(
|
|
|
|
|
[
|
|
|
|
|
"ChatGLM-6B-int4",
|
|
|
|
|
],
|
|
|
|
|
label="large language model",
|
|
|
|
|
value="ChatGLM-6B-int4")
|
|
|
|
|
|
2023-04-18 17:44:03 +08:00
|
|
|
|
top_k = gr.Slider(1,
|
|
|
|
|
20,
|
2023-04-19 01:15:22 +08:00
|
|
|
|
value=4,
|
2023-04-18 17:44:03 +08:00
|
|
|
|
step=1,
|
2023-04-19 01:15:22 +08:00
|
|
|
|
label="检索top-k文档",
|
2023-04-18 17:44:03 +08:00
|
|
|
|
interactive=True)
|
2023-04-20 02:06:02 +08:00
|
|
|
|
|
|
|
|
|
use_web = gr.Radio(["使用", "不使用"], label="web search",
|
|
|
|
|
info="是否使用网络搜索,使用时确保网络通常",
|
|
|
|
|
value="不使用"
|
|
|
|
|
)
|
|
|
|
|
use_pattern = gr.Radio(
|
|
|
|
|
[
|
|
|
|
|
'模型问答',
|
|
|
|
|
'知识库问答',
|
|
|
|
|
],
|
|
|
|
|
label="模式",
|
|
|
|
|
value='模型问答',
|
|
|
|
|
interactive=True)
|
|
|
|
|
|
2023-04-18 23:45:16 +08:00
|
|
|
|
kg_name = gr.Radio(['中文维基百科',
|
|
|
|
|
'大规模金融研报知识图谱',
|
|
|
|
|
'初始化知识库'
|
|
|
|
|
],
|
2023-04-18 17:44:03 +08:00
|
|
|
|
label="知识库",
|
2023-04-20 02:06:02 +08:00
|
|
|
|
value=None,
|
|
|
|
|
info="使用知识库问答,请加载知识库",
|
2023-04-18 17:44:03 +08:00
|
|
|
|
interactive=True)
|
2023-04-20 02:06:02 +08:00
|
|
|
|
set_kg_btn = gr.Button("加载知识库")
|
2023-04-19 01:15:22 +08:00
|
|
|
|
|
|
|
|
|
file = gr.File(label="将文件上传到知识库库,内容要尽量匹配",
|
2023-04-18 17:44:03 +08:00
|
|
|
|
visible=True,
|
|
|
|
|
file_types=['.txt', '.md', '.docx', '.pdf']
|
|
|
|
|
)
|
2023-04-17 16:20:32 +08:00
|
|
|
|
|
|
|
|
|
with gr.Column(scale=4):
|
|
|
|
|
with gr.Row():
|
2023-04-19 02:09:00 +08:00
|
|
|
|
chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400)
|
|
|
|
|
with gr.Row():
|
|
|
|
|
message = gr.Textbox(label='请输入问题')
|
|
|
|
|
with gr.Row():
|
|
|
|
|
clear_history = gr.Button("🧹 清除历史对话")
|
|
|
|
|
send = gr.Button("🚀 发送")
|
|
|
|
|
with gr.Row():
|
|
|
|
|
gr.Markdown("""提醒:<br>
|
|
|
|
|
[Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain) <br>
|
|
|
|
|
有任何使用问题[Github Issue区](https://github.com/yanqiangmiffy/Chinese-LangChain)进行反馈. <br>
|
|
|
|
|
""")
|
|
|
|
|
with gr.Column(scale=2):
|
|
|
|
|
search = gr.Textbox(label='搜索结果')
|
|
|
|
|
|
2023-04-20 02:06:02 +08:00
|
|
|
|
# ============= 触发动作=============
|
|
|
|
|
file.upload(upload_file,
|
|
|
|
|
inputs=file,
|
|
|
|
|
outputs=None)
|
2023-04-18 23:45:16 +08:00
|
|
|
|
set_kg_btn.click(
|
|
|
|
|
set_knowledge,
|
|
|
|
|
show_progress=True,
|
|
|
|
|
inputs=[kg_name, chatbot],
|
|
|
|
|
outputs=chatbot
|
|
|
|
|
)
|
2023-04-18 02:22:26 +08:00
|
|
|
|
# 发送按钮 提交
|
|
|
|
|
send.click(predict,
|
|
|
|
|
inputs=[
|
2023-04-20 02:06:02 +08:00
|
|
|
|
message,
|
|
|
|
|
large_language_model,
|
|
|
|
|
embedding_model,
|
|
|
|
|
top_k,
|
|
|
|
|
use_web,
|
|
|
|
|
use_pattern,
|
2023-04-19 01:15:22 +08:00
|
|
|
|
state
|
2023-04-18 02:22:26 +08:00
|
|
|
|
],
|
|
|
|
|
outputs=[message, chatbot, state, search])
|
|
|
|
|
|
|
|
|
|
# 清空历史对话按钮 提交
|
|
|
|
|
clear_history.click(fn=clear_session,
|
|
|
|
|
inputs=[],
|
|
|
|
|
outputs=[chatbot, state],
|
|
|
|
|
queue=False)
|
|
|
|
|
|
|
|
|
|
# 输入框 回车
|
|
|
|
|
message.submit(predict,
|
|
|
|
|
inputs=[
|
2023-04-20 02:06:02 +08:00
|
|
|
|
message,
|
|
|
|
|
large_language_model,
|
|
|
|
|
embedding_model,
|
|
|
|
|
top_k,
|
|
|
|
|
use_web,
|
|
|
|
|
use_pattern,
|
2023-04-19 01:15:22 +08:00
|
|
|
|
state
|
2023-04-18 02:22:26 +08:00
|
|
|
|
],
|
|
|
|
|
outputs=[message, chatbot, state, search])
|
2023-04-19 02:09:00 +08:00
|
|
|
|
|
2023-04-18 17:44:03 +08:00
|
|
|
|
demo.queue(concurrency_count=2).launch(
|
|
|
|
|
server_name='0.0.0.0',
|
|
|
|
|
share=False,
|
|
|
|
|
show_error=True,
|
|
|
|
|
debug=True,
|
2023-04-19 22:50:11 +08:00
|
|
|
|
enable_queue=True,
|
|
|
|
|
inbrowser=True,
|
2023-04-18 17:44:03 +08:00
|
|
|
|
)
|