#全景分割 from PIL import Image import io import matplotlib.pyplot as plt import torch import torchvision.transforms as T import numpy import gradio as gr import itertools import seaborn as sns from panopticapi.utils import rgb2id from gradio.themes.utils import sizes from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation from torch import nn import numpy as np theme = gr.themes.Default(radius_size=sizes.radius_none).set( block_label_text_color = '#4D63FF', block_title_text_color = '#4D63FF', button_primary_text_color = '#4D63FF', button_primary_background_fill='#FFFFFF', button_primary_border_color='#4D63FF', button_primary_background_fill_hover='#EDEFFF', ) processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") def detect_objects(image_input): inputs = processor(images=image_input, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image_input.size[::-1]])[0] #Visualize prediction viz_img = visualize_prediction(result["segmentation"], image_input) return viz_img def visualize_prediction(result, image): color_palette = [list(np.random.choice(range(256), size=3)) for _ in range(len(model.config.id2label))] seg = result color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 palette = np.array(color_palette) for label, color in enumerate(palette): color_seg[seg == label, :] = color # Convert to BGR color_seg = color_seg[..., ::-1] # Show image + mask img = np.array(image) * 0.5 + color_seg * 0.5 img = img.astype(np.uint8) return img with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo: gr.Markdown("""
全景分割
""") with gr.Row(): with gr.Column(): image = gr.Image(label="图片", type="pil") with gr.Row(): button = gr.Button("提交", variant="primary") box2 = gr.Image(label="图片") button.click(fn=detect_objects, inputs=[image], outputs=box2) examples = gr.Examples(examples=[['1.jpg'], ['2.jpg']], inputs=[image], label="例子") if __name__ == '__main__': demo.queue().launch(server_name = "0.0.0.0")