panorama_segmentation/app.py

44 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gradio as gr
from transformers import MaskFormerImageProcessor, MaskFormerForInstanceSegmentation
def inference(img):
# load MaskFormer fine-tuned on COCO panoptic segmentation
model_path = "maskformer-swin-large-coco"
processor = MaskFormerImageProcessor.from_pretrained(model_path)
model = MaskFormerForInstanceSegmentation.from_pretrained(model_path)
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
# you can pass them to processor for postprocessing
result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[img.size[::-1]])[0]
# we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
predicted_panoptic_map = result["segmentation"]
return predicted_panoptic_map
examples=[['example_cat.jpg']]
with gr.Blocks() as demo:
gr.Markdown(
"""
# Panorama segmentation:maskformer-swin-large-coco
这是maskformer-swin-large-coco的Gradio Demo用于全景分割。上传你想要的图像或者点击下面的示例来加载它。
""")
with gr.Row():
image_input = gr.Image(type="pil")
text_output = gr.Textbox()
image_button = gr.Button("上传")
image_button.click(inference, inputs=image_input, outputs=text_output)
gr.Examples(examples,inputs=image_input)
demo.launch(server_name="0.0.0.0")