add beit-base-finetuned-ade-640-640

This commit is contained in:
songw 2023-04-06 16:29:37 +08:00
parent d67ec55dc2
commit 5f2f1881aa
5 changed files with 140 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 270 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 429 KiB

View File

@ -0,0 +1,13 @@
FROM python:3.7.4-slim
WORKDIR /app
COPY requirements.txt /app
RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple/
RUN pip3 install --trusted-host pypi.python.org -r requirements.txt
COPY . /app
CMD ["python", "beit.py"]

View File

@ -0,0 +1,119 @@
#语义分割
from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation
import torch
from PIL import Image
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import io
import gradio as gr
def ade_palette():
"""ADE20K palette that maps each class to RGB values."""
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]]
# Draw the bounding boxes on image.
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
# Draw the bounding boxes.
def visualize_prediction(outputs, image):
# First, rescale logits to original image size
logits = nn.functional.interpolate(outputs.logits,
size=image.size[::-1], # (height, width)
mode='bilinear',
align_corners=False)
# Second, apply argmax on the class dimension
seg = logits.argmax(dim=1)[0].cpu()
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array(ade_palette())
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
plt.figure(figsize=(15, 10))
plt.axis('off')
plt.imshow(img)
return fig2img(plt.gcf())
def detect_objects(image_input):
model_name = "microsoft/beit-base-finetuned-ade-640-640"
feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False)
model = BeitForSemanticSegmentation.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# if image comes from upload
if image_input:
image = image_input
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
outputs = model(pixel_values)
#Visualize prediction
viz_img = visualize_prediction(outputs, image)
return viz_img
demo = gr.Interface(fn = detect_objects,
inputs = gr.Image(type = 'pil'),
outputs = gr.Image(shape = (650,650)),
title = "语义分割",
allow_flagging="never",
examples = ['1.png', '2.png'])
if __name__ == '__main__':
demo.queue().launch(server_name = "0.0.0.0", server_port = 7001, max_threads=40)

View File

@ -0,0 +1,8 @@
gradio
torch
transformers
datasets
matplotlib
huggingface_hub
pillow
numpy