feature@添加websearch
This commit is contained in:
parent
0e2455982e
commit
3c24b5a1a6
|
@ -2,12 +2,15 @@
|
||||||
|
|
||||||
> Chinese-LangChain:中文langchain项目,基于ChatGLM-6b+langchain实现本地化知识库检索与智能答案生成
|
> Chinese-LangChain:中文langchain项目,基于ChatGLM-6b+langchain实现本地化知识库检索与智能答案生成
|
||||||
|
|
||||||
|
俗称:小必应,Q.Talk,强聊,QiangTalk
|
||||||
|
|
||||||
## 🔥 效果演示
|
## 🔥 效果演示
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## 🚀 特性
|
## 🚀 特性
|
||||||
|
|
||||||
|
- 🚀 2023/04/19 增加web search功能,需要确保网络畅通!
|
||||||
- 🚀 2023/04/18 webui增加知识库选择功能
|
- 🚀 2023/04/18 webui增加知识库选择功能
|
||||||
- 🚀 2023/04/18 修复推理预测超时5s报错问题
|
- 🚀 2023/04/18 修复推理预测超时5s报错问题
|
||||||
- 🎉 2023/04/17 支持多种文档上传与内容解析:pdf、docx,ppt等
|
- 🎉 2023/04/17 支持多种文档上传与内容解析:pdf、docx,ppt等
|
||||||
|
@ -29,7 +32,7 @@
|
||||||
* [x] 支持检索结果与LLM生成结果对比
|
* [x] 支持检索结果与LLM生成结果对比
|
||||||
* [ ] 支持检索生成结果与原始LLM生成结果对比
|
* [ ] 支持检索生成结果与原始LLM生成结果对比
|
||||||
* [ ] 检索结果过滤与排序
|
* [ ] 检索结果过滤与排序
|
||||||
* [ ] 互联网检索结果接入
|
* [x] 互联网检索结果接入
|
||||||
* [ ] 模型初始化有问题
|
* [ ] 模型初始化有问题
|
||||||
* [ ] 增加非LangChain策略
|
* [ ] 增加非LangChain策略
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -37,13 +37,24 @@ class LangChainApplication(object):
|
||||||
history_len=5,
|
history_len=5,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
top_k=4,
|
||||||
|
web_content='',
|
||||||
chat_history=[]):
|
chat_history=[]):
|
||||||
prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
|
if web_content:
|
||||||
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
|
||||||
已知内容:
|
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
||||||
{context}
|
已知网络检索内容:{web_content}""" + """
|
||||||
问题:
|
已知内容:
|
||||||
{question}"""
|
{context}
|
||||||
|
问题:
|
||||||
|
{question}"""
|
||||||
|
else:
|
||||||
|
prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
|
||||||
|
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
||||||
|
已知内容:
|
||||||
|
{context}
|
||||||
|
问题:
|
||||||
|
{question}"""
|
||||||
prompt = PromptTemplate(template=prompt_template,
|
prompt = PromptTemplate(template=prompt_template,
|
||||||
input_variables=["context", "question"])
|
input_variables=["context", "question"])
|
||||||
self.llm_service.history = chat_history[-history_len:] if history_len > 0 else []
|
self.llm_service.history = chat_history[-history_len:] if history_len > 0 else []
|
||||||
|
@ -54,7 +65,7 @@ class LangChainApplication(object):
|
||||||
knowledge_chain = RetrievalQA.from_llm(
|
knowledge_chain = RetrievalQA.from_llm(
|
||||||
llm=self.llm_service,
|
llm=self.llm_service,
|
||||||
retriever=self.source_service.vector_store.as_retriever(
|
retriever=self.source_service.vector_store.as_retriever(
|
||||||
search_kwargs={"k": 4}),
|
search_kwargs={"k": top_k}),
|
||||||
prompt=prompt)
|
prompt=prompt)
|
||||||
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
||||||
input_variables=["page_content"], template="{page_content}")
|
input_variables=["page_content"], template="{page_content}")
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from duckduckgo_search import ddg
|
||||||
|
from duckduckgo_search.utils import SESSION
|
||||||
from langchain.document_loaders import UnstructuredFileLoader
|
from langchain.document_loaders import UnstructuredFileLoader
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores import FAISS
|
||||||
|
@ -53,6 +55,18 @@ class SourceService(object):
|
||||||
self.vector_store = FAISS.load_local(path, self.embeddings)
|
self.vector_store = FAISS.load_local(path, self.embeddings)
|
||||||
return self.vector_store
|
return self.vector_store
|
||||||
|
|
||||||
|
def search_web(self, query):
|
||||||
|
|
||||||
|
SESSION.proxies = {
|
||||||
|
"http": f"socks5h://localhost:7890",
|
||||||
|
"https": f"socks5h://localhost:7890"
|
||||||
|
}
|
||||||
|
results = ddg(query)
|
||||||
|
web_content = ''
|
||||||
|
if results:
|
||||||
|
for result in results:
|
||||||
|
web_content += result['body']
|
||||||
|
return web_content
|
||||||
# if __name__ == '__main__':
|
# if __name__ == '__main__':
|
||||||
# config = LangChainCFG()
|
# config = LangChainCFG()
|
||||||
# source_service = SourceService(config)
|
# source_service = SourceService(config)
|
||||||
|
|
42
main.py
42
main.py
|
@ -5,19 +5,19 @@ import gradio as gr
|
||||||
|
|
||||||
from clc.langchain_application import LangChainApplication
|
from clc.langchain_application import LangChainApplication
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
||||||
|
|
||||||
|
|
||||||
# 修改成自己的配置!!!
|
# 修改成自己的配置!!!
|
||||||
class LangChainCFG:
|
class LangChainCFG:
|
||||||
llm_model_name = '../../pretrained_models/chatglm-6b-int4-qe' # 本地模型文件 or huggingface远程仓库
|
llm_model_name = 'THUDM/chatglm-6b-int4-qe' # 本地模型文件 or huggingface远程仓库
|
||||||
embedding_model_name = '../../pretrained_models/text2vec-large-chinese' # 检索模型文件 or huggingface远程仓库
|
embedding_model_name = 'GanymedeNil/text2vec-large-chinese' # 检索模型文件 or huggingface远程仓库
|
||||||
vector_store_path = './cache'
|
vector_store_path = './cache'
|
||||||
docs_path = './docs'
|
docs_path = './docs'
|
||||||
kg_vector_stores = {
|
kg_vector_stores = {
|
||||||
'中文维基百科': '/root/GoMall/Knowledge-ChatGLM/cache/zh_wikipedia',
|
'中文维基百科': './cache/zh_wikipedia',
|
||||||
'大规模金融研报知识图谱': '/root/GoMall/Knowledge-ChatGLM/cache/financial_research_reports',
|
'大规模金融研报知识图谱': '.cache/financial_research_reports',
|
||||||
'初始化知识库': '/root/GoMall/Knowledge-ChatGLM/cache',
|
'初始化知识库': '.cache',
|
||||||
} # 可以替换成自己的知识库,如果没有需要设置为None
|
} # 可以替换成自己的知识库,如果没有需要设置为None
|
||||||
# kg_vector_stores=None
|
# kg_vector_stores=None
|
||||||
|
|
||||||
|
@ -62,24 +62,35 @@ def clear_session():
|
||||||
def predict(input,
|
def predict(input,
|
||||||
large_language_model,
|
large_language_model,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
top_k,
|
||||||
|
use_web,
|
||||||
history=None):
|
history=None):
|
||||||
# print(large_language_model, embedding_model)
|
# print(large_language_model, embedding_model)
|
||||||
print(input)
|
print(input)
|
||||||
if history == None:
|
if history == None:
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
|
if use_web == '使用':
|
||||||
|
web_content = application.source_service.search_web(query=input)
|
||||||
|
else:
|
||||||
|
web_content = ''
|
||||||
resp = application.get_knowledge_based_answer(
|
resp = application.get_knowledge_based_answer(
|
||||||
query=input,
|
query=input,
|
||||||
history_len=1,
|
history_len=1,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
top_k=top_k,
|
||||||
|
web_content=web_content,
|
||||||
chat_history=history
|
chat_history=history
|
||||||
)
|
)
|
||||||
history.append((input, resp['result']))
|
history.append((input, resp['result']))
|
||||||
search_text = ''
|
search_text = ''
|
||||||
for idx, source in enumerate(resp['source_documents'][:4]):
|
for idx, source in enumerate(resp['source_documents'][:4]):
|
||||||
sep = f'----------【搜索结果{idx+1}:】---------------\n'
|
sep = f'----------【搜索结果{idx + 1}:】---------------\n'
|
||||||
search_text += f'{sep}\n{source.page_content}\n\n'
|
search_text += f'{sep}\n{source.page_content}\n\n'
|
||||||
print(search_text)
|
print(search_text)
|
||||||
|
search_text += "----------【网络检索内容】-----------\n"
|
||||||
|
search_text += web_content
|
||||||
return '', history, history, search_text
|
return '', history, history, search_text
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,20 +119,22 @@ with block as demo:
|
||||||
|
|
||||||
top_k = gr.Slider(1,
|
top_k = gr.Slider(1,
|
||||||
20,
|
20,
|
||||||
value=2,
|
value=4,
|
||||||
step=1,
|
step=1,
|
||||||
label="向量匹配 top k",
|
label="检索top-k文档",
|
||||||
interactive=True)
|
interactive=True)
|
||||||
kg_name = gr.Radio(['中文维基百科',
|
kg_name = gr.Radio(['中文维基百科',
|
||||||
'大规模金融研报知识图谱',
|
'大规模金融研报知识图谱',
|
||||||
'初始化知识库'
|
'初始化知识库'
|
||||||
],
|
],
|
||||||
label="知识库",
|
label="知识库",
|
||||||
value='中文维基百科',
|
value='初始化知识库',
|
||||||
interactive=True)
|
interactive=True)
|
||||||
set_kg_btn = gr.Button("重新加载知识库")
|
set_kg_btn = gr.Button("重新加载知识库")
|
||||||
|
|
||||||
file = gr.File(label="将文件上传到数据库",
|
use_web = gr.Radio(["使用", "不使用"], label="web search", info="是否使用网络搜索,使用时确保网络通常")
|
||||||
|
|
||||||
|
file = gr.File(label="将文件上传到知识库库,内容要尽量匹配",
|
||||||
visible=True,
|
visible=True,
|
||||||
file_types=['.txt', '.md', '.docx', '.pdf']
|
file_types=['.txt', '.md', '.docx', '.pdf']
|
||||||
)
|
)
|
||||||
|
@ -149,7 +162,9 @@ with block as demo:
|
||||||
send.click(predict,
|
send.click(predict,
|
||||||
inputs=[
|
inputs=[
|
||||||
message, large_language_model,
|
message, large_language_model,
|
||||||
embedding_model, state
|
embedding_model, top_k, use_web,
|
||||||
|
|
||||||
|
state
|
||||||
],
|
],
|
||||||
outputs=[message, chatbot, state, search])
|
outputs=[message, chatbot, state, search])
|
||||||
|
|
||||||
|
@ -163,7 +178,8 @@ with block as demo:
|
||||||
message.submit(predict,
|
message.submit(predict,
|
||||||
inputs=[
|
inputs=[
|
||||||
message, large_language_model,
|
message, large_language_model,
|
||||||
embedding_model, state
|
embedding_model, top_k, use_web,
|
||||||
|
state
|
||||||
],
|
],
|
||||||
outputs=[message, chatbot, state, search])
|
outputs=[message, chatbot, state, search])
|
||||||
gr.Markdown("""提醒:<br>
|
gr.Markdown("""提醒:<br>
|
||||||
|
|
160
requirements.txt
160
requirements.txt
|
@ -1,153 +1,7 @@
|
||||||
aiofiles==23.1.0
|
langchain
|
||||||
aiohttp==3.8.4
|
gradio
|
||||||
aiosignal==1.3.1
|
transformers
|
||||||
altair==4.2.2
|
sentence_transformers
|
||||||
antlr4-python3-runtime==4.9.3
|
faiss-cpu
|
||||||
anyio==3.6.2
|
unstructured
|
||||||
argilla==1.6.0
|
duckduckgo_search
|
||||||
async-timeout==4.0.2
|
|
||||||
attrs==23.1.0
|
|
||||||
backoff==2.2.1
|
|
||||||
beautifulsoup4==4.12.2
|
|
||||||
brotlipy==0.7.0
|
|
||||||
cachetools==5.3.0
|
|
||||||
cchardet==2.1.7
|
|
||||||
certifi
|
|
||||||
cffi
|
|
||||||
chardet==5.1.0
|
|
||||||
charset-normalizer==3.1.0
|
|
||||||
click==8.1.3
|
|
||||||
coloredlogs==15.0.1
|
|
||||||
commonmark==0.9.1
|
|
||||||
contourpy==1.0.7
|
|
||||||
cpm-kernels==1.0.11
|
|
||||||
cryptography
|
|
||||||
cycler==0.11.0
|
|
||||||
dataclasses-json==0.5.7
|
|
||||||
Deprecated==1.2.13
|
|
||||||
effdet==0.3.0
|
|
||||||
entrypoints==0.4
|
|
||||||
et-xmlfile==1.1.0
|
|
||||||
faiss-gpu==1.7.2
|
|
||||||
fastapi==0.95.1
|
|
||||||
ffmpy==0.3.0
|
|
||||||
filelock==3.11.0
|
|
||||||
flatbuffers==23.3.3
|
|
||||||
flit_core
|
|
||||||
fonttools==4.39.3
|
|
||||||
frozenlist==1.3.3
|
|
||||||
fsspec==2023.4.0
|
|
||||||
gmpy2
|
|
||||||
gptcache==0.1.14
|
|
||||||
gradio==3.27.0
|
|
||||||
gradio_client==0.1.3
|
|
||||||
greenlet==2.0.2
|
|
||||||
h11==0.14.0
|
|
||||||
httpcore==0.16.3
|
|
||||||
httpx==0.23.3
|
|
||||||
huggingface-hub==0.13.4
|
|
||||||
humanfriendly==10.0
|
|
||||||
icetk==0.0.7
|
|
||||||
idna
|
|
||||||
iopath==0.1.10
|
|
||||||
Jinja2
|
|
||||||
joblib==1.2.0
|
|
||||||
jsonschema==4.17.3
|
|
||||||
kiwisolver==1.4.4
|
|
||||||
langchain==0.0.142
|
|
||||||
layoutparser==0.3.4
|
|
||||||
linkify-it-py==2.0.0
|
|
||||||
lxml==4.9.2
|
|
||||||
Markdown==3.4.3
|
|
||||||
markdown-it-py==2.2.0
|
|
||||||
MarkupSafe==2.1.2
|
|
||||||
marshmallow==3.19.0
|
|
||||||
marshmallow-enum==1.5.1
|
|
||||||
matplotlib==3.7.1
|
|
||||||
mdit-py-plugins==0.3.3
|
|
||||||
mdurl==0.1.2
|
|
||||||
mkl-fft==1.3.1
|
|
||||||
mkl-random
|
|
||||||
mkl-service==2.4.0
|
|
||||||
monotonic==1.6
|
|
||||||
mpmath==1.2.1
|
|
||||||
msg-parser==1.2.0
|
|
||||||
multidict==6.0.4
|
|
||||||
mypy-extensions==1.0.0
|
|
||||||
networkx
|
|
||||||
nltk==3.8.1
|
|
||||||
numexpr==2.8.4
|
|
||||||
numpy
|
|
||||||
olefile==0.46
|
|
||||||
omegaconf==2.3.0
|
|
||||||
onnxruntime==1.14.1
|
|
||||||
openai==0.27.4
|
|
||||||
openapi-schema-pydantic==1.2.4
|
|
||||||
opencv-python==4.6.0.66
|
|
||||||
openpyxl==3.1.2
|
|
||||||
orjson==3.8.10
|
|
||||||
packaging==23.1
|
|
||||||
pandas==1.5.3
|
|
||||||
pdf2image==1.16.3
|
|
||||||
pdfminer.six==20221105
|
|
||||||
pdfplumber==0.9.0
|
|
||||||
Pillow==9.5.0
|
|
||||||
portalocker==2.7.0
|
|
||||||
protobuf==3.18.3
|
|
||||||
pycocotools==2.0.6
|
|
||||||
pycparser
|
|
||||||
pydantic==1.10.7
|
|
||||||
pydub==0.25.1
|
|
||||||
Pygments==2.15.0
|
|
||||||
pyOpenSSL
|
|
||||||
pypandoc==1.11
|
|
||||||
pyparsing==3.0.9
|
|
||||||
pyrsistent==0.19.3
|
|
||||||
PySocks
|
|
||||||
pytesseract==0.3.10
|
|
||||||
python-dateutil==2.8.2
|
|
||||||
python-docx==0.8.11
|
|
||||||
python-magic==0.4.27
|
|
||||||
python-multipart==0.0.6
|
|
||||||
python-pptx==0.6.21
|
|
||||||
pytz==2023.3
|
|
||||||
PyYAML==6.0
|
|
||||||
regex==2023.3.23
|
|
||||||
requests==2.28.2
|
|
||||||
rfc3986==1.5.0
|
|
||||||
rich==13.0.1
|
|
||||||
scikit-learn==1.2.2
|
|
||||||
scipy==1.10.1
|
|
||||||
semantic-version==2.10.0
|
|
||||||
sentence-transformers==2.2.2
|
|
||||||
sentencepiece==0.1.98
|
|
||||||
six
|
|
||||||
sniffio==1.3.0
|
|
||||||
soupsieve==2.4.1
|
|
||||||
SQLAlchemy==1.4.47
|
|
||||||
starlette==0.26.1
|
|
||||||
sympy
|
|
||||||
tenacity==8.2.2
|
|
||||||
threadpoolctl==3.1.0
|
|
||||||
timm==0.6.13
|
|
||||||
tokenizers==0.13.3
|
|
||||||
toolz==0.12.0
|
|
||||||
torch==2.0.0
|
|
||||||
torchaudio==2.0.0
|
|
||||||
torchvision==0.15.0
|
|
||||||
tqdm==4.65.0
|
|
||||||
transformers==4.28.1
|
|
||||||
triton==2.0.0
|
|
||||||
typing-inspect==0.8.0
|
|
||||||
typing_extensions==4.5.0
|
|
||||||
tzdata==2023.3
|
|
||||||
uc-micro-py==1.0.1
|
|
||||||
unstructured==0.5.12
|
|
||||||
unstructured-inference==0.3.2
|
|
||||||
urllib3
|
|
||||||
uvicorn==0.21.1
|
|
||||||
Wand==0.6.11
|
|
||||||
websockets==11.0.2
|
|
||||||
wrapt==1.14.1
|
|
||||||
XlsxWriter==3.1.0
|
|
||||||
yarl==1.8.2
|
|
|
@ -2,9 +2,15 @@ from duckduckgo_search import ddg
|
||||||
from duckduckgo_search.utils import SESSION
|
from duckduckgo_search.utils import SESSION
|
||||||
|
|
||||||
|
|
||||||
# SESSION.proxies = {
|
SESSION.proxies = {
|
||||||
# "http": f"socks5h://localhost:7890",
|
"http": f"socks5h://localhost:7890",
|
||||||
# "https": f"socks5h://localhost:7890"
|
"https": f"socks5h://localhost:7890"
|
||||||
# }
|
}
|
||||||
r = ddg("马保国")
|
r = ddg("马保国")
|
||||||
print(r)
|
print(r[:2])
|
||||||
|
"""
|
||||||
|
[{'title': '马保国 - 维基百科,自由的百科全书', 'href': 'https://zh.wikipedia.org/wiki/%E9%A9%AC%E4%BF%9D%E5%9B%BD', 'body': '马保国(1951年 — ) ,男,籍贯 山东 临沂,出生及长大于河南,中国大陆太极拳师,自称"浑元形意太极门掌门人" 。 马保国因2017年约战mma格斗家徐晓冬首次出现
|
||||||
|
大众视野中。 2020年5月,马保国在对阵民间武术爱好者王庆民的比赛中,30秒内被连续高速击倒三次,此事件成为了持续多日的社交 ...'}, {'title': '馬保國的主页 - 抖音', 'href': 'https://www.douyin.com/user/MS4wLjABAAAAW0E1ziOvxgUh3VVv5FE6xmoo3w5WtZalfphYZKj4mCg', 'body': '6.3万. #马马国教扛打功 最近有几个人模芳我动作,很危险啊,不可以的,朋友们不要受伤了。. 5.3万. #马保国直播带货榜第一 朋友们周末愉快,本周六早上湿点,我本人在此号进行第一次带货直播,活到老,学到老,越活越年轻。. 7.0万. #马保国击破红牛罐 昨天 ...'}]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
Loading…
Reference in New Issue