80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
#全景分割
|
|
from PIL import Image
|
|
import io
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import torchvision.transforms as T
|
|
import numpy
|
|
import gradio as gr
|
|
import itertools
|
|
import seaborn as sns
|
|
from panopticapi.utils import rgb2id
|
|
from gradio.themes.utils import sizes
|
|
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
|
|
from torch import nn
|
|
import numpy as np
|
|
|
|
|
|
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
|
|
block_label_text_color = '#4D63FF',
|
|
block_title_text_color = '#4D63FF',
|
|
button_primary_text_color = '#4D63FF',
|
|
button_primary_background_fill='#FFFFFF',
|
|
button_primary_border_color='#4D63FF',
|
|
button_primary_background_fill_hover='#EDEFFF',
|
|
)
|
|
|
|
|
|
feature_extractor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-base-coco")
|
|
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-coco")
|
|
|
|
|
|
def detect_objects(image_input):
|
|
inputs = feature_extractor(images=image_input, return_tensors="pt")
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
result = feature_extractor.post_process_panoptic_segmentation(outputs, target_sizes=[image_input.size[::-1]])[0]
|
|
|
|
#Visualize prediction
|
|
viz_img = visualize_prediction(result["segmentation"], image_input)
|
|
|
|
return viz_img
|
|
|
|
|
|
def visualize_prediction(result, image):
|
|
color_palette = [list(np.random.choice(range(256), size=3)) for _ in range(len(model.config.id2label))]
|
|
seg = result
|
|
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
|
|
palette = np.array(color_palette)
|
|
for label, color in enumerate(palette):
|
|
color_seg[seg == label, :] = color
|
|
# Convert to BGR
|
|
color_seg = color_seg[..., ::-1]
|
|
|
|
# Show image + mask
|
|
img = np.array(image) * 0.5 + color_seg * 0.5
|
|
img = img.astype(np.uint8)
|
|
|
|
return img
|
|
|
|
|
|
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
|
|
gr.Markdown("""
|
|
<div align='center' ><font size='60'>全景分割</font></div>
|
|
""")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
image = gr.Image(label="图片", type="pil")
|
|
with gr.Row():
|
|
button = gr.Button("提交", variant="primary")
|
|
box2 = gr.Image(label="图片")
|
|
|
|
button.click(fn=detect_objects, inputs=[image], outputs=box2)
|
|
examples = gr.Examples(examples=[['1.jpg'], ['2.jpg']], inputs=[image], label="例子")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
demo.queue().launch(server_name = "0.0.0.0")
|