ailab/chatgpt-prompts-bart-long/app.py

24 lines
839 B
Python

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompts-bart-long")
model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompts-bart-long", from_tf=True)
def generate(prompt):
batch = tokenizer(prompt, return_tensors="pt")
generated_ids = model.generate(batch["input_ids"], max_new_tokens=150)
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return output[0]
demo = gr.Interface(fn=generate,
inputs='text',
outputs='text',
title = "generate prompt",
examples = [["photographer"], ["developer"]])
if __name__ == "__main__":
demo.queue(concurrency_count=3).launch(server_name = "0.0.0.0", server_port = 7020)