24 lines
839 B
Python
24 lines
839 B
Python
import gradio as gr
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompts-bart-long")
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompts-bart-long", from_tf=True)
|
|
|
|
def generate(prompt):
|
|
batch = tokenizer(prompt, return_tensors="pt")
|
|
generated_ids = model.generate(batch["input_ids"], max_new_tokens=150)
|
|
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
|
|
return output[0]
|
|
|
|
demo = gr.Interface(fn=generate,
|
|
inputs='text',
|
|
outputs='text',
|
|
title = "generate prompt",
|
|
examples = [["photographer"], ["developer"]])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo.queue(concurrency_count=3).launch(server_name = "0.0.0.0", server_port = 7020)
|