codegeex2-6b/app.py

47 lines
1.6 KiB
Python

from transformers import AutoTokenizer, AutoModel
from gradio.themes.utils import sizes
import gradio as gr
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
block_label_text_color = '#4D63FF',
block_title_text_color = '#4D63FF',
button_primary_text_color = '#4D63FF',
button_primary_background_fill='#FFFFFF',
button_primary_border_color='#4D63FF',
button_primary_background_fill_hover='#EDEFFF',
)
css = "footer {visibility: hidden}"
tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
model = model.eval()
def gene_code(prompt):
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_length=256, top_k=1)
response = tokenizer.decode(outputs[0])
return response
with gr.Blocks(theme=theme, css=css) as demo:
gr.Markdown("""
<div align='center' ><font size='60'>代码生成(codegeex2-6b)</font></div>
""")
with gr.Row():
with gr.Column():
prompt = gr.TextArea(label="提示")
with gr.Row():
button = gr.Button("提交", variant="primary")
generated_code = gr.TextArea(label='生成的代码')
button.click(fn=gene_code, inputs=prompt, outputs=generated_code)
examples = gr.Examples(examples=["# language: python\n# write a bubble sort function\n"],inputs=prompt, label="例子")
if __name__ == "__main__":
demo.queue().launch(server_name = "0.0.0.0")