269 lines
9.9 KiB
Python
269 lines
9.9 KiB
Python
|
# ------------------------------------------------------------------------
|
||
|
# Grounding DINO
|
||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||
|
# ------------------------------------------------------------------------
|
||
|
|
||
|
import copy
|
||
|
import math
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from torch import Tensor, nn
|
||
|
|
||
|
|
||
|
def _get_clones(module, N, layer_share=False):
|
||
|
# import ipdb; ipdb.set_trace()
|
||
|
if layer_share:
|
||
|
return nn.ModuleList([module for i in range(N)])
|
||
|
else:
|
||
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||
|
|
||
|
|
||
|
def get_sine_pos_embed(
|
||
|
pos_tensor: torch.Tensor,
|
||
|
num_pos_feats: int = 128,
|
||
|
temperature: int = 10000,
|
||
|
exchange_xy: bool = True,
|
||
|
):
|
||
|
"""generate sine position embedding from a position tensor
|
||
|
Args:
|
||
|
pos_tensor (torch.Tensor): shape: [..., n].
|
||
|
num_pos_feats (int): projected shape for each float in the tensor.
|
||
|
temperature (int): temperature in the sine/cosine function.
|
||
|
exchange_xy (bool, optional): exchange pos x and pos y. \
|
||
|
For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
|
||
|
Returns:
|
||
|
pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
|
||
|
"""
|
||
|
scale = 2 * math.pi
|
||
|
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
|
||
|
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
|
||
|
|
||
|
def sine_func(x: torch.Tensor):
|
||
|
sin_x = x * scale / dim_t
|
||
|
sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
|
||
|
return sin_x
|
||
|
|
||
|
pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
|
||
|
if exchange_xy:
|
||
|
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
|
||
|
pos_res = torch.cat(pos_res, dim=-1)
|
||
|
return pos_res
|
||
|
|
||
|
|
||
|
def gen_encoder_output_proposals(
|
||
|
memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None
|
||
|
):
|
||
|
"""
|
||
|
Input:
|
||
|
- memory: bs, \sum{hw}, d_model
|
||
|
- memory_padding_mask: bs, \sum{hw}
|
||
|
- spatial_shapes: nlevel, 2
|
||
|
- learnedwh: 2
|
||
|
Output:
|
||
|
- output_memory: bs, \sum{hw}, d_model
|
||
|
- output_proposals: bs, \sum{hw}, 4
|
||
|
"""
|
||
|
N_, S_, C_ = memory.shape
|
||
|
proposals = []
|
||
|
_cur = 0
|
||
|
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
||
|
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1)
|
||
|
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
|
||
|
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
||
|
|
||
|
# import ipdb; ipdb.set_trace()
|
||
|
|
||
|
grid_y, grid_x = torch.meshgrid(
|
||
|
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
||
|
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
|
||
|
)
|
||
|
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
|
||
|
|
||
|
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
|
||
|
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
|
||
|
|
||
|
if learnedwh is not None:
|
||
|
# import ipdb; ipdb.set_trace()
|
||
|
wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
|
||
|
else:
|
||
|
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
|
||
|
|
||
|
# scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
|
||
|
# grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
|
||
|
# wh = torch.ones_like(grid) / scale
|
||
|
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
|
||
|
proposals.append(proposal)
|
||
|
_cur += H_ * W_
|
||
|
# import ipdb; ipdb.set_trace()
|
||
|
output_proposals = torch.cat(proposals, 1)
|
||
|
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(
|
||
|
-1, keepdim=True
|
||
|
)
|
||
|
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
|
||
|
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
|
||
|
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
|
||
|
|
||
|
output_memory = memory
|
||
|
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
|
||
|
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
|
||
|
|
||
|
# output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
|
||
|
# output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
|
||
|
|
||
|
return output_memory, output_proposals
|
||
|
|
||
|
|
||
|
class RandomBoxPerturber:
|
||
|
def __init__(
|
||
|
self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2
|
||
|
) -> None:
|
||
|
self.noise_scale = torch.Tensor(
|
||
|
[x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]
|
||
|
)
|
||
|
|
||
|
def __call__(self, refanchors: Tensor) -> Tensor:
|
||
|
nq, bs, query_dim = refanchors.shape
|
||
|
device = refanchors.device
|
||
|
|
||
|
noise_raw = torch.rand_like(refanchors)
|
||
|
noise_scale = self.noise_scale.to(device)[:query_dim]
|
||
|
|
||
|
new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
|
||
|
return new_refanchors.clamp_(0, 1)
|
||
|
|
||
|
|
||
|
def sigmoid_focal_loss(
|
||
|
inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False
|
||
|
):
|
||
|
"""
|
||
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||
|
Args:
|
||
|
inputs: A float tensor of arbitrary shape.
|
||
|
The predictions for each example.
|
||
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
||
|
classification label for each element in inputs
|
||
|
(0 for the negative class and 1 for the positive class).
|
||
|
alpha: (optional) Weighting factor in range (0,1) to balance
|
||
|
positive vs negative examples. Default = -1 (no weighting).
|
||
|
gamma: Exponent of the modulating factor (1 - p_t) to
|
||
|
balance easy vs hard examples.
|
||
|
Returns:
|
||
|
Loss tensor
|
||
|
"""
|
||
|
prob = inputs.sigmoid()
|
||
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||
|
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||
|
loss = ce_loss * ((1 - p_t) ** gamma)
|
||
|
|
||
|
if alpha >= 0:
|
||
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||
|
loss = alpha_t * loss
|
||
|
|
||
|
if no_reduction:
|
||
|
return loss
|
||
|
|
||
|
return loss.mean(1).sum() / num_boxes
|
||
|
|
||
|
|
||
|
class MLP(nn.Module):
|
||
|
"""Very simple multi-layer perceptron (also called FFN)"""
|
||
|
|
||
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||
|
super().__init__()
|
||
|
self.num_layers = num_layers
|
||
|
h = [hidden_dim] * (num_layers - 1)
|
||
|
self.layers = nn.ModuleList(
|
||
|
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
for i, layer in enumerate(self.layers):
|
||
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def _get_activation_fn(activation, d_model=256, batch_dim=0):
|
||
|
"""Return an activation function given a string"""
|
||
|
if activation == "relu":
|
||
|
return F.relu
|
||
|
if activation == "gelu":
|
||
|
return F.gelu
|
||
|
if activation == "glu":
|
||
|
return F.glu
|
||
|
if activation == "prelu":
|
||
|
return nn.PReLU()
|
||
|
if activation == "selu":
|
||
|
return F.selu
|
||
|
|
||
|
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||
|
|
||
|
|
||
|
def gen_sineembed_for_position(pos_tensor):
|
||
|
# n_query, bs, _ = pos_tensor.size()
|
||
|
# sineembed_tensor = torch.zeros(n_query, bs, 256)
|
||
|
scale = 2 * math.pi
|
||
|
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
|
||
|
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
|
||
|
x_embed = pos_tensor[:, :, 0] * scale
|
||
|
y_embed = pos_tensor[:, :, 1] * scale
|
||
|
pos_x = x_embed[:, :, None] / dim_t
|
||
|
pos_y = y_embed[:, :, None] / dim_t
|
||
|
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
||
|
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
||
|
if pos_tensor.size(-1) == 2:
|
||
|
pos = torch.cat((pos_y, pos_x), dim=2)
|
||
|
elif pos_tensor.size(-1) == 4:
|
||
|
w_embed = pos_tensor[:, :, 2] * scale
|
||
|
pos_w = w_embed[:, :, None] / dim_t
|
||
|
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
|
||
|
|
||
|
h_embed = pos_tensor[:, :, 3] * scale
|
||
|
pos_h = h_embed[:, :, None] / dim_t
|
||
|
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
|
||
|
|
||
|
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
|
||
|
else:
|
||
|
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
|
||
|
return pos
|
||
|
|
||
|
|
||
|
class ContrastiveEmbed(nn.Module):
|
||
|
def __init__(self, max_text_len=256):
|
||
|
"""
|
||
|
Args:
|
||
|
max_text_len: max length of text.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.max_text_len = max_text_len
|
||
|
|
||
|
def forward(self, x, text_dict):
|
||
|
"""_summary_
|
||
|
|
||
|
Args:
|
||
|
x (_type_): _description_
|
||
|
text_dict (_type_): _description_
|
||
|
{
|
||
|
'encoded_text': encoded_text, # bs, 195, d_model
|
||
|
'text_token_mask': text_token_mask, # bs, 195
|
||
|
# True for used tokens. False for padding tokens
|
||
|
}
|
||
|
Returns:
|
||
|
_type_: _description_
|
||
|
"""
|
||
|
assert isinstance(text_dict, dict)
|
||
|
|
||
|
y = text_dict["encoded_text"]
|
||
|
text_token_mask = text_dict["text_token_mask"]
|
||
|
|
||
|
res = x @ y.transpose(-1, -2)
|
||
|
res.masked_fill_(~text_token_mask[:, None, :], float("-inf"))
|
||
|
|
||
|
# padding to max_text_len
|
||
|
new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device)
|
||
|
new_res[..., : res.shape[-1]] = res
|
||
|
|
||
|
return new_res
|