47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
from transformers import AutoTokenizer, AutoModel
|
|
from gradio.themes.utils import sizes
|
|
import gradio as gr
|
|
|
|
|
|
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
|
|
block_label_text_color = '#4D63FF',
|
|
block_title_text_color = '#4D63FF',
|
|
button_primary_text_color = '#4D63FF',
|
|
button_primary_background_fill='#FFFFFF',
|
|
button_primary_border_color='#4D63FF',
|
|
button_primary_background_fill_hover='#EDEFFF',
|
|
)
|
|
|
|
css = "footer {visibility: hidden}"
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
|
|
model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
|
|
model = model.eval()
|
|
|
|
|
|
def gene_code(prompt):
|
|
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
|
|
outputs = model.generate(inputs, max_length=256, top_k=1)
|
|
response = tokenizer.decode(outputs[0])
|
|
|
|
return response
|
|
|
|
|
|
with gr.Blocks(theme=theme, css=css) as demo:
|
|
gr.Markdown("""
|
|
<div align='center' ><font size='60'>代码生成(codegeex2-6b)</font></div>
|
|
""")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
prompt = gr.TextArea(label="提示")
|
|
with gr.Row():
|
|
button = gr.Button("提交", variant="primary")
|
|
generated_code = gr.TextArea(label='生成的代码')
|
|
|
|
button.click(fn=gene_code, inputs=prompt, outputs=generated_code)
|
|
examples = gr.Examples(examples=["# language: python\n# write a bubble sort function\n"],inputs=prompt, label="例子")
|
|
|
|
if __name__ == "__main__":
|
|
demo.queue().launch(server_name = "0.0.0.0")
|