71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
import os
|
|
from gradio.themes.utils import sizes
|
|
|
|
|
|
css = "footer {visibility: hidden}"
|
|
|
|
os.system('pip install tiktoken')
|
|
os.system('pip install "modelscope" --upgrade -f https://pypi.org/project/modelscope/')
|
|
os.system('pip install transformers_stream_generator')
|
|
|
|
import gradio as gr
|
|
from modelscope.pipelines import pipeline
|
|
from modelscope.utils.constant import Tasks
|
|
from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
import torch
|
|
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
|
|
|
|
|
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
|
|
block_label_text_color = '#4D63FF',
|
|
block_title_text_color = '#4D63FF',
|
|
button_primary_text_color = '#4D63FF',
|
|
button_primary_background_fill='#FFFFFF',
|
|
button_primary_border_color='#4D63FF',
|
|
button_primary_background_fill_hover='#EDEFFF',
|
|
)
|
|
|
|
def clear_session():
|
|
return '', None
|
|
|
|
model_id = 'qwen/Qwen-7B-Chat'
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='v1.0.1', trust_remote_code=True)
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", revision='v1.0.1',
|
|
trust_remote_code=True, fp16=True).eval()
|
|
model.generation_config = GenerationConfig.from_pretrained(model_id, trust_remote_code=True)
|
|
|
|
def generate_chat(input: str, history = None):
|
|
if input is None:
|
|
input = ''
|
|
if history is None:
|
|
history = []
|
|
history = history[-5:]
|
|
gen = model.chat(tokenizer, input, history=history, stream=True)
|
|
for x in gen:
|
|
history.append((input, x))
|
|
yield None, history
|
|
history.pop()
|
|
history.append((input, x))
|
|
return None, history
|
|
|
|
block = gr.Blocks(theme=theme, css=css)
|
|
with block as demo:
|
|
gr.Markdown("""<center><font size=8>Qwen-7B-Chat Bot</center>""")
|
|
|
|
chatbot = gr.Chatbot(lines=10, label='Qwen-7B-Chat', elem_classes="control-height")
|
|
message = gr.Textbox(lines=2, label='Input')
|
|
|
|
with gr.Row():
|
|
clear_history = gr.Button("🧹 清除历史对话")
|
|
sumbit = gr.Button("🚀 发送")
|
|
|
|
sumbit.click(generate_chat,
|
|
inputs=[message, chatbot],
|
|
outputs=[message, chatbot])
|
|
clear_history.click(fn=clear_session,
|
|
inputs=[],
|
|
outputs=[message, chatbot],
|
|
queue=False)
|
|
|
|
demo.queue().launch(server_name="0.0.0.0")
|