24 lines
839 B
Python
24 lines
839 B
Python
|
import gradio as gr
|
||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||
|
|
||
|
|
||
|
tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompts-bart-long")
|
||
|
model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompts-bart-long", from_tf=True)
|
||
|
|
||
|
def generate(prompt):
|
||
|
batch = tokenizer(prompt, return_tensors="pt")
|
||
|
generated_ids = model.generate(batch["input_ids"], max_new_tokens=150)
|
||
|
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||
|
|
||
|
return output[0]
|
||
|
|
||
|
demo = gr.Interface(fn=generate,
|
||
|
inputs='text',
|
||
|
outputs='text',
|
||
|
title = "generate prompt",
|
||
|
examples = [["photographer"], ["developer"]])
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
demo.queue(concurrency_count=3).launch(server_name = "0.0.0.0", server_port = 7020)
|