import gradio as gr
import torch
import gc
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import os
from gradio.themes.utils import sizes
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
    block_label_text_color = '#4D63FF',
    block_title_text_color = '#4D63FF',
    button_primary_text_color = '#4D63FF',
    button_primary_background_fill='#FFFFFF',
    button_primary_border_color='#4D63FF',
    button_primary_background_fill_hover='#EDEFFF',
)


os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"

model_id = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")


def text2image(prompt):
    image = pipe(prompt).images[0]

    return image


demo = gr.Interface(fn=text2image,
                    inputs='text',
                    outputs='image',
                    theme = theme,
		    css = "footer {visibility: hidden}", 
		    allow_flagging = "never",
                    examples = ['a photo of an astronaut riding a horse on mars'])


if __name__ == "__main__":
    demo.queue(concurrency_count=10).launch(server_name = "0.0.0.0")