import gradio as gr
from transformers import AutoProcessor, CLIPSegForImageSegmentation


def inference(img):
    model_path = "clipseg-rd64-refined"
    processor = AutoProcessor.from_pretrained(model_path)
    model = CLIPSegForImageSegmentation.from_pretrained(model_path)

    texts = ["a cat", "a remote", "a blanket"]
    inputs = processor(text=texts, images=[img] * len(texts), padding=True, return_tensors="pt")

    outputs = model(**inputs)

    logits = outputs.logits
    print(logits.shape)
    return logits.shape

examples=[['example_cat.jpg']]

with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Semantic segmentation:clipseg-rd64-refined
    Gradio Demo for clipseg-rd64-refined. To use it, simply upload your image, or click one of the examples to load them.
    """)
    with gr.Row():
        image_input = gr.Image(type="pil")
        text_output = gr.Textbox()
    image_button = gr.Button("上传")
    image_button.click(inference, inputs=image_input, outputs=text_output)
    gr.Examples(examples,inputs=image_input)

demo.launch()