update@web_demo

This commit is contained in:
yanqiangmiffy 2023-04-19 02:09:00 +08:00
parent b549679dde
commit 08afbac93b
11 changed files with 36 additions and 192 deletions

View File

@ -11,8 +11,9 @@
## 🚀 特性
- 🚀 2023/04/19 增加web search功能需要确保网络畅通
- 🚀 2023/04/18 webui增加知识库选择功能
- 🐯 2023/04/19 引入ChuanhuChatGPT皮肤
- 📱 2023/04/19 增加web search功能需要确保网络畅通
- 📚 2023/04/18 webui增加知识库选择功能
- 🚀 2023/04/18 修复推理预测超时5s报错问题
- 🎉 2023/04/17 支持多种文档上传与内容解析pdf、docxppt等
- 🎉 2023/04/17 支持知识增量更新

Binary file not shown.

Binary file not shown.

View File

@ -1,25 +1,12 @@
from __future__ import annotations
import logging
from llama_index import Prompt
from typing import List, Tuple
import mdtex2html
from app_modules.presets import *
from app_modules.utils import *
def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
logging.debug("Compacting text chunks...🚀🚀🚀")
combined_str = [c.strip() for c in text_chunks if c.strip()]
combined_str = [f"[{index+1}] {c}" for index, c in enumerate(combined_str)]
combined_str = "\n\n".join(combined_str)
# resplit based on self.max_chunk_overlap
text_splitter = self.get_text_splitter_given_prompt(prompt, 1, padding=1)
return text_splitter.split_text(combined_str)
def postprocess(
self, y: List[Tuple[str | None, str | None]]
self, y: List[Tuple[str | None, str | None]]
) -> List[Tuple[str | None, str | None]]:
"""
Parameters:
@ -39,13 +26,17 @@ def postprocess(
temp.append((user, bot))
return temp
with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", encoding="utf-8") as f2:
with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r",
encoding="utf-8") as f2:
customJS = f.read()
kelpyCodos = f2.read()
def reload_javascript():
print("Reloading javascript...")
js = f'<script>{customJS}</script><script>{kelpyCodos}</script>'
def template_response(*args, **kwargs):
res = GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(b'</html>', f'{js}</html>'.encode("utf8"))
@ -54,4 +45,5 @@ def reload_javascript():
gr.routes.templates.TemplateResponse = template_response
GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse

View File

@ -1,32 +1,16 @@
# -*- coding:utf-8 -*-
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
import logging
import json
import os
import datetime
import hashlib
import csv
import requests
import re
import html
import markdown2
import torch
import sys
import gc
from pygments.lexers import guess_lexer, ClassNotFound
import gradio as gr
from pypinyin import lazy_pinyin
import tiktoken
import html
import logging
import re
import mdtex2html
from markdown import markdown
from pygments import highlight
from pygments.lexers import guess_lexer, get_lexer_by_name
from pygments.formatters import HtmlFormatter
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from pygments.lexers import ClassNotFound
from pygments.lexers import guess_lexer, get_lexer_by_name
from app_modules.presets import *
@ -241,142 +225,3 @@ class State:
shared_state = State()
# Greedy Search
def greedy_search(input_ids: torch.Tensor,
model: torch.nn.Module,
tokenizer: transformers.PreTrainedTokenizer,
stop_words: list,
max_length: int,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 25) -> Iterator[str]:
generated_tokens = []
past_key_values = None
current_length = 1
for i in range(max_length):
with torch.no_grad():
if past_key_values is None:
outputs = model(input_ids)
else:
outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
# apply temperature
logits /= temperature
probs = torch.softmax(logits, dim=-1)
# apply top_p
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
# apply top_k
# if top_k is not None:
# probs_sort1, _ = torch.topk(probs_sort, top_k)
# min_top_probs_sort = torch.min(probs_sort1, dim=-1, keepdim=True).values
# probs_sort = torch.where(probs_sort < min_top_probs_sort, torch.full_like(probs_sort, float(0.0)), probs_sort)
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
input_ids = torch.cat((input_ids, next_token), dim=-1)
generated_tokens.append(next_token[0].item())
text = tokenizer.decode(generated_tokens)
yield text
if any([x in text for x in stop_words]):
del past_key_values
del logits
del probs
del probs_sort
del probs_idx
del probs_sum
gc.collect()
return
def generate_prompt_with_history(text, history, tokenizer, max_length=2048):
prompt = "The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n[|Human|]Hello!\n[|AI|]Hi!"
history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0], x[1]) for x in history]
history.append("\n[|Human|]{}\n[|AI|]".format(text))
history_text = ""
flag = False
for x in history[::-1]:
if tokenizer(prompt + history_text + x, return_tensors="pt")['input_ids'].size(-1) <= max_length:
history_text = x + history_text
flag = True
else:
break
if flag:
return prompt + history_text, tokenizer(prompt + history_text, return_tensors="pt")
else:
return None
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
for stop_word in stop_words:
if s.endswith(stop_word):
return True
for i in range(1, len(stop_word)):
if s.endswith(stop_word[:i]):
return True
return False
def load_tokenizer_and_model(base_model, adapter_model, load_8bit=False):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except: # noqa: E722
pass
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
adapter_model,
torch_dtype=torch.float16,
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
adapter_model,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model, device_map={"": device}, low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
adapter_model,
device_map={"": device},
)
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
return tokenizer, model, device

Binary file not shown.

Before

Width:  |  Height:  |  Size: 128 KiB

After

Width:  |  Height:  |  Size: 109 KiB

29
main.py
View File

@ -1,7 +1,6 @@
import os
import shutil
import gradio as gr
from app_modules.presets import *
from clc.langchain_application import LangChainApplication
@ -93,6 +92,7 @@ def predict(input,
search_text += web_content
return '', history, history, search_text
with open("assets/custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
@ -147,14 +147,20 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
outputs=None)
with gr.Column(scale=4):
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400)
message = gr.Textbox(label='请输入问题')
with gr.Row():
clear_history = gr.Button("🧹 清除历史对话")
send = gr.Button("🚀 发送")
with gr.Column(scale=2):
search = gr.Textbox(label='搜索结果')
chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400)
with gr.Row():
message = gr.Textbox(label='请输入问题')
with gr.Row():
clear_history = gr.Button("🧹 清除历史对话")
send = gr.Button("🚀 发送")
with gr.Row():
gr.Markdown("""提醒:<br>
[Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain) <br>
有任何使用问题[Github Issue区](https://github.com/yanqiangmiffy/Chinese-LangChain)进行反馈. <br>
""")
with gr.Column(scale=2):
search = gr.Textbox(label='搜索结果')
set_kg_btn.click(
set_knowledge,
show_progress=True,
@ -185,10 +191,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
state
],
outputs=[message, chatbot, state, search])
gr.Markdown("""提醒:<br>
[Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain) <br>
有任何使用问题[Github Issue区](https://github.com/yanqiangmiffy/Chinese-LangChain)进行反馈. <br>
""")
demo.queue(concurrency_count=2).launch(
server_name='0.0.0.0',
server_port=8888,

View File

@ -4,4 +4,7 @@ transformers
sentence_transformers
faiss-cpu
unstructured
duckduckgo_search
duckduckgo_search
mdtex2html
chardet
cchardet