130 lines
3.8 KiB
Python
130 lines
3.8 KiB
Python
|
import matplotlib.cm as cm
|
||
|
import torch
|
||
|
import gradio as gr
|
||
|
from models.matching import Matching
|
||
|
from models.utils import (make_matching_plot_fast, process_image)
|
||
|
|
||
|
torch.set_grad_enabled(False)
|
||
|
|
||
|
# Load the SuperPoint and SuperGlue models.
|
||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
|
|
||
|
resize = [640, 640]
|
||
|
max_keypoints = 1024
|
||
|
keypoint_threshold = 0.005
|
||
|
nms_radius = 4
|
||
|
sinkhorn_iterations = 20
|
||
|
match_threshold = 0.2
|
||
|
resize_float = False
|
||
|
|
||
|
config_indoor = {
|
||
|
'superpoint': {
|
||
|
'nms_radius': nms_radius,
|
||
|
'keypoint_threshold': keypoint_threshold,
|
||
|
'max_keypoints': max_keypoints
|
||
|
},
|
||
|
'superglue': {
|
||
|
'weights': "indoor",
|
||
|
'sinkhorn_iterations': sinkhorn_iterations,
|
||
|
'match_threshold': match_threshold,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
config_outdoor = {
|
||
|
'superpoint': {
|
||
|
'nms_radius': nms_radius,
|
||
|
'keypoint_threshold': keypoint_threshold,
|
||
|
'max_keypoints': max_keypoints
|
||
|
},
|
||
|
'superglue': {
|
||
|
'weights': "outdoor",
|
||
|
'sinkhorn_iterations': sinkhorn_iterations,
|
||
|
'match_threshold': match_threshold,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
matching_indoor = Matching(config_indoor).eval().to(device)
|
||
|
matching_outdoor = Matching(config_outdoor).eval().to(device)
|
||
|
|
||
|
def run(input0, input1, superglue):
|
||
|
if superglue == "indoor":
|
||
|
matching = matching_indoor
|
||
|
else:
|
||
|
matching = matching_outdoor
|
||
|
|
||
|
name0 = 'image1'
|
||
|
name1 = 'image2'
|
||
|
|
||
|
# If a rotation integer is provided (e.g. from EXIF data), use it:
|
||
|
rot0, rot1 = 0, 0
|
||
|
|
||
|
# Load the image pair.
|
||
|
image0, inp0, scales0 = process_image(input0, device, resize, rot0, resize_float)
|
||
|
image1, inp1, scales1 = process_image(input1, device, resize, rot1, resize_float)
|
||
|
|
||
|
if image0 is None or image1 is None:
|
||
|
print('Problem reading image pair')
|
||
|
return
|
||
|
|
||
|
# Perform the matching.
|
||
|
pred = matching({'image0': inp0, 'image1': inp1})
|
||
|
pred = {k: v[0].detach().numpy() for k, v in pred.items()}
|
||
|
kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
|
||
|
matches, conf = pred['matches0'], pred['matching_scores0']
|
||
|
|
||
|
valid = matches > -1
|
||
|
mkpts0 = kpts0[valid]
|
||
|
mkpts1 = kpts1[matches[valid]]
|
||
|
mconf = conf[valid]
|
||
|
|
||
|
|
||
|
# Visualize the matches.
|
||
|
color = cm.jet(mconf)
|
||
|
text = [
|
||
|
'SuperGlue',
|
||
|
'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)),
|
||
|
'{}'.format(len(mkpts0)),
|
||
|
]
|
||
|
|
||
|
if rot0 != 0 or rot1 != 0:
|
||
|
text.append('Rotation: {}:{}'.format(rot0, rot1))
|
||
|
|
||
|
# Display extra parameter info.
|
||
|
k_thresh = matching.superpoint.config['keypoint_threshold']
|
||
|
m_thresh = matching.superglue.config['match_threshold']
|
||
|
small_text = [
|
||
|
'Keypoint Threshold: {:.4f}'.format(k_thresh),
|
||
|
'Match Threshold: {:.2f}'.format(m_thresh),
|
||
|
'Image Pair: {}:{}'.format(name0, name1),
|
||
|
]
|
||
|
|
||
|
output = make_matching_plot_fast(
|
||
|
image0, image1, kpts0, kpts1, mkpts0, mkpts1, color,
|
||
|
text, show_keypoints=True, small_text=small_text)
|
||
|
|
||
|
print('Source Image - {}, Destination Image - {}, {}, Match Percentage - {}'.format(name0, name1, text[2], len(mkpts0)/len(kpts0)))
|
||
|
return output, text[2], str((len(mkpts0)/len(kpts0))*100.0) + '%'
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
|
||
|
glue = gr.Interface(
|
||
|
fn=run,
|
||
|
inputs=[
|
||
|
gr.Image(label='Input Image'),
|
||
|
gr.Image(label='Match Image'),
|
||
|
gr.Radio(choices=["indoor", "outdoor"], value="indoor", type="value", label="SuperGlueType", interactive=True),
|
||
|
],
|
||
|
outputs=[gr.Image(
|
||
|
type="pil",
|
||
|
label="Result"),
|
||
|
gr.Textbox(label="Keypoints Matched"),
|
||
|
gr.Textbox(label="Match Percentage")
|
||
|
],
|
||
|
examples=[
|
||
|
['./taj-1.jpg', './taj-2.jpg', "outdoor"],
|
||
|
['./outdoor-1.JPEG', './outdoor-2.JPEG', "outdoor"]
|
||
|
]
|
||
|
)
|
||
|
glue.queue()
|
||
|
glue.launch(server_name = "0.0.0.0")
|