From 65d97b14b43d82ef714de64e585da61bfdcc6803 Mon Sep 17 00:00:00 2001 From: yanqiangmiffy <1185918903@qq.com> Date: Thu, 20 Apr 2023 02:06:02 +0800 Subject: [PATCH] =?UTF-8?q?feature@=E6=B7=BB=E5=8A=A0=E9=97=AE=E7=AD=94?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E9=80=89=E6=8B=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 ++ app.py | 100 ++++++++++++++++++++++------------- assets/custom.css | 6 +-- clc/config.py | 4 +- clc/langchain_application.py | 48 ++++++++++------- clc/source_service.py | 17 +++--- main.py | 98 ++++++++++++++++++++++------------ 7 files changed, 176 insertions(+), 101 deletions(-) diff --git a/README.md b/README.md index 1a2ac79..6a3bad8 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ colorTo: yellow pinned: true app_file: app.py --- + # Chinese-LangChain > Chinese-LangChain:中文langchain项目,基于ChatGLM-6b+langchain实现本地化知识库检索与智能答案生成 @@ -55,6 +56,8 @@ python main.py ## 🚀 特性 +- 📝 2023/04/20 支持模型问答与检索问答模式切换 +- 📝 2023/04/20 感谢HF官方提供免费算力,添加HuggingFace Spaces在线体验[[🤗 DEMO](https://huggingface.co/spaces/ChallengeHub/Chinese-LangChain) - 📝 2023/04/19 发布45万Wikipedia的文本预处理语料以及FAISS索引向量 - 🐯 2023/04/19 引入ChuanhuChatGPT皮肤 - 📱 2023/04/19 增加web search功能,需要确保网络畅通!(感谢[@wanghao07456](https://github.com/wanghao07456),提供的idea) @@ -87,6 +90,7 @@ python main.py * [x] 支持加载不同知识库 * [x] 支持检索结果与LLM生成结果对比 * [ ] 支持检索生成结果与原始LLM生成结果对比 +* [ ] 支持模型问答与检索问答 * [ ] 检索结果过滤与排序 * [x] 互联网检索结果接入 * [ ] 模型初始化有问题 diff --git a/app.py b/app.py index 7776c8a..518e482 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,7 @@ import os import shutil +from app_modules.overwrites import postprocess from app_modules.presets import * from clc.langchain_application import LangChainApplication @@ -8,15 +9,16 @@ from clc.langchain_application import LangChainApplication # 修改成自己的配置!!! class LangChainCFG: llm_model_name = 'THUDM/chatglm-6b-int4-qe' # 本地模型文件 or huggingface远程仓库 - embedding_model_name = 'GanymedeNil/text2vec-base-chinese' # 检索模型文件 or huggingface远程仓库 + embedding_model_name = 'GanymedeNil/text2vec-large-chinese' # 检索模型文件 or huggingface远程仓库 vector_store_path = './cache' docs_path = './docs' kg_vector_stores = { '中文维基百科': './cache/zh_wikipedia', - '大规模金融研报知识图谱': '.cache/financial_research_reports', - '初始化知识库': '.cache', + '大规模金融研报': './cache/financial_research_reports', + '初始化': './cache', } # 可以替换成自己的知识库,如果没有需要设置为None # kg_vector_stores=None + patterns = ['模型问答', '知识库问答'] # config = LangChainCFG() @@ -61,6 +63,7 @@ def predict(input, embedding_model, top_k, use_web, + use_pattern, history=None): # print(large_language_model, embedding_model) print(input) @@ -71,24 +74,31 @@ def predict(input, web_content = application.source_service.search_web(query=input) else: web_content = '' - resp = application.get_knowledge_based_answer( - query=input, - history_len=1, - temperature=0.1, - top_p=0.9, - top_k=top_k, - web_content=web_content, - chat_history=history - ) - history.append((input, resp['result'])) search_text = '' - for idx, source in enumerate(resp['source_documents'][:4]): - sep = f'----------【搜索结果{idx + 1}:】---------------\n' - search_text += f'{sep}\n{source.page_content}\n\n' - print(search_text) - search_text += "----------【网络检索内容】-----------\n" - search_text += web_content - return '', history, history, search_text + if use_pattern == '模型问答': + result = application.get_llm_answer(query=input, web_content=web_content) + history.append((input, result)) + search_text += web_content + return '', history, history, search_text + + else: + resp = application.get_knowledge_based_answer( + query=input, + history_len=1, + temperature=0.1, + top_p=0.9, + top_k=top_k, + web_content=web_content, + chat_history=history + ) + history.append((input, resp['result'])) + for idx, source in enumerate(resp['source_documents'][:4]): + sep = f'----------【搜索结果{idx + 1}:】---------------\n' + search_text += f'{sep}\n{source.page_content}\n\n' + print(search_text) + search_text += "----------【网络检索内容】-----------\n" + search_text += web_content + return '', history, history, search_text with open("assets/custom.css", "r", encoding="utf-8") as f: @@ -121,28 +131,35 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: step=1, label="检索top-k文档", interactive=True) - kg_name = gr.Radio(['中文维基百科', - '大规模金融研报知识图谱', - '初始化知识库' - ], - label="知识库", - value='初始化知识库', - interactive=True) - set_kg_btn = gr.Button("重新加载知识库") use_web = gr.Radio(["使用", "不使用"], label="web search", info="是否使用网络搜索,使用时确保网络通常", value="不使用" ) + use_pattern = gr.Radio( + [ + '模型问答', + '知识库问答', + ], + label="模式", + value='模型问答', + interactive=True) + + kg_name = gr.Radio(['中文维基百科', + '大规模金融研报知识图谱', + '初始化知识库' + ], + label="知识库", + value=None, + info="使用知识库问答,请加载知识库", + interactive=True) + set_kg_btn = gr.Button("加载知识库") file = gr.File(label="将文件上传到知识库库,内容要尽量匹配", visible=True, file_types=['.txt', '.md', '.docx', '.pdf'] ) - file.upload(upload_file, - inputs=file, - outputs=None) with gr.Column(scale=4): with gr.Row(): chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400) @@ -159,6 +176,10 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: with gr.Column(scale=2): search = gr.Textbox(label='搜索结果') + # ============= 触发动作============= + file.upload(upload_file, + inputs=file, + outputs=None) set_kg_btn.click( set_knowledge, show_progress=True, @@ -168,9 +189,12 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: # 发送按钮 提交 send.click(predict, inputs=[ - message, large_language_model, - embedding_model, top_k, use_web, - + message, + large_language_model, + embedding_model, + top_k, + use_web, + use_pattern, state ], outputs=[message, chatbot, state, search]) @@ -184,8 +208,12 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: # 输入框 回车 message.submit(predict, inputs=[ - message, large_language_model, - embedding_model, top_k, use_web, + message, + large_language_model, + embedding_model, + top_k, + use_web, + use_pattern, state ], outputs=[message, chatbot, state, search]) diff --git a/assets/custom.css b/assets/custom.css index f8e3e4e..9c18b7c 100644 --- a/assets/custom.css +++ b/assets/custom.css @@ -1,5 +1,5 @@ :root { - --chatbot-color-light: #F3F3F3; + --chatbot-color-light: rgba(255, 255, 255, 0.08); --chatbot-color-dark: #121111; } @@ -40,7 +40,7 @@ ol:not(.options), ul:not(.options) { color: #000000 !important; } [data-testid = "bot"] { - background-color: #FFFFFF !important; + background-color: rgba(255, 255, 255, 0.08) !important; } [data-testid = "user"] { background-color: #95EC69 !important; @@ -49,7 +49,7 @@ ol:not(.options), ul:not(.options) { /* Dark mode */ .dark #chuanhu_chatbot { background-color: var(--chatbot-color-dark) !important; - color: #FFFFFF !important; + color: rgba(255, 255, 255, 0.08) !important; } .dark [data-testid = "bot"] { background-color: #2C2C2C !important; diff --git a/clc/config.py b/clc/config.py index 3426f5c..099fefe 100644 --- a/clc/config.py +++ b/clc/config.py @@ -12,7 +12,7 @@ class LangChainCFG: - llm_model_name = 'chatglm-6b' # 本地模型文件 or huggingface远程仓库 - embedding_model_name = 'text2vec-large-chinese' # 检索模型文件 or huggingface远程仓库 + llm_model_name = 'THUDM/chatglm-6b-int4-qe' # 本地模型文件 or huggingface远程仓库 + embedding_model_name = 'GanymedeNil/text2vec-large-chinese' # 检索模型文件 or huggingface远程仓库 vector_store_path = '.' docs_path = './docs' diff --git a/clc/langchain_application.py b/clc/langchain_application.py index 9090938..a2a8a4e 100644 --- a/clc/langchain_application.py +++ b/clc/langchain_application.py @@ -9,10 +9,10 @@ @software: PyCharm @description: coding.. """ - from langchain.chains import RetrievalQA from langchain.prompts.prompt import PromptTemplate +from clc.config import LangChainCFG from clc.gpt_service import ChatGLMService from clc.source_service import SourceService @@ -23,15 +23,16 @@ class LangChainApplication(object): self.llm_service = ChatGLMService() self.llm_service.load_model(model_name_or_path=self.config.llm_model_name) self.source_service = SourceService(config) - if self.config.kg_vector_stores is None: - print("init a source vector store") - self.source_service.init_source_vector() - else: - print("load zh_wikipedia source vector store ") - try: - self.source_service.load_vector_store(self.config.kg_vector_stores['初始化知识库']) - except Exception as e: - self.source_service.init_source_vector() + + # if self.config.kg_vector_stores is None: + # print("init a source vector store") + # self.source_service.init_source_vector() + # else: + # print("load zh_wikipedia source vector store ") + # try: + # self.source_service.load_vector_store(self.config.kg_vector_stores['初始化知识库']) + # except Exception as e: + # self.source_service.init_source_vector() def get_knowledge_based_answer(self, query, history_len=5, @@ -75,11 +76,22 @@ class LangChainApplication(object): result = knowledge_chain({"query": query}) return result -# if __name__ == '__main__': -# config = LangChainCFG() -# application = LangChainApplication(config) -# result = application.get_knowledge_based_answer('马保国是谁') -# print(result) -# application.source_service.add_document('/home/searchgpt/yq/Knowledge-ChatGLM/docs/added/马保国.txt') -# result = application.get_knowledge_based_answer('马保国是谁') -# print(result) + def get_llm_answer(self, query='', web_content=''): + if web_content: + prompt = f'基于网络检索内容:{web_content},回答以下问题{query}' + else: + prompt = query + result = self.llm_service._call(prompt) + return result + + +if __name__ == '__main__': + config = LangChainCFG() + application = LangChainApplication(config) + # result = application.get_knowledge_based_answer('马保国是谁') + # print(result) + # application.source_service.add_document('/home/searchgpt/yq/Knowledge-ChatGLM/docs/added/马保国.txt') + # result = application.get_knowledge_based_answer('马保国是谁') + # print(result) + result = application.get_llm_answer('马保国是谁') + print(result) diff --git a/clc/source_service.py b/clc/source_service.py index 6e5590b..12d9d46 100644 --- a/clc/source_service.py +++ b/clc/source_service.py @@ -13,7 +13,6 @@ import os from duckduckgo_search import ddg -from duckduckgo_search.utils import SESSION from langchain.document_loaders import UnstructuredFileLoader from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.vectorstores import FAISS @@ -61,12 +60,16 @@ class SourceService(object): # "http": f"socks5h://localhost:7890", # "https": f"socks5h://localhost:7890" # } - results = ddg(query) - web_content = '' - if results: - for result in results: - web_content += result['body'] - return web_content + try: + results = ddg(query) + web_content = '' + if results: + for result in results: + web_content += result['body'] + return web_content + except Exception as e: + print(f"网络检索异常:{query}") + return '' # if __name__ == '__main__': # config = LangChainCFG() # source_service = SourceService(config) diff --git a/main.py b/main.py index fa66fba..c656294 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import os import shutil +from app_modules.overwrites import postprocess from app_modules.presets import * from clc.langchain_application import LangChainApplication @@ -13,10 +14,11 @@ class LangChainCFG: docs_path = './docs' kg_vector_stores = { '中文维基百科': './cache/zh_wikipedia', - '大规模金融研报知识图谱': '.cache/financial_research_reports', - '初始化知识库': '.cache', + '大规模金融研报': './cache/financial_research_reports', + '初始化': './cache', } # 可以替换成自己的知识库,如果没有需要设置为None # kg_vector_stores=None + patterns = ['模型问答', '知识库问答'] # config = LangChainCFG() @@ -61,6 +63,7 @@ def predict(input, embedding_model, top_k, use_web, + use_pattern, history=None): # print(large_language_model, embedding_model) print(input) @@ -71,24 +74,31 @@ def predict(input, web_content = application.source_service.search_web(query=input) else: web_content = '' - resp = application.get_knowledge_based_answer( - query=input, - history_len=1, - temperature=0.1, - top_p=0.9, - top_k=top_k, - web_content=web_content, - chat_history=history - ) - history.append((input, resp['result'])) search_text = '' - for idx, source in enumerate(resp['source_documents'][:4]): - sep = f'----------【搜索结果{idx + 1}:】---------------\n' - search_text += f'{sep}\n{source.page_content}\n\n' - print(search_text) - search_text += "----------【网络检索内容】-----------\n" - search_text += web_content - return '', history, history, search_text + if use_pattern == '模型问答': + result = application.get_llm_answer(query=input, web_content=web_content) + history.append((input, result)) + search_text += web_content + return '', history, history, search_text + + else: + resp = application.get_knowledge_based_answer( + query=input, + history_len=1, + temperature=0.1, + top_p=0.9, + top_k=top_k, + web_content=web_content, + chat_history=history + ) + history.append((input, resp['result'])) + for idx, source in enumerate(resp['source_documents'][:4]): + sep = f'----------【搜索结果{idx + 1}:】---------------\n' + search_text += f'{sep}\n{source.page_content}\n\n' + print(search_text) + search_text += "----------【网络检索内容】-----------\n" + search_text += web_content + return '', history, history, search_text with open("assets/custom.css", "r", encoding="utf-8") as f: @@ -121,28 +131,35 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: step=1, label="检索top-k文档", interactive=True) - kg_name = gr.Radio(['中文维基百科', - '大规模金融研报知识图谱', - '初始化知识库' - ], - label="知识库", - value='初始化知识库', - interactive=True) - set_kg_btn = gr.Button("重新加载知识库") use_web = gr.Radio(["使用", "不使用"], label="web search", info="是否使用网络搜索,使用时确保网络通常", value="不使用" ) + use_pattern = gr.Radio( + [ + '模型问答', + '知识库问答', + ], + label="模式", + value='模型问答', + interactive=True) + + kg_name = gr.Radio(['中文维基百科', + '大规模金融研报知识图谱', + '初始化知识库' + ], + label="知识库", + value=None, + info="使用知识库问答,请加载知识库", + interactive=True) + set_kg_btn = gr.Button("加载知识库") file = gr.File(label="将文件上传到知识库库,内容要尽量匹配", visible=True, file_types=['.txt', '.md', '.docx', '.pdf'] ) - file.upload(upload_file, - inputs=file, - outputs=None) with gr.Column(scale=4): with gr.Row(): chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400) @@ -159,6 +176,10 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: with gr.Column(scale=2): search = gr.Textbox(label='搜索结果') + # ============= 触发动作============= + file.upload(upload_file, + inputs=file, + outputs=None) set_kg_btn.click( set_knowledge, show_progress=True, @@ -168,9 +189,12 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: # 发送按钮 提交 send.click(predict, inputs=[ - message, large_language_model, - embedding_model, top_k, use_web, - + message, + large_language_model, + embedding_model, + top_k, + use_web, + use_pattern, state ], outputs=[message, chatbot, state, search]) @@ -184,8 +208,12 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: # 输入框 回车 message.submit(predict, inputs=[ - message, large_language_model, - embedding_model, top_k, use_web, + message, + large_language_model, + embedding_model, + top_k, + use_web, + use_pattern, state ], outputs=[message, chatbot, state, search])