diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..06483ed --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM python:3.8-slim-buster + +WORKDIR /app + +COPY . /app + +#RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple +#RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple +RUN sed -i "s@http://deb.debian.org@http://mirrors.tuna.tsinghua.edu.cn@g" /etc/apt/sources.list +RUN apt-get clean + +RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y +RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple +RUN pip install -r requirements.txt + +CMD ["python", "app.py"] diff --git a/app.py b/app.py new file mode 100644 index 0000000..4d8a8b3 --- /dev/null +++ b/app.py @@ -0,0 +1,220 @@ +from ultralytics import YOLO +import numpy as np +import matplotlib.pyplot as plt +import gradio as gr +import cv2 +import torch +from PIL import Image + +# Load the pre-trained model +model = YOLO('checkpoints/FastSAM.pt') + +# Description +title = "
快速分割一切
" + +examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"], + ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"], + ["assets/sa_561.jpg"], ["assets/sa_192.jpg"], + ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]] + +default_example = examples[0] + +css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; } footer {visibility: hidden}" + +def fast_process(annotations, image, high_quality, device, scale): + if isinstance(annotations[0],dict): + annotations = [annotation['segmentation'] for annotation in annotations] + + + original_h = image.height + original_w = image.width + if high_quality == True: + if isinstance(annotations[0],torch.Tensor): + annotations = np.array(annotations.cpu()) + for i, mask in enumerate(annotations): + mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) + annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) + if device == 'cpu': + annotations = np.array(annotations) + inner_mask = fast_show_mask(annotations, + plt.gca(), + bbox=None, + points=None, + pointlabel=None, + retinamask=True, + target_height=original_h, + target_width=original_w) + else: + if isinstance(annotations[0],np.ndarray): + annotations = torch.from_numpy(annotations) + inner_mask = fast_show_mask_gpu(annotations, + plt.gca(), + bbox=None, + points=None, + pointlabel=None) + if isinstance(annotations, torch.Tensor): + annotations = annotations.cpu().numpy() + + if high_quality: + contour_all = [] + temp = np.zeros((original_h, original_w,1)) + for i, mask in enumerate(annotations): + if type(mask) == dict: + mask = mask['segmentation'] + annotation = mask.astype(np.uint8) + contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + contour_all.append(contour) + cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale) + color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9]) + contour_mask = temp / 255 * color.reshape(1, 1, -1) + image = image.convert('RGBA') + + overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA') + image.paste(overlay_inner, (0, 0), overlay_inner) + + if high_quality: + overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA') + image.paste(overlay_contour, (0, 0), overlay_contour) + + return image + +# CPU post process +def fast_show_mask(annotation, ax, bbox=None, + points=None, pointlabel=None, + retinamask=True, target_height=960, + target_width=960): + msak_sum = annotation.shape[0] + height = annotation.shape[1] + weight = annotation.shape[2] + # 将annotation 按照面积 排序 + areas = np.sum(annotation, axis=(1, 2)) + sorted_indices = np.argsort(areas)[::1] + annotation = annotation[sorted_indices] + + index = (annotation != 0).argmax(axis=0) + color = np.random.random((msak_sum,1,1,3)) + transparency = np.ones((msak_sum,1,1,1)) * 0.6 + visual = np.concatenate([color,transparency],axis=-1) + mask_image = np.expand_dims(annotation,-1) * visual + + mask = np.zeros((height,weight,4)) + + h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij') + indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) + # 使用向量化索引更新show的值 + mask[h_indices, w_indices, :] = mask_image[indices] + if bbox is not None: + x1, y1, x2, y2 = bbox + ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) + # draw point + if points is not None: + plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y') + plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m') + + if retinamask==False: + mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST) + + return mask + + +def fast_show_mask_gpu(annotation, ax, + bbox=None, points=None, + pointlabel=None): + msak_sum = annotation.shape[0] + height = annotation.shape[1] + weight = annotation.shape[2] + areas = torch.sum(annotation, dim=(1, 2)) + sorted_indices = torch.argsort(areas, descending=False) + annotation = annotation[sorted_indices] + # 找每个位置第一个非零值下标 + index = (annotation != 0).to(torch.long).argmax(dim=0) + color = torch.rand((msak_sum,1,1,3)).to(annotation.device) + transparency = torch.ones((msak_sum,1,1,1)).to(annotation.device) * 0.6 + visual = torch.cat([color,transparency],dim=-1) + mask_image = torch.unsqueeze(annotation,-1) * visual + # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式 + mask = torch.zeros((height,weight,4)).to(annotation.device) + h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight)) + indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) + # 使用向量化索引更新show的值 + mask[h_indices, w_indices, :] = mask_image[indices] + mask_cpu = mask.cpu().numpy() + if bbox is not None: + x1, y1, x2, y2 = bbox + ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) + # draw point + if points is not None: + plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y') + plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m') + return mask_cpu + + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def segment_image(input, input_size=1024, high_visual_quality=True, iou_threshold=0.7, conf_threshold=0.25): + input_size = int(input_size) # 确保 imgsz 是整数 + + # Thanks for the suggestion by hysts in HuggingFace. + w, h = input.size + scale = input_size / max(w, h) + new_w = int(w * scale) + new_h = int(h * scale) + input = input.resize((new_w, new_h)) + + results = model(input, device=device, retina_masks=True, iou=iou_threshold, conf=conf_threshold, imgsz=input_size) + fig = fast_process(annotations=results[0].masks.data, + image=input, high_quality=high_visual_quality, + device=device, scale=(1024 // input_size)) + return fig + + +cond_img = gr.Image(label="输入", value=default_example[0], type='pil') + +segm_img = gr.Image(label="分割后的图片", interactive=False, type='pil') + +input_size_slider = gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='输入尺寸') + +with gr.Blocks(css=css, title='快速分割一切') as demo: + with gr.Row(): + gr.Markdown(title) + + # Images + with gr.Row(variant="panel"): + with gr.Column(scale=1): + cond_img.render() + + with gr.Column(scale=1): + segm_img.render() + + # Submit & Clear + with gr.Row(): + with gr.Column(): + input_size_slider.render() + + with gr.Row(): + vis_check = gr.Checkbox(value=True, label='高质量') + + with gr.Column(): + segment_btn = gr.Button("分割一切", variant='primary') + + + gr.Examples(examples=examples, + inputs=[cond_img], + outputs=segm_img, + fn=segment_image, + cache_examples=True, + examples_per_page=4, label="示例图片") + + with gr.Column(): + with gr.Accordion("高级选项", open=False): + iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou_threshold') + conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf_threshold') + + segment_btn.click(segment_image, + inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold], + outputs=segm_img) + + +demo.queue().launch(server_name="0.0.0.0") + diff --git a/assets/sa_10039.jpg b/assets/sa_10039.jpg new file mode 100644 index 0000000..499d0df Binary files /dev/null and b/assets/sa_10039.jpg differ diff --git a/assets/sa_11025.jpg b/assets/sa_11025.jpg new file mode 100644 index 0000000..4a1677d Binary files /dev/null and b/assets/sa_11025.jpg differ diff --git a/assets/sa_1309.jpg b/assets/sa_1309.jpg new file mode 100644 index 0000000..bde0e84 Binary files /dev/null and b/assets/sa_1309.jpg differ diff --git a/assets/sa_192.jpg b/assets/sa_192.jpg new file mode 100644 index 0000000..f06f499 Binary files /dev/null and b/assets/sa_192.jpg differ diff --git a/assets/sa_414.jpg b/assets/sa_414.jpg new file mode 100644 index 0000000..eaf5c92 Binary files /dev/null and b/assets/sa_414.jpg differ diff --git a/assets/sa_561.jpg b/assets/sa_561.jpg new file mode 100644 index 0000000..a47b92e Binary files /dev/null and b/assets/sa_561.jpg differ diff --git a/assets/sa_862.jpg b/assets/sa_862.jpg new file mode 100644 index 0000000..b54c069 Binary files /dev/null and b/assets/sa_862.jpg differ diff --git a/assets/sa_8776.jpg b/assets/sa_8776.jpg new file mode 100644 index 0000000..a4abb4c Binary files /dev/null and b/assets/sa_8776.jpg differ diff --git a/checkpoints/FastSAM.pt b/checkpoints/FastSAM.pt new file mode 100644 index 0000000..f741ea7 Binary files /dev/null and b/checkpoints/FastSAM.pt differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bed70fa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +# Base----------------------------------- +matplotlib==3.2.2 +numpy +opencv-python +# Pillow>=7.1.2 +# PyYAML>=5.3.1 +# requests>=2.23.0 +# scipy>=1.4.1 +# torch +# torchvision +# tqdm>=4.64.0 + +# pandas>=1.1.4 +# seaborn>=0.11.0 + +# Ultralytics----------------------------------- +ultralytics==8.0.121 +gradio +