update@web_demo

This commit is contained in:
yanqiangmiffy 2023-04-19 02:09:00 +08:00
parent b549679dde
commit 08afbac93b
11 changed files with 36 additions and 192 deletions

View File

@ -11,8 +11,9 @@
## 🚀 特性 ## 🚀 特性
- 🚀 2023/04/19 增加web search功能需要确保网络畅通 - 🐯 2023/04/19 引入ChuanhuChatGPT皮肤
- 🚀 2023/04/18 webui增加知识库选择功能 - 📱 2023/04/19 增加web search功能需要确保网络畅通
- 📚 2023/04/18 webui增加知识库选择功能
- 🚀 2023/04/18 修复推理预测超时5s报错问题 - 🚀 2023/04/18 修复推理预测超时5s报错问题
- 🎉 2023/04/17 支持多种文档上传与内容解析pdf、docxppt等 - 🎉 2023/04/17 支持多种文档上传与内容解析pdf、docxppt等
- 🎉 2023/04/17 支持知识增量更新 - 🎉 2023/04/17 支持知识增量更新

Binary file not shown.

Binary file not shown.

View File

@ -1,25 +1,12 @@
from __future__ import annotations from __future__ import annotations
import logging
from llama_index import Prompt
from typing import List, Tuple from typing import List, Tuple
import mdtex2html
from app_modules.presets import *
from app_modules.utils import * from app_modules.utils import *
def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
logging.debug("Compacting text chunks...🚀🚀🚀")
combined_str = [c.strip() for c in text_chunks if c.strip()]
combined_str = [f"[{index+1}] {c}" for index, c in enumerate(combined_str)]
combined_str = "\n\n".join(combined_str)
# resplit based on self.max_chunk_overlap
text_splitter = self.get_text_splitter_given_prompt(prompt, 1, padding=1)
return text_splitter.split_text(combined_str)
def postprocess( def postprocess(
self, y: List[Tuple[str | None, str | None]] self, y: List[Tuple[str | None, str | None]]
) -> List[Tuple[str | None, str | None]]: ) -> List[Tuple[str | None, str | None]]:
""" """
Parameters: Parameters:
@ -39,13 +26,17 @@ def postprocess(
temp.append((user, bot)) temp.append((user, bot))
return temp return temp
with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", encoding="utf-8") as f2:
with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r",
encoding="utf-8") as f2:
customJS = f.read() customJS = f.read()
kelpyCodos = f2.read() kelpyCodos = f2.read()
def reload_javascript(): def reload_javascript():
print("Reloading javascript...") print("Reloading javascript...")
js = f'<script>{customJS}</script><script>{kelpyCodos}</script>' js = f'<script>{customJS}</script><script>{kelpyCodos}</script>'
def template_response(*args, **kwargs): def template_response(*args, **kwargs):
res = GradioTemplateResponseOriginal(*args, **kwargs) res = GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(b'</html>', f'{js}</html>'.encode("utf8")) res.body = res.body.replace(b'</html>', f'{js}</html>'.encode("utf8"))
@ -54,4 +45,5 @@ def reload_javascript():
gr.routes.templates.TemplateResponse = template_response gr.routes.templates.TemplateResponse = template_response
GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse

View File

