text_generation/distilgpt2/app.py

40 lines
1.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gradio as gr
from transformers import pipeline, set_seed
def inference(text):
model_path = "distilgpt2"
generator = pipeline('text-generation', model=model_path)
set_seed(42)
output=[]
lst=generator(text, max_length=20, num_return_sequences=5)
for dic in lst:
output.append(dic['generated_text'])
return output
# tokenizer = GPT2Tokenizer.from_pretrained(model_path)
# model = GPT2Model.from_pretrained(model_path)
# encoded_input = tokenizer(text, return_tensors='pt')
# output = model(**encoded_input)
# print(output)
# return output
examples=[["Hello, Im a language model."]]
with gr.Blocks() as demo:
gr.Markdown(
"""
# Text generation:distilgpt2
Gradio Demo for distilgpt2. To use it, simply type in text, or click one of the examples to load them.
""")
with gr.Row():
text_input = gr.Textbox()
text_output = gr.Textbox()
image_button = gr.Button("上传")
image_button.click(inference, inputs=text_input, outputs=text_output)
gr.Examples(examples,inputs=text_input)
demo.launch()