chinese-langchain/main.py

197 lines
6.4 KiB
Python
Raw Normal View History

2023-04-17 16:20:32 +08:00
import os
import shutil
import gradio as gr
from clc.langchain_application import LangChainApplication
2023-04-19 01:15:22 +08:00
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
2023-04-18 02:22:26 +08:00
2023-04-17 16:20:32 +08:00
# 修改成自己的配置!!!
class LangChainCFG:
2023-04-19 01:15:22 +08:00
llm_model_name = 'THUDM/chatglm-6b-int4-qe' # 本地模型文件 or huggingface远程仓库
embedding_model_name = 'GanymedeNil/text2vec-large-chinese' # 检索模型文件 or huggingface远程仓库
2023-04-17 16:20:32 +08:00
vector_store_path = './cache'
docs_path = './docs'
2023-04-18 23:45:16 +08:00
kg_vector_stores = {
2023-04-19 01:15:22 +08:00
'中文维基百科': './cache/zh_wikipedia',
'大规模金融研报知识图谱': '.cache/financial_research_reports',
'初始化知识库': '.cache',
2023-04-18 23:45:16 +08:00
} # 可以替换成自己的知识库如果没有需要设置为None
# kg_vector_stores=None
2023-04-17 16:20:32 +08:00
config = LangChainCFG()
application = LangChainApplication(config)
def get_file_list():
if not os.path.exists("docs"):
return []
return [f for f in os.listdir("docs")]
file_list = get_file_list()
def upload_file(file):
if not os.path.exists("docs"):
os.mkdir("docs")
filename = os.path.basename(file.name)
shutil.move(file.name, "docs/" + filename)
# file_list首位插入新上传的文件
file_list.insert(0, filename)
application.source_service.add_document("docs/" + filename)
return gr.Dropdown.update(choices=file_list, value=filename)
2023-04-18 23:45:16 +08:00
def set_knowledge(kg_name, history):
try:
application.source_service.load_vector_store(config.kg_vector_stores[kg_name])
msg_status = f'{kg_name}知识库已成功加载'
except Exception as e:
msg_status = f'{kg_name}知识库未成功加载'
return history + [[None, msg_status]]
2023-04-17 16:20:32 +08:00
def clear_session():
return '', None
def predict(input,
large_language_model,
embedding_model,
2023-04-19 01:15:22 +08:00
top_k,
use_web,
2023-04-17 16:20:32 +08:00
history=None):
2023-04-18 02:22:26 +08:00
# print(large_language_model, embedding_model)
print(input)
2023-04-17 16:20:32 +08:00
if history == None:
history = []
2023-04-19 01:15:22 +08:00
if use_web == '使用':
web_content = application.source_service.search_web(query=input)
else:
web_content = ''
2023-04-17 16:20:32 +08:00
resp = application.get_knowledge_based_answer(
query=input,
2023-04-18 02:22:26 +08:00
history_len=1,
2023-04-17 16:20:32 +08:00
temperature=0.1,
top_p=0.9,
2023-04-19 01:15:22 +08:00
top_k=top_k,
web_content=web_content,
2023-04-17 16:20:32 +08:00
chat_history=history
)
history.append((input, resp['result']))
2023-04-17 17:31:54 +08:00
search_text = ''
2023-04-18 23:45:16 +08:00
for idx, source in enumerate(resp['source_documents'][:4]):
2023-04-19 01:15:22 +08:00
sep = f'----------【搜索结果{idx + 1}:】---------------\n'
2023-04-18 02:22:26 +08:00
search_text += f'{sep}\n{source.page_content}\n\n'
print(search_text)
2023-04-19 01:15:22 +08:00
search_text += "----------【网络检索内容】-----------\n"
search_text += web_content
2023-04-17 17:31:54 +08:00
return '', history, history, search_text
2023-04-17 16:20:32 +08:00
block = gr.Blocks()
with block as demo:
gr.Markdown("""<h1><center>Chinese-LangChain</center></h1>
<center><font size=3>
</center></font>
""")
2023-04-18 02:22:26 +08:00
state = gr.State()
2023-04-17 16:20:32 +08:00
with gr.Row():
with gr.Column(scale=1):
embedding_model = gr.Dropdown([
"text2vec-base"
],
label="Embedding model",
value="text2vec-base")
large_language_model = gr.Dropdown(
[
"ChatGLM-6B-int4",
],
label="large language model",
value="ChatGLM-6B-int4")
2023-04-18 17:44:03 +08:00
top_k = gr.Slider(1,
20,
2023-04-19 01:15:22 +08:00
value=4,
2023-04-18 17:44:03 +08:00
step=1,
2023-04-19 01:15:22 +08:00
label="检索top-k文档",
2023-04-18 17:44:03 +08:00
interactive=True)
2023-04-18 23:45:16 +08:00
kg_name = gr.Radio(['中文维基百科',
'大规模金融研报知识图谱',
'初始化知识库'
],
2023-04-18 17:44:03 +08:00
label="知识库",
2023-04-19 01:15:22 +08:00
value='初始化知识库',
2023-04-18 17:44:03 +08:00
interactive=True)
2023-04-18 23:45:16 +08:00
set_kg_btn = gr.Button("重新加载知识库")
2023-04-19 01:15:22 +08:00
use_web = gr.Radio(["使用", "不使用"], label="web search", info="是否使用网络搜索,使用时确保网络通常")
file = gr.File(label="将文件上传到知识库库,内容要尽量匹配",
2023-04-18 17:44:03 +08:00
visible=True,
file_types=['.txt', '.md', '.docx', '.pdf']
)
2023-04-17 16:20:32 +08:00
file.upload(upload_file,
inputs=file,
2023-04-18 17:44:03 +08:00
outputs=None)
2023-04-17 16:20:32 +08:00
with gr.Column(scale=4):
with gr.Row():
2023-04-17 17:31:54 +08:00
with gr.Column(scale=4):
chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400)
message = gr.Textbox(label='请输入问题')
with gr.Row():
clear_history = gr.Button("🧹 清除历史对话")
send = gr.Button("🚀 发送")
with gr.Column(scale=2):
search = gr.Textbox(label='搜索结果')
2023-04-18 23:45:16 +08:00
set_kg_btn.click(
set_knowledge,
show_progress=True,
inputs=[kg_name, chatbot],
outputs=chatbot
)
2023-04-18 02:22:26 +08:00
# 发送按钮 提交
send.click(predict,
inputs=[
message, large_language_model,
2023-04-19 01:15:22 +08:00
embedding_model, top_k, use_web,
state
2023-04-18 02:22:26 +08:00
],
outputs=[message, chatbot, state, search])
# 清空历史对话按钮 提交
clear_history.click(fn=clear_session,
inputs=[],
outputs=[chatbot, state],
queue=False)
# 输入框 回车
message.submit(predict,
inputs=[
message, large_language_model,
2023-04-19 01:15:22 +08:00
embedding_model, top_k, use_web,
state
2023-04-18 02:22:26 +08:00
],
outputs=[message, chatbot, state, search])
2023-04-18 21:23:56 +08:00
gr.Markdown("""提醒:<br>
2023-04-18 23:45:16 +08:00
[Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain) <br>
2023-04-18 21:23:56 +08:00
有任何使用问题[Github Issue区](https://github.com/yanqiangmiffy/Chinese-LangChain)进行反馈. <br>
""")
2023-04-18 17:44:03 +08:00
demo.queue(concurrency_count=2).launch(
server_name='0.0.0.0',
server_port=8888,
share=False,
show_error=True,
debug=True,
enable_queue=True
)