@ -1,32 +1,16 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
import logging
import json
import os
import datetime
import hashlib
import csv
import requests
import re
import html
import markdown2
import torch
import sys
import gc
from pygments.lexers import guess_lexer, ClassNotFound
import gradio as gr import html
from pypinyin import lazy_pinyin import logging
import tiktoken import re
import mdtex2html import mdtex2html
from markdown import markdown from markdown import markdown
from pygments import highlight from pygments import highlight
from pygments.lexers import guess_lexer, get_lexer_by_name
from pygments.formatters import HtmlFormatter from pygments.formatters import HtmlFormatter
import transformers from pygments.lexers import ClassNotFound
from peft import PeftModel from pygments.lexers import guess_lexer, get_lexer_by_name
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from app_modules.presets import * from app_modules.presets import *
@ -241,142 +225,3 @@ class State:
shared_state = State() shared_state = State()
# Greedy Search
def greedy_search(input_ids: torch.Tensor,
model: torch.nn.Module,
tokenizer: transformers.PreTrainedTokenizer,
stop_words: list,
max_length: int,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 25) -> Iterator[str]:
generated_tokens = []
past_key_values = None
current_length = 1
for i in range(max_length):
with torch.no_grad():
if past_key_values is None:
outputs = model(input_ids)
else:
outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
# apply temperature
logits /= temperature
probs = torch.softmax(logits, dim=-1)
# apply top_p
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
# apply top_k
# if top_k is not None:
# probs_sort1, _ = torch.topk(probs_sort, top_k)
# min_top_probs_sort = torch.min(probs_sort1, dim=-1, keepdim=True).values
# probs_sort = torch.where(probs_sort < min_top_probs_sort, torch.full_like(probs_sort, float(0.0)), probs_sort)
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
input_ids = torch.cat((input_ids, next_token), dim=-1)
generated_tokens.append(next_token[0].item())
text = tokenizer.decode(generated_tokens)
yield text
if any([x in text for x in stop_words]):
del past_key_values
del logits
del probs
del probs_sort
del probs_idx
del probs_sum
gc.collect()
return
def generate_prompt_with_history(text, history, tokenizer, max_length=2048):
prompt = "The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n[|Human|]Hello!\n[|AI|]Hi!"
history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0], x[1]) for x in history]
history.append("\n[|Human|]{}\n[|AI|]".format(text))
history_text = ""
flag = False
for x in history[::-1]:
if tokenizer(prompt + history_text + x, return_tensors="pt")['input_ids'].size(-1) <= max_length:
history_text = x + history_text
flag = True
else:
break
if flag:
return prompt + history_text, tokenizer(prompt + history_text, return_tensors="pt")
else:
return None
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
for stop_word in stop_words:
if s.endswith(stop_word):
return True
for i in range(1, len(stop_word)):
if s.endswith(stop_word[:i]):
return True
return False
def load_tokenizer_and_model(base_model, adapter_model, load_8bit=False):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except: # noqa: E722
pass
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
adapter_model,
torch_dtype=torch.float16,
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
adapter_model,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model, device_map={"": device}, low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
adapter_model,
device_map={"": device},
)
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
return tokenizer, model, device

Binary file not shown.

Before

Width:  |  Height:  |  Size: 128 KiB

After

Width:  |  Height:  |  Size: 109 KiB

29
main.py
View File

@ -1,7 +1,6 @@
import os import os
import shutil import shutil
import gradio as gr
from app_modules.presets import * from app_modules.presets import *
from clc.langchain_application import LangChainApplication from clc.langchain_application import LangChainApplication
@ -93,6 +92,7 @@ def predict(input,
search_text += web_content search_text += web_content
return '', history, history, search_text return '', history, history, search_text
with open("assets/custom.css", "r", encoding="utf-8") as f: with open("assets/custom.css", "r", encoding="utf-8") as f:
customCSS = f.read() customCSS = f.read()
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
@ -147,14 +147,20 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
outputs=None) outputs=None)
with gr.Column(scale=4): with gr.Column(scale=4):
with gr.Row(): with gr.Row():
with gr.Column(scale=4): chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400)
chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400) with gr.Row():
message = gr.Textbox(label='请输入问题') message = gr.Textbox(label='请输入问题')
with gr.Row(): with gr.Row():
clear_history = gr.Button("🧹 清除历史对话") clear_history = gr.Button("🧹 清除历史对话")
send = gr.Button("🚀 发送") send = gr.Button("🚀 发送")
with gr.Column(scale=2): with gr.Row():
search = gr.Textbox(label='搜索结果') gr.Markdown("""提醒:<br>
[Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain) <br>
有任何使用问题[Github Issue区](https://github.com/yanqiangmiffy/Chinese-LangChain)进行反馈. <br>
""")
with gr.Column(scale=2):
search = gr.Textbox(label='搜索结果')
set_kg_btn.click( set_kg_btn.click(
set_knowledge, set_knowledge,
show_progress=True, show_progress=True,
@ -185,10 +191,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
state state
], ],
outputs=[message, chatbot, state, search]) outputs=[message, chatbot, state, search])
gr.Markdown("""提醒:<br>
[Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain) <br>
有任何使用问题[Github Issue区](https://github.com/yanqiangmiffy/Chinese-LangChain)进行反馈. <br>
""")
demo.queue(concurrency_count=2).launch( demo.queue(concurrency_count=2).launch(
server_name='0.0.0.0', server_name='0.0.0.0',
server_port=8888, server_port=8888,

View File

@ -4,4 +4,7 @@ transformers
sentence_transformers sentence_transformers
faiss-cpu faiss-cpu
unstructured unstructured
duckduckgo_search duckduckgo_search
mdtex2html
chardet
cchardet