82 lines
2.0 KiB
Python
82 lines
2.0 KiB
Python
|
import gradio as gr
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import matplotlib.pyplot as plt
|
||
|
import cv2
|
||
|
import sys
|
||
|
sys.path.append("..")
|
||
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
||
|
from PIL import Image
|
||
|
import io
|
||
|
|
||
|
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
||
|
model_type = "vit_h"
|
||
|
|
||
|
device = "cuda"
|
||
|
|
||
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||
|
sam.to(device=device)
|
||
|
|
||
|
mask_generator = SamAutomaticMaskGenerator(sam)
|
||
|
|
||
|
mask_generator_2 = SamAutomaticMaskGenerator(
|
||
|
model=sam,
|
||
|
points_per_side=32,
|
||
|
pred_iou_thresh=0.86,
|
||
|
stability_score_thresh=0.92,
|
||
|
crop_n_layers=1,
|
||
|
crop_n_points_downscale_factor=2,
|
||
|
min_mask_region_area=100, # Requires open-cv to run post-processing
|
||
|
)
|
||
|
|
||
|
|
||
|
def fig2img(fig):
|
||
|
buf = io.BytesIO()
|
||
|
fig.savefig(buf)
|
||
|
buf.seek(0)
|
||
|
img = Image.open(buf)
|
||
|
return img
|
||
|
|
||
|
def show_anns(anns):
|
||
|
if len(anns) == 0:
|
||
|
return
|
||
|
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||
|
ax = plt.gca()
|
||
|
ax.set_autoscale_on(False)
|
||
|
polygons = []
|
||
|
color = []
|
||
|
for ann in sorted_anns:
|
||
|
m = ann['segmentation']
|
||
|
img = np.ones((m.shape[0], m.shape[1], 3))
|
||
|
color_mask = np.random.random((1, 3)).tolist()[0]
|
||
|
for i in range(3):
|
||
|
img[:,:,i] = color_mask[i]
|
||
|
ax.imshow(np.dstack((img, m*0.35)))
|
||
|
|
||
|
|
||
|
def segment_image(image):
|
||
|
image = image.astype('uint8')
|
||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
|
#masks = mask_generator.generate(image)
|
||
|
masks2 = mask_generator_2.generate(image)
|
||
|
|
||
|
plt.figure(figsize=(20,20))
|
||
|
plt.imshow(image)
|
||
|
#show_anns(masks)
|
||
|
show_anns(masks2)
|
||
|
plt.axis('off')
|
||
|
|
||
|
return fig2img(plt.gcf())
|
||
|
|
||
|
|
||
|
demo = gr.Interface(fn=segment_image,
|
||
|
inputs=gr.Image(),
|
||
|
outputs=gr.Image(),
|
||
|
title = "图像分割",
|
||
|
examples = ['dog.jpg'])
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
demo.queue(concurrency_count=3)
|
||
|
demo.launch(server_name = "0.0.0.0", server_port = 7027)
|