diff --git a/README.md b/README.md index 25e2e02..a7d3ee2 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,9 @@ ## 🚀 特性 -- 🚀 2023/04/19 增加web search功能,需要确保网络畅通! -- 🚀 2023/04/18 webui增加知识库选择功能 +- 🐯 2023/04/19 引入ChuanhuChatGPT皮肤 +- 📱 2023/04/19 增加web search功能,需要确保网络畅通! +- 📚 2023/04/18 webui增加知识库选择功能 - 🚀 2023/04/18 修复推理预测超时5s报错问题 - 🎉 2023/04/17 支持多种文档上传与内容解析:pdf、docx,ppt等 - 🎉 2023/04/17 支持知识增量更新 diff --git a/app_modules/__pycache__/presets.cpython-310.pyc b/app_modules/__pycache__/presets.cpython-310.pyc new file mode 100644 index 0000000..8e7a651 Binary files /dev/null and b/app_modules/__pycache__/presets.cpython-310.pyc differ diff --git a/app_modules/__pycache__/presets.cpython-39.pyc b/app_modules/__pycache__/presets.cpython-39.pyc new file mode 100644 index 0000000..e7176c1 Binary files /dev/null and b/app_modules/__pycache__/presets.cpython-39.pyc differ diff --git a/app_modules/overwrites.py b/app_modules/overwrites.py index 4bfa339..7ef9614 100644 --- a/app_modules/overwrites.py +++ b/app_modules/overwrites.py @@ -1,25 +1,12 @@ from __future__ import annotations -import logging -from llama_index import Prompt from typing import List, Tuple -import mdtex2html -from app_modules.presets import * from app_modules.utils import * -def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]: - logging.debug("Compacting text chunks...🚀🚀🚀") - combined_str = [c.strip() for c in text_chunks if c.strip()] - combined_str = [f"[{index+1}] {c}" for index, c in enumerate(combined_str)] - combined_str = "\n\n".join(combined_str) - # resplit based on self.max_chunk_overlap - text_splitter = self.get_text_splitter_given_prompt(prompt, 1, padding=1) - return text_splitter.split_text(combined_str) - def postprocess( - self, y: List[Tuple[str | None, str | None]] + self, y: List[Tuple[str | None, str | None]] ) -> List[Tuple[str | None, str | None]]: """ Parameters: @@ -39,13 +26,17 @@ def postprocess( temp.append((user, bot)) return temp -with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", encoding="utf-8") as f2: + +with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", + encoding="utf-8") as f2: customJS = f.read() kelpyCodos = f2.read() + def reload_javascript(): print("Reloading javascript...") js = f'' + def template_response(*args, **kwargs): res = GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace(b'', f'{js}'.encode("utf8")) @@ -54,4 +45,5 @@ def reload_javascript(): gr.routes.templates.TemplateResponse = template_response + GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse diff --git a/app_modules/utils.py b/app_modules/utils.py index 6c8f55c..80a52ef 100644 --- a/app_modules/utils.py +++ b/app_modules/utils.py @@ -1,32 +1,16 @@ # -*- coding:utf-8 -*- from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type -import logging -import json -import os -import datetime -import hashlib -import csv -import requests -import re -import html -import markdown2 -import torch -import sys -import gc -from pygments.lexers import guess_lexer, ClassNotFound -import gradio as gr -from pypinyin import lazy_pinyin -import tiktoken +import html +import logging +import re + import mdtex2html from markdown import markdown from pygments import highlight -from pygments.lexers import guess_lexer, get_lexer_by_name from pygments.formatters import HtmlFormatter -import transformers -from peft import PeftModel -from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer +from pygments.lexers import ClassNotFound +from pygments.lexers import guess_lexer, get_lexer_by_name from app_modules.presets import * @@ -241,142 +225,3 @@ class State: shared_state = State() - - -# Greedy Search -def greedy_search(input_ids: torch.Tensor, - model: torch.nn.Module, - tokenizer: transformers.PreTrainedTokenizer, - stop_words: list, - max_length: int, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = 25) -> Iterator[str]: - generated_tokens = [] - past_key_values = None - current_length = 1 - for i in range(max_length): - with torch.no_grad(): - if past_key_values is None: - outputs = model(input_ids) - else: - outputs = model(input_ids[:, -1:], past_key_values=past_key_values) - logits = outputs.logits[:, -1, :] - past_key_values = outputs.past_key_values - - # apply temperature - logits /= temperature - - probs = torch.softmax(logits, dim=-1) - # apply top_p - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > top_p - probs_sort[mask] = 0.0 - - # apply top_k - # if top_k is not None: - # probs_sort1, _ = torch.topk(probs_sort, top_k) - # min_top_probs_sort = torch.min(probs_sort1, dim=-1, keepdim=True).values - # probs_sort = torch.where(probs_sort < min_top_probs_sort, torch.full_like(probs_sort, float(0.0)), probs_sort) - - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - - input_ids = torch.cat((input_ids, next_token), dim=-1) - - generated_tokens.append(next_token[0].item()) - text = tokenizer.decode(generated_tokens) - - yield text - if any([x in text for x in stop_words]): - del past_key_values - del logits - del probs - del probs_sort - del probs_idx - del probs_sum - gc.collect() - return - - -def generate_prompt_with_history(text, history, tokenizer, max_length=2048): - prompt = "The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n[|Human|]Hello!\n[|AI|]Hi!" - history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0], x[1]) for x in history] - history.append("\n[|Human|]{}\n[|AI|]".format(text)) - history_text = "" - flag = False - for x in history[::-1]: - if tokenizer(prompt + history_text + x, return_tensors="pt")['input_ids'].size(-1) <= max_length: - history_text = x + history_text - flag = True - else: - break - if flag: - return prompt + history_text, tokenizer(prompt + history_text, return_tensors="pt") - else: - return None - - -def is_stop_word_or_prefix(s: str, stop_words: list) -> bool: - for stop_word in stop_words: - if s.endswith(stop_word): - return True - for i in range(1, len(stop_word)): - if s.endswith(stop_word[:i]): - return True - return False - - -def load_tokenizer_and_model(base_model, adapter_model, load_8bit=False): - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - try: - if torch.backends.mps.is_available(): - device = "mps" - except: # noqa: E722 - pass - tokenizer = LlamaTokenizer.from_pretrained(base_model) - if device == "cuda": - model = LlamaForCausalLM.from_pretrained( - base_model, - load_in_8bit=load_8bit, - torch_dtype=torch.float16, - device_map="auto", - ) - model = PeftModel.from_pretrained( - model, - adapter_model, - torch_dtype=torch.float16, - ) - elif device == "mps": - model = LlamaForCausalLM.from_pretrained( - base_model, - device_map={"": device}, - torch_dtype=torch.float16, - ) - model = PeftModel.from_pretrained( - model, - adapter_model, - device_map={"": device}, - torch_dtype=torch.float16, - ) - else: - model = LlamaForCausalLM.from_pretrained( - base_model, device_map={"": device}, low_cpu_mem_usage=True - ) - model = PeftModel.from_pretrained( - model, - adapter_model, - device_map={"": device}, - ) - - if not load_8bit: - model.half() # seems to fix bugs for some users. - - model.eval() - return tokenizer, model, device diff --git a/clc/__pycache__/langchain_application.cpython-39.pyc b/clc/__pycache__/langchain_application.cpython-39.pyc index c93a0da..984444a 100644 Binary files a/clc/__pycache__/langchain_application.cpython-39.pyc and b/clc/__pycache__/langchain_application.cpython-39.pyc differ diff --git a/clc/__pycache__/source_service.cpython-310.pyc b/clc/__pycache__/source_service.cpython-310.pyc index 4b1540c..c25759c 100644 Binary files a/clc/__pycache__/source_service.cpython-310.pyc and b/clc/__pycache__/source_service.cpython-310.pyc differ diff --git a/clc/__pycache__/source_service.cpython-39.pyc b/clc/__pycache__/source_service.cpython-39.pyc index e82dc2a..90144d8 100644 Binary files a/clc/__pycache__/source_service.cpython-39.pyc and b/clc/__pycache__/source_service.cpython-39.pyc differ diff --git a/images/web_demo_new.png b/images/web_demo_new.png index 9a2aa81..41655b7 100644 Binary files a/images/web_demo_new.png and b/images/web_demo_new.png differ diff --git a/main.py b/main.py index 450f827..8e4f9f7 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,6 @@ import os import shutil -import gradio as gr from app_modules.presets import * from clc.langchain_application import LangChainApplication @@ -93,6 +92,7 @@ def predict(input, search_text += web_content return '', history, history, search_text + with open("assets/custom.css", "r", encoding="utf-8") as f: customCSS = f.read() with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: @@ -147,14 +147,20 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: outputs=None) with gr.Column(scale=4): with gr.Row(): - with gr.Column(scale=4): - chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400) - message = gr.Textbox(label='请输入问题') - with gr.Row(): - clear_history = gr.Button("🧹 清除历史对话") - send = gr.Button("🚀 发送") - with gr.Column(scale=2): - search = gr.Textbox(label='搜索结果') + chatbot = gr.Chatbot(label='Chinese-LangChain').style(height=400) + with gr.Row(): + message = gr.Textbox(label='请输入问题') + with gr.Row(): + clear_history = gr.Button("🧹 清除历史对话") + send = gr.Button("🚀 发送") + with gr.Row(): + gr.Markdown("""提醒:
+ [Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain)
+ 有任何使用问题[Github Issue区](https://github.com/yanqiangmiffy/Chinese-LangChain)进行反馈.
+ """) + with gr.Column(scale=2): + search = gr.Textbox(label='搜索结果') + set_kg_btn.click( set_knowledge, show_progress=True, @@ -185,10 +191,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: state ], outputs=[message, chatbot, state, search]) - gr.Markdown("""提醒:
- [Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain)
- 有任何使用问题[Github Issue区](https://github.com/yanqiangmiffy/Chinese-LangChain)进行反馈.
- """) + demo.queue(concurrency_count=2).launch( server_name='0.0.0.0', server_port=8888, diff --git a/requirements.txt b/requirements.txt index e132ab1..9af61a4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,7 @@ transformers sentence_transformers faiss-cpu unstructured -duckduckgo_search \ No newline at end of file +duckduckgo_search +mdtex2html +chardet +cchardet \ No newline at end of file