detr-resnet-101-dc5/app.py

127 lines
4.2 KiB
Python

from transformers import DetrFeatureExtractor, DetrForObjectDetection
from PIL import Image
import requests
from torch import nn
from gradio.themes.utils import sizes
import torch
import matplotlib.pyplot as plt
import math
import gradio as gr
import torchvision.transforms as T
import io
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
block_label_text_color = '#4D63FF',
block_title_text_color = '#4D63FF',
button_primary_text_color = '#4D63FF',
button_primary_background_fill='#FFFFFF',
button_primary_border_color='#4D63FF',
button_primary_background_fill_hover='#EDEFFF',
)
# COCO classes
CLASSES = [
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
'toothbrush'
]
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
# standard PyTorch mean-std input image normalization
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def plot_results(pil_img, prob, boxes):
plt.figure(figsize=(16,10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=c, linewidth=3))
cl = p.argmax()
text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
return fig2img(plt.gcf())
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-101-dc5')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101-dc5')
def detect_object(image):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
probas = outputs.logits.softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep], image.size)
return plot_results(image, probas[keep], bboxes_scaled)
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
gr.Markdown("""
<div align='center' ><font size='60'>目标检测</font></div>
""")
with gr.Row():
with gr.Column():
box1 = gr.Image(label="图片", type="pil")
with gr.Row():
button = gr.Button("提交", variant="primary")
box2 = gr.Image(label="图片")
button.click(fn=detect_object, inputs=box1, outputs=box2)
examples = gr.Examples(examples=[['cat.jpg']], inputs=[box1], label="例子")
if __name__ == "__main__":
demo.queue(concurrency_count=3).launch(server_name = "0.0.0.0")