maskformer-swin-base-coco/app.py

80 lines
2.5 KiB
Python

#全景分割
from PIL import Image
import io
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as T
import numpy
import gradio as gr
import itertools
import seaborn as sns
from panopticapi.utils import rgb2id
from gradio.themes.utils import sizes
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
from torch import nn
import numpy as np
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
block_label_text_color = '#4D63FF',
block_title_text_color = '#4D63FF',
button_primary_text_color = '#4D63FF',
button_primary_background_fill='#FFFFFF',
button_primary_border_color='#4D63FF',
button_primary_background_fill_hover='#EDEFFF',
)
feature_extractor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-base-coco")
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-coco")
def detect_objects(image_input):
inputs = feature_extractor(images=image_input, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
result = feature_extractor.post_process_panoptic_segmentation(outputs, target_sizes=[image_input.size[::-1]])[0]
#Visualize prediction
viz_img = visualize_prediction(result["segmentation"], image_input)
return viz_img
def visualize_prediction(result, image):
color_palette = [list(np.random.choice(range(256), size=3)) for _ in range(len(model.config.id2label))]
seg = result
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array(color_palette)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
return img
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
gr.Markdown("""
<div align='center' ><font size='60'>全景分割</font></div>
""")
with gr.Row():
with gr.Column():
image = gr.Image(label="图片", type="pil")
with gr.Row():
button = gr.Button("提交", variant="primary")
box2 = gr.Image(label="图片")
button.click(fn=detect_objects, inputs=[image], outputs=box2)
examples = gr.Examples(examples=[['1.jpg'], ['2.jpg']], inputs=[image], label="例子")
if __name__ == '__main__':
demo.queue().launch(server_name = "0.0.0.0")