Qwen-7B-Chat/app.py

71 lines
2.3 KiB
Python

import os
from gradio.themes.utils import sizes
css = "footer {visibility: hidden}"
os.system('pip install tiktoken')
os.system('pip install "modelscope" --upgrade -f https://pypi.org/project/modelscope/')
os.system('pip install transformers_stream_generator')
import gradio as gr
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
block_label_text_color = '#4D63FF',
block_title_text_color = '#4D63FF',
button_primary_text_color = '#4D63FF',
button_primary_background_fill='#FFFFFF',
button_primary_border_color='#4D63FF',
button_primary_background_fill_hover='#EDEFFF',
)
def clear_session():
return '', None
model_id = 'qwen/Qwen-7B-Chat'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='v1.0.1', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", revision='v1.0.1',
trust_remote_code=True, fp16=True).eval()
model.generation_config = GenerationConfig.from_pretrained(model_id, trust_remote_code=True)
def generate_chat(input: str, history = None):
if input is None:
input = ''
if history is None:
history = []
history = history[-5:]
gen = model.chat(tokenizer, input, history=history, stream=True)
for x in gen:
history.append((input, x))
yield None, history
history.pop()
history.append((input, x))
return None, history
block = gr.Blocks(theme=theme, css=css)
with block as demo:
gr.Markdown("""<center><font size=8>Qwen-7B-Chat Bot</center>""")
chatbot = gr.Chatbot(lines=10, label='Qwen-7B-Chat', elem_classes="control-height")
message = gr.Textbox(lines=2, label='Input')
with gr.Row():
clear_history = gr.Button("🧹 清除历史对话")
sumbit = gr.Button("🚀 发送")
sumbit.click(generate_chat,
inputs=[message, chatbot],
outputs=[message, chatbot])
clear_history.click(fn=clear_session,
inputs=[],
outputs=[message, chatbot],
queue=False)
demo.queue().launch(server_name="0.0.0.0")