diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..06483ed
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,16 @@
+FROM python:3.8-slim-buster
+
+WORKDIR /app
+
+COPY . /app
+
+#RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple
+#RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple
+RUN sed -i "s@http://deb.debian.org@http://mirrors.tuna.tsinghua.edu.cn@g" /etc/apt/sources.list
+RUN apt-get clean
+
+RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
+RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
+RUN pip install -r requirements.txt
+
+CMD ["python", "app.py"]
diff --git a/app.py b/app.py
new file mode 100644
index 0000000..4d8a8b3
--- /dev/null
+++ b/app.py
@@ -0,0 +1,220 @@
+from ultralytics import YOLO
+import numpy as np
+import matplotlib.pyplot as plt
+import gradio as gr
+import cv2
+import torch
+from PIL import Image
+
+# Load the pre-trained model
+model = YOLO('checkpoints/FastSAM.pt')
+
+# Description
+title = "
快速分割一切"
+
+examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"],
+ ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
+ ["assets/sa_561.jpg"], ["assets/sa_192.jpg"],
+ ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
+
+default_example = examples[0]
+
+css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; } footer {visibility: hidden}"
+
+def fast_process(annotations, image, high_quality, device, scale):
+ if isinstance(annotations[0],dict):
+ annotations = [annotation['segmentation'] for annotation in annotations]
+
+
+ original_h = image.height
+ original_w = image.width
+ if high_quality == True:
+ if isinstance(annotations[0],torch.Tensor):
+ annotations = np.array(annotations.cpu())
+ for i, mask in enumerate(annotations):
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
+ if device == 'cpu':
+ annotations = np.array(annotations)
+ inner_mask = fast_show_mask(annotations,
+ plt.gca(),
+ bbox=None,
+ points=None,
+ pointlabel=None,
+ retinamask=True,
+ target_height=original_h,
+ target_width=original_w)
+ else:
+ if isinstance(annotations[0],np.ndarray):
+ annotations = torch.from_numpy(annotations)
+ inner_mask = fast_show_mask_gpu(annotations,
+ plt.gca(),
+ bbox=None,
+ points=None,
+ pointlabel=None)
+ if isinstance(annotations, torch.Tensor):
+ annotations = annotations.cpu().numpy()
+
+ if high_quality:
+ contour_all = []
+ temp = np.zeros((original_h, original_w,1))
+ for i, mask in enumerate(annotations):
+ if type(mask) == dict:
+ mask = mask['segmentation']
+ annotation = mask.astype(np.uint8)
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+ for contour in contours:
+ contour_all.append(contour)
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
+ image = image.convert('RGBA')
+
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
+ image.paste(overlay_inner, (0, 0), overlay_inner)
+
+ if high_quality:
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
+ image.paste(overlay_contour, (0, 0), overlay_contour)
+
+ return image
+
+# CPU post process
+def fast_show_mask(annotation, ax, bbox=None,
+ points=None, pointlabel=None,
+ retinamask=True, target_height=960,
+ target_width=960):
+ msak_sum = annotation.shape[0]
+ height = annotation.shape[1]
+ weight = annotation.shape[2]
+ # 将annotation 按照面积 排序
+ areas = np.sum(annotation, axis=(1, 2))
+ sorted_indices = np.argsort(areas)[::1]
+ annotation = annotation[sorted_indices]
+
+ index = (annotation != 0).argmax(axis=0)
+ color = np.random.random((msak_sum,1,1,3))
+ transparency = np.ones((msak_sum,1,1,1)) * 0.6
+ visual = np.concatenate([color,transparency],axis=-1)
+ mask_image = np.expand_dims(annotation,-1) * visual
+
+ mask = np.zeros((height,weight,4))
+
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
+ # 使用向量化索引更新show的值
+ mask[h_indices, w_indices, :] = mask_image[indices]
+ if bbox is not None:
+ x1, y1, x2, y2 = bbox
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
+ # draw point
+ if points is not None:
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
+
+ if retinamask==False:
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
+
+ return mask
+
+
+def fast_show_mask_gpu(annotation, ax,
+ bbox=None, points=None,
+ pointlabel=None):
+ msak_sum = annotation.shape[0]
+ height = annotation.shape[1]
+ weight = annotation.shape[2]
+ areas = torch.sum(annotation, dim=(1, 2))
+ sorted_indices = torch.argsort(areas, descending=False)
+ annotation = annotation[sorted_indices]
+ # 找每个位置第一个非零值下标
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
+ color = torch.rand((msak_sum,1,1,3)).to(annotation.device)
+ transparency = torch.ones((msak_sum,1,1,1)).to(annotation.device) * 0.6
+ visual = torch.cat([color,transparency],dim=-1)
+ mask_image = torch.unsqueeze(annotation,-1) * visual
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
+ mask = torch.zeros((height,weight,4)).to(annotation.device)
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
+ # 使用向量化索引更新show的值
+ mask[h_indices, w_indices, :] = mask_image[indices]
+ mask_cpu = mask.cpu().numpy()
+ if bbox is not None:
+ x1, y1, x2, y2 = bbox
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
+ # draw point
+ if points is not None:
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
+ return mask_cpu
+
+
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+def segment_image(input, input_size=1024, high_visual_quality=True, iou_threshold=0.7, conf_threshold=0.25):
+ input_size = int(input_size) # 确保 imgsz 是整数
+
+ # Thanks for the suggestion by hysts in HuggingFace.
+ w, h = input.size
+ scale = input_size / max(w, h)
+ new_w = int(w * scale)
+ new_h = int(h * scale)
+ input = input.resize((new_w, new_h))
+
+ results = model(input, device=device, retina_masks=True, iou=iou_threshold, conf=conf_threshold, imgsz=input_size)
+ fig = fast_process(annotations=results[0].masks.data,
+ image=input, high_quality=high_visual_quality,
+ device=device, scale=(1024 // input_size))
+ return fig
+
+
+cond_img = gr.Image(label="输入", value=default_example[0], type='pil')
+
+segm_img = gr.Image(label="分割后的图片", interactive=False, type='pil')
+
+input_size_slider = gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='输入尺寸')
+
+with gr.Blocks(css=css, title='快速分割一切') as demo:
+ with gr.Row():
+ gr.Markdown(title)
+
+ # Images
+ with gr.Row(variant="panel"):
+ with gr.Column(scale=1):
+ cond_img.render()
+
+ with gr.Column(scale=1):
+ segm_img.render()
+
+ # Submit & Clear
+ with gr.Row():
+ with gr.Column():
+ input_size_slider.render()
+
+ with gr.Row():
+ vis_check = gr.Checkbox(value=True, label='高质量')
+
+ with gr.Column():
+ segment_btn = gr.Button("分割一切", variant='primary')
+
+
+ gr.Examples(examples=examples,
+ inputs=[cond_img],
+ outputs=segm_img,
+ fn=segment_image,
+ cache_examples=True,
+ examples_per_page=4, label="示例图片")
+
+ with gr.Column():
+ with gr.Accordion("高级选项", open=False):
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou_threshold')
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf_threshold')
+
+ segment_btn.click(segment_image,
+ inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
+ outputs=segm_img)
+
+
+demo.queue().launch(server_name="0.0.0.0")
+
diff --git a/assets/sa_10039.jpg b/assets/sa_10039.jpg
new file mode 100644
index 0000000..499d0df
Binary files /dev/null and b/assets/sa_10039.jpg differ
diff --git a/assets/sa_11025.jpg b/assets/sa_11025.jpg
new file mode 100644
index 0000000..4a1677d
Binary files /dev/null and b/assets/sa_11025.jpg differ
diff --git a/assets/sa_1309.jpg b/assets/sa_1309.jpg
new file mode 100644
index 0000000..bde0e84
Binary files /dev/null and b/assets/sa_1309.jpg differ
diff --git a/assets/sa_192.jpg b/assets/sa_192.jpg
new file mode 100644
index 0000000..f06f499
Binary files /dev/null and b/assets/sa_192.jpg differ
diff --git a/assets/sa_414.jpg b/assets/sa_414.jpg
new file mode 100644
index 0000000..eaf5c92
Binary files /dev/null and b/assets/sa_414.jpg differ
diff --git a/assets/sa_561.jpg b/assets/sa_561.jpg
new file mode 100644
index 0000000..a47b92e
Binary files /dev/null and b/assets/sa_561.jpg differ
diff --git a/assets/sa_862.jpg b/assets/sa_862.jpg
new file mode 100644
index 0000000..b54c069
Binary files /dev/null and b/assets/sa_862.jpg differ
diff --git a/assets/sa_8776.jpg b/assets/sa_8776.jpg
new file mode 100644
index 0000000..a4abb4c
Binary files /dev/null and b/assets/sa_8776.jpg differ
diff --git a/checkpoints/FastSAM.pt b/checkpoints/FastSAM.pt
new file mode 100644
index 0000000..f741ea7
Binary files /dev/null and b/checkpoints/FastSAM.pt differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..bed70fa
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+# Base-----------------------------------
+matplotlib==3.2.2
+numpy
+opencv-python
+# Pillow>=7.1.2
+# PyYAML>=5.3.1
+# requests>=2.23.0
+# scipy>=1.4.1
+# torch
+# torchvision
+# tqdm>=4.64.0
+
+# pandas>=1.1.4
+# seaborn>=0.11.0
+
+# Ultralytics-----------------------------------
+ultralytics==8.0.121
+gradio
+