chinese-langchain/main.py

181 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
import gradio as gr
from clc.langchain_application import LangChainApplication
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
# 修改成自己的配置!!!
class LangChainCFG:
llm_model_name = '../../pretrained_models/chatglm-6b-int4-qe' # 本地模型文件 or huggingface远程仓库
embedding_model_name = '../../pretrained_models/text2vec-large-chinese' # 检索模型文件 or huggingface远程仓库
vector_store_path = './cache'
docs_path = './docs'
kg_vector_stores = {
'中文维基百科': '/root/GoMall/Knowledge-ChatGLM/cache/zh_wikipedia',
'大规模金融研报知识图谱': '/root/GoMall/Knowledge-ChatGLM/cache/financial_research_reports',
'初始化知识库': '/root/GoMall/Knowledge-ChatGLM/cache',
} # 可以替换成自己的知识库如果没有需要设置为None
# kg_vector_stores=None
config = LangChainCFG()
application = LangChainApplication(config)
def get_file_list():
if not os.path.exists("docs"):
return []
return [f for f in os.listdir("docs")]
file_list = get_file_list()
def upload_file(file):
if not os.path.exists("docs"):
os.mkdir("docs")
filename = os.path.basename(file.name)
shutil.move(file.name, "docs/" + filename)
# file_list首位插入新上传的文件
file_list.insert(0, filename)
application.source_service.add_document("docs/" + filename)
return gr.Dropdown.update(choices=file_list, value=filename)
def set_knowledge(kg_name, history):
try:
application.source_service.load_vector_store(config.kg_vector_stores[kg_name])
msg_status = f'{kg_name}知识库已成功加载'
except Exception as e:
msg_status = f'{kg_name}知识库未成功加载'
return history + [[None, msg_status]]
def clear_session():
return '', None
def predict(input,
large_language_model,
embedding_model,
history=None):
# print(large_language_model, embedding_model)
print(input)
if history == None:
history = []
resp = application.get_knowledge_based_answer(
query=input,
history_len=1,
temperature=0.1,
top_p=0.9,
chat_history=history
)
history.append((input, resp['result']))
search_text = ''
for idx, source in enumerate(resp['source_documents'][:4]):
sep = f'----------【搜索结果{idx+1}:】---------------\n'
search_text += f'{sep}\n{source.page_content}\n\n'
print(search_text)
return '', history, history, search_text
block = gr.Blocks()
with block as demo:
gr.Markdown("""<h1><center>Chinese-LangChain</center></h1>
<center><font size=3>
</center></font>
""")
state = gr.State()
with gr.Row():
with gr.Column(scale=1):
embedding_model = gr.Dropdown([
"text2vec-base"
],
label="Embedding model",
value="text2vec-base")
large_language_model = gr.Dropdown(
[
"ChatGLM-6B-int4",
],
label="large language model",
value="ChatGLM-6B-int4")
top_k = gr.Slider(1,
20,
value=2,
step=1,
label="向量匹配 top k",
interactive=True)
kg_name = gr.Radio(['中文维基百科',
'大规模金融研报知识图谱',
'初始化知识库'
],
label="知识库",
value='中文维基百科',
interactive=True)
set_kg_btn = gr.Button("重新加载知识库")
file = gr.File(label="将文件上传到数据库",
visible=True,
file_types=['.txt', '.md', '.docx', '.pdf']
)
file.upload(upload_file,
inputs=file,
outputs=None)
with gr.Column(scale=4):
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400)
message = gr.Textbox(label='请输入问题')
with gr.Row():
clear_history = gr.Button("🧹 清除历史对话")
send = gr.Button("🚀 发送")
with gr.Column(scale=2):
search = gr.Textbox(label='搜索结果')
set_kg_btn.click(
set_knowledge,
show_progress=True,
inputs=[kg_name, chatbot],
outputs=chatbot
)
# 发送按钮 提交
send.click(predict,
inputs=[
message, large_language_model,
embedding_model, state
],
outputs=[message, chatbot, state, search])
# 清空历史对话按钮 提交
clear_history.click(fn=clear_session,
inputs=[],
outputs=[chatbot, state],
queue=False)
# 输入框 回车
message.submit(predict,
inputs=[
message, large_language_model,
embedding_model, state
],
outputs=[message, chatbot, state, search])
gr.Markdown("""提醒:<br>
[Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain) <br>
有任何使用问题[Github Issue区](https://github.com/yanqiangmiffy/Chinese-LangChain)进行反馈. <br>
""")
demo.queue(concurrency_count=2).launch(
server_name='0.0.0.0',
server_port=8888,
share=False,
show_error=True,
debug=True,
enable_queue=True
)