add
Build-Deploy-Actions Details

This commit is contained in:
songw 2023-05-15 14:16:22 +08:00
parent ab527ef40c
commit e0f3566f42
9 changed files with 2099 additions and 0 deletions

View File

@ -0,0 +1,47 @@
name: Build
run-name: ${{ github.actor }} is upgrade release 🚀
on: [push]
env:
REPOSITORY: ${{ github.repository }}
COMMIT_ID: ${{ github.sha }}
jobs:
Build-Deploy-Actions:
runs-on: ubuntu-latest
steps:
- run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event."
- run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by Gitea!"
- run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}."
- name: Check out repository code
uses: actions/checkout@v3
-
name: Setup Git LFS
run: |
git lfs install
git lfs fetch
git lfs checkout
- name: List files in the repository
run: |
ls ${{ github.workspace }}
-
name: Docker Image Info
id: image-info
run: |
echo "::set-output name=image_name::$(echo $REPOSITORY | tr '[:upper:]' '[:lower:]')"
echo "::set-output name=image_tag::${COMMIT_ID:0:10}"
-
name: Login to Docker Hub
uses: docker/login-action@v2
with:
registry: artifacts.iflytek.com
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
-
name: Build and push
run: |
docker version
docker buildx build -t artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }} . --file ${{ github.workspace }}/Dockerfile --load
docker push artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }}
docker rmi artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }}
- run: echo "🍏 This job's status is ${{ job.status }}."

62
app.py Normal file
View File

@ -0,0 +1,62 @@
import data
import torch
import gradio as gr
from models import imagebind_model
from models.imagebind_model import ModalityType
from gradio.themes.utils import sizes
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
block_label_text_color = '#4D63FF',
block_title_text_color = '#4D63FF',
button_primary_text_color = '#4D63FF',
button_primary_background_fill='#FFFFFF',
button_primary_border_color='#4D63FF',
button_primary_background_fill_hover='#EDEFFF',
)
css = "footer {visibility: hidden}"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
def audio_text(audio, text_list):
audio_paths = [audio]
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")]
inputs = {
ModalityType.TEXT: data.load_and_transform_text(labels, device),
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}
with torch.no_grad():
embeddings = model(inputs)
scores = torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1).squeeze(0).tolist()
score_dict = {label:score for label, score in zip(labels, scores)}
print(score_dict)
return score_dict
with gr.Blocks(theme=theme, css=css) as demo:
gr.Markdown("""
<div align='center' ><font size='60'>音频分类</font></div>
""")
with gr.Row():
with gr.Column():
audio = gr.inputs.Audio(type='filepath',label="音频输入")
text = gr.inputs.Textbox(lines=1,label="类别")
with gr.Row():
button = gr.Button("提交", variant="primary")
outputs = gr.Label(label="类别")
button.click(fn=audio_text, inputs=[audio, text], outputs=outputs)
examples = gr.Examples(examples=[[".assets/dog_audio.wav", "A dog|A car|A bird"],[".assets/car_audio.wav", "A dog|A car|A bird"], [".assets/bird_audio.wav", "A dog|A car|A bird"]],inputs=[audio, text], label="例子")
if __name__ == "__main__":
demo.queue().launch(server_name = "0.0.0.0")

Binary file not shown.

350
data.py Normal file
View File

@ -0,0 +1,350 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torchaudio
import logging
from models.multimodal_preprocessors import SimpleTokenizer
from PIL import Image
from pytorchvideo import transforms as pv_transforms
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
from pytorchvideo.data.encoded_video import EncodedVideo
from torchvision import transforms
from torchvision.transforms._transforms_video import NormalizeVideo
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
# Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
waveform -= waveform.mean()
fbank = torchaudio.compliance.kaldi.fbank(
waveform,
htk_compat=True,
sample_frequency=sample_rate,
use_energy=False,
window_type="hanning",
num_mel_bins=num_mel_bins,
dither=0.0,
frame_length=25,
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
)
# Convert to [mel_bins, num_frames] shape
fbank = fbank.transpose(0, 1)
# Pad to target_length
n_frames = fbank.size(1)
p = target_length - n_frames
# if p is too large (say >20%), flash a warning
if abs(p) / n_frames > 0.2:
logging.warning(
"Large gap between audio n_frames(%d) and "
"target_length (%d). Is the audio_target_length "
"setting correct?",
n_frames,
target_length,
)
# cut and pad
if p > 0:
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
elif p < 0:
fbank = fbank[:, 0:target_length]
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1
# channel image
fbank = fbank.unsqueeze(0)
return fbank
def get_clip_timepoints(clip_sampler, duration):
# Read out all clips in this video
all_clips_timepoints = []
is_last_clip = False
end = 0.0
while not is_last_clip:
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
all_clips_timepoints.append((start, end))
return all_clips_timepoints
def load_and_transform_vision_data(image_paths, device):
if image_paths is None:
return None
image_ouputs = []
for image_path in image_paths:
data_transform = transforms.Compose(
[
transforms.Resize(
224, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
with open(image_path, "rb") as fopen:
image = Image.open(fopen).convert("RGB")
image = data_transform(image).to(device)
image_ouputs.append(image)
return torch.stack(image_ouputs, dim=0)
def load_and_transform_text(text, device):
if text is None:
return None
tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
tokens = torch.cat(tokens, dim=0)
return tokens
def load_and_transform_audio_data(
audio_paths,
device,
num_mel_bins=128,
target_length=204,
sample_rate=16000,
clip_duration=2,
clips_per_video=3,
mean=-4.268,
std=9.138,
):
if audio_paths is None:
return None
audio_outputs = []
clip_sampler = ConstantClipsPerVideoSampler(
clip_duration=clip_duration, clips_per_video=clips_per_video
)
for audio_path in audio_paths:
waveform, sr = torchaudio.load(audio_path)
if sample_rate != sr:
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=sample_rate
)
all_clips_timepoints = get_clip_timepoints(
clip_sampler, waveform.size(1) / sample_rate
)
all_clips = []
for clip_timepoints in all_clips_timepoints:
waveform_clip = waveform[
:,
int(clip_timepoints[0] * sample_rate) : int(
clip_timepoints[1] * sample_rate
),
]
waveform_melspec = waveform2melspec(
waveform_clip, sample_rate, num_mel_bins, target_length
)
all_clips.append(waveform_melspec)
normalize = transforms.Normalize(mean=mean, std=std)
all_clips = [normalize(ac).to(device) for ac in all_clips]
all_clips = torch.stack(all_clips, dim=0)
audio_outputs.append(all_clips)
return torch.stack(audio_outputs, dim=0)
def get_clip_timepoints(clip_sampler, duration):
# Read out all clips in this video
all_clips_timepoints = []
is_last_clip = False
end = 0.0
while not is_last_clip:
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
all_clips_timepoints.append((start, end))
return all_clips_timepoints
def crop_boxes(boxes, x_offset, y_offset):
"""
Peform crop on the bounding boxes given the offsets.
Args:
boxes (ndarray or None): bounding boxes to peform crop. The dimension
is `num boxes` x 4.
x_offset (int): cropping offset in the x axis.
y_offset (int): cropping offset in the y axis.
Returns:
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
cropped_boxes = boxes.copy()
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
return cropped_boxes
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
"""
Perform uniform spatial sampling on the images and corresponding boxes.
Args:
images (tensor): images to perform uniform crop. The dimension is
`num frames` x `channel` x `height` x `width`.
size (int): size of height and weight to crop the images.
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
is larger than height. Or 0, 1, or 2 for top, center, and bottom
crop if height is larger than width.
boxes (ndarray or None): optional. Corresponding boxes to images.
Dimension is `num boxes` x 4.
scale_size (int): optinal. If not None, resize the images to scale_size before
performing any crop.
Returns:
cropped (tensor): images with dimension of
`num frames` x `channel` x `size` x `size`.
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
assert spatial_idx in [0, 1, 2]
ndim = len(images.shape)
if ndim == 3:
images = images.unsqueeze(0)
height = images.shape[2]
width = images.shape[3]
if scale_size is not None:
if width <= height:
width, height = scale_size, int(height / width * scale_size)
else:
width, height = int(width / height * scale_size), scale_size
images = torch.nn.functional.interpolate(
images,
size=(height, width),
mode="bilinear",
align_corners=False,
)
y_offset = int(math.ceil((height - size) / 2))
x_offset = int(math.ceil((width - size) / 2))
if height > width:
if spatial_idx == 0:
y_offset = 0
elif spatial_idx == 2:
y_offset = height - size
else:
if spatial_idx == 0:
x_offset = 0
elif spatial_idx == 2:
x_offset = width - size
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
if ndim == 3:
cropped = cropped.squeeze(0)
return cropped, cropped_boxes
class SpatialCrop(nn.Module):
"""
Convert the video into 3 smaller clips spatially. Must be used after the
temporal crops to get spatial crops, and should be used with
-2 in the spatial crop at the slowfast augmentation stage (so full
frames are passed in here). Will return a larger list with the
3x spatial crops as well.
"""
def __init__(self, crop_size: int = 224, num_crops: int = 3):
super().__init__()
self.crop_size = crop_size
if num_crops == 3:
self.crops_to_ext = [0, 1, 2]
self.flipped_crops_to_ext = []
elif num_crops == 1:
self.crops_to_ext = [1]
self.flipped_crops_to_ext = []
else:
raise NotImplementedError("Nothing else supported yet")
def forward(self, videos):
"""
Args:
videos: A list of C, T, H, W videos.
Returns:
videos: A list with 3x the number of elements. Each video converted
to C, T, H', W' by spatial cropping.
"""
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
res = []
for video in videos:
for spatial_idx in self.crops_to_ext:
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
if not self.flipped_crops_to_ext:
continue
flipped_video = transforms.functional.hflip(video)
for spatial_idx in self.flipped_crops_to_ext:
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
return res
def load_and_transform_video_data(
video_paths,
device,
clip_duration=2,
clips_per_video=5,
sample_rate=16000,
):
if video_paths is None:
return None
video_outputs = []
video_transform = transforms.Compose(
[
pv_transforms.ShortSideScale(224),
NormalizeVideo(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
clip_sampler = ConstantClipsPerVideoSampler(
clip_duration=clip_duration, clips_per_video=clips_per_video
)
frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
for video_path in video_paths:
video = EncodedVideo.from_path(
video_path,
decoder="decord",
decode_audio=False,
**{"sample_rate": sample_rate},
)
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
all_video = []
for clip_timepoints in all_clips_timepoints:
# Read the clip, get frames
clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
if clip is None:
raise ValueError("No clip found")
video_clip = frame_sampler(clip["video"])
video_clip = video_clip / 255.0 # since this is float, need 0-1
all_video.append(video_clip)
all_video = [video_transform(clip) for clip in all_video]
all_video = SpatialCrop(224, num_crops=3)(all_video)
all_video = torch.stack(all_video, dim=0)
video_outputs.append(all_video)
return torch.stack(video_outputs, dim=0).to(device)

141
models/helpers.py Normal file
View File

@ -0,0 +1,141 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import einops
import numpy as np
import torch
import torch.nn as nn
class Normalize(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x):
return torch.nn.functional.normalize(x, dim=self.dim, p=2)
class LearnableLogitScaling(nn.Module):
def __init__(
self,
logit_scale_init: float = 1 / 0.07,
learnable: bool = True,
max_logit_scale: float = 100,
) -> None:
super().__init__()
self.max_logit_scale = max_logit_scale
self.logit_scale_init = logit_scale_init
self.learnable = learnable
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
if learnable:
self.log_logit_scale = nn.Parameter(log_logit_scale)
else:
self.register_buffer("log_logit_scale", log_logit_scale)
def forward(self, x):
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
def extra_repr(self):
st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
return st
class EinOpsRearrange(nn.Module):
def __init__(self, rearrange_expr: str, **kwargs) -> None:
super().__init__()
self.rearrange_expr = rearrange_expr
self.kwargs = kwargs
def forward(self, x):
assert isinstance(x, torch.Tensor)
return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
class VerboseNNModule(nn.Module):
"""
Wrapper around nn.Module that prints registered buffers and parameter names.
"""
@staticmethod
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
st = (
"("
+ name
+ "): "
+ "tensor("
+ str(tuple(tensor[1].shape))
+ ", requires_grad="
+ str(tensor[1].requires_grad)
+ ")\n"
)
return st
def extra_repr(self) -> str:
named_modules = set()
for p in self.named_modules():
named_modules.update([p[0]])
named_modules = list(named_modules)
string_repr = ""
for p in self.named_parameters():
name = p[0].split(".")[0]
if name not in named_modules:
string_repr += self.get_readable_tensor_repr(name, p)
for p in self.named_buffers():
name = p[0].split(".")[0]
string_repr += self.get_readable_tensor_repr(name, p)
return string_repr
def cast_if_src_dtype(
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
):
updated = False
if tensor.dtype == src_dtype:
tensor = tensor.to(dtype=tgt_dtype)
updated = True
return tensor, updated
class QuickGELU(nn.Module):
# From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class SelectElement(nn.Module):
def __init__(self, index) -> None:
super().__init__()
self.index = index
def forward(self, x):
assert x.ndim >= 3
return x[:, self.index, ...]
class SelectEOSAndProject(nn.Module):
"""
Text Pooling used in OpenCLIP
"""
def __init__(self, proj: nn.Module) -> None:
super().__init__()
self.proj = proj
def forward(self, x, seq_len):
assert x.ndim == 3
# x is of shape B x L x D
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), seq_len]
x = self.proj(x)
return x

517
models/imagebind_model.py Normal file
View File

@ -0,0 +1,517 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import urllib
from functools import partial
from types import SimpleNamespace
import torch
import torch.nn as nn
from models.helpers import (
EinOpsRearrange,
LearnableLogitScaling,
Normalize,
SelectElement,
SelectEOSAndProject,
)
from models.multimodal_preprocessors import (
AudioPreprocessor,
IMUPreprocessor,
PadIm2Video,
PatchEmbedGeneric,
RGBDTPreprocessor,
SpatioTemporalPosEmbeddingHelper,
TextPreprocessor,
ThermalPreprocessor,
)
from models.transformer import MultiheadAttention, SimpleTransformer
ModalityType = SimpleNamespace(
VISION="vision",
TEXT="text",
AUDIO="audio",
THERMAL="thermal",
DEPTH="depth",
IMU="imu",
)
class ImageBindModel(nn.Module):
def __init__(
self,
video_frames=2,
kernel_size=(2, 14, 14),
audio_kernel_size=16,
audio_stride=10,
out_embed_dim=768,
vision_embed_dim=1024,
vision_num_blocks=24,
vision_num_heads=16,
audio_embed_dim=768,
audio_num_blocks=12,
audio_num_heads=12,
audio_num_mel_bins=128,
audio_target_len=204,
audio_drop_path=0.1,
text_embed_dim=768,
text_num_blocks=12,
text_num_heads=12,
depth_embed_dim=384,
depth_kernel_size=16,
depth_num_blocks=12,
depth_num_heads=8,
depth_drop_path=0.0,
thermal_embed_dim=768,
thermal_kernel_size=16,
thermal_num_blocks=12,
thermal_num_heads=12,
thermal_drop_path=0.0,
imu_embed_dim=512,
imu_kernel_size=8,
imu_num_blocks=6,
imu_num_heads=8,
imu_drop_path=0.7,
):
super().__init__()
self.modality_preprocessors = self._create_modality_preprocessors(
video_frames,
vision_embed_dim,
kernel_size,
text_embed_dim,
audio_embed_dim,
audio_kernel_size,
audio_stride,
audio_num_mel_bins,
audio_target_len,
depth_embed_dim,
depth_kernel_size,
thermal_embed_dim,
thermal_kernel_size,
imu_embed_dim,
)
self.modality_trunks = self._create_modality_trunks(
vision_embed_dim,
vision_num_blocks,
vision_num_heads,
text_embed_dim,
text_num_blocks,
text_num_heads,
audio_embed_dim,
audio_num_blocks,
audio_num_heads,
audio_drop_path,
depth_embed_dim,
depth_num_blocks,
depth_num_heads,
depth_drop_path,
thermal_embed_dim,
thermal_num_blocks,
thermal_num_heads,
thermal_drop_path,
imu_embed_dim,
imu_num_blocks,
imu_num_heads,
imu_drop_path,
)
self.modality_heads = self._create_modality_heads(
out_embed_dim,
vision_embed_dim,
text_embed_dim,
audio_embed_dim,
depth_embed_dim,
thermal_embed_dim,
imu_embed_dim,
)
self.modality_postprocessors = self._create_modality_postprocessors(
out_embed_dim
)
def _create_modality_preprocessors(
self,
video_frames=2,
vision_embed_dim=1024,
kernel_size=(2, 14, 14),
text_embed_dim=768,
audio_embed_dim=768,
audio_kernel_size=16,
audio_stride=10,
audio_num_mel_bins=128,
audio_target_len=204,
depth_embed_dim=768,
depth_kernel_size=16,
thermal_embed_dim=768,
thermal_kernel_size=16,
imu_embed_dim=512,
):
rgbt_stem = PatchEmbedGeneric(
proj_stem=[
PadIm2Video(pad_type="repeat", ntimes=2),
nn.Conv3d(
in_channels=3,
kernel_size=kernel_size,
out_channels=vision_embed_dim,
stride=kernel_size,
bias=False,
),
]
)
rgbt_preprocessor = RGBDTPreprocessor(
img_size=[3, video_frames, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
rgbt_stem=rgbt_stem,
depth_stem=None,
)
text_preprocessor = TextPreprocessor(
context_length=77,
vocab_size=49408,
embed_dim=text_embed_dim,
causal_masking=True,
)
audio_stem = PatchEmbedGeneric(
proj_stem=[
nn.Conv2d(
in_channels=1,
kernel_size=audio_kernel_size,
stride=audio_stride,
out_channels=audio_embed_dim,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
)
audio_preprocessor = AudioPreprocessor(
img_size=[1, audio_num_mel_bins, audio_target_len],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
audio_stem=audio_stem,
)
depth_stem = PatchEmbedGeneric(
[
nn.Conv2d(
kernel_size=depth_kernel_size,
in_channels=1,
out_channels=depth_embed_dim,
stride=depth_kernel_size,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
)
depth_preprocessor = RGBDTPreprocessor(
img_size=[1, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
rgbt_stem=None,
depth_stem=depth_stem,
)
thermal_stem = PatchEmbedGeneric(
[
nn.Conv2d(
kernel_size=thermal_kernel_size,
in_channels=1,
out_channels=thermal_embed_dim,
stride=thermal_kernel_size,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
)
thermal_preprocessor = ThermalPreprocessor(
img_size=[1, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
thermal_stem=thermal_stem,
)
imu_stem = PatchEmbedGeneric(
[
nn.Linear(
in_features=48,
out_features=imu_embed_dim,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
)
imu_preprocessor = IMUPreprocessor(
img_size=[6, 2000],
num_cls_tokens=1,
kernel_size=8,
embed_dim=imu_embed_dim,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
imu_stem=imu_stem,
)
modality_preprocessors = {
ModalityType.VISION: rgbt_preprocessor,
ModalityType.TEXT: text_preprocessor,
ModalityType.AUDIO: audio_preprocessor,
ModalityType.DEPTH: depth_preprocessor,
ModalityType.THERMAL: thermal_preprocessor,
ModalityType.IMU: imu_preprocessor,
}
return nn.ModuleDict(modality_preprocessors)
def _create_modality_trunks(
self,
vision_embed_dim=1024,
vision_num_blocks=24,
vision_num_heads=16,
text_embed_dim=768,
text_num_blocks=12,
text_num_heads=12,
audio_embed_dim=768,
audio_num_blocks=12,
audio_num_heads=12,
audio_drop_path=0.0,
depth_embed_dim=768,
depth_num_blocks=12,
depth_num_heads=12,
depth_drop_path=0.0,
thermal_embed_dim=768,
thermal_num_blocks=12,
thermal_num_heads=12,
thermal_drop_path=0.0,
imu_embed_dim=512,
imu_num_blocks=6,
imu_num_heads=8,
imu_drop_path=0.7,
):
def instantiate_trunk(
embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
):
return SimpleTransformer(
embed_dim=embed_dim,
num_blocks=num_blocks,
ffn_dropout_rate=0.0,
drop_path_rate=drop_path,
attn_target=partial(
MultiheadAttention,
embed_dim=embed_dim,
num_heads=num_heads,
bias=True,
add_bias_kv=add_bias_kv,
),
pre_transformer_layer=nn.Sequential(
nn.LayerNorm(embed_dim, eps=1e-6)
if pre_transformer_ln
else nn.Identity(),
EinOpsRearrange("b l d -> l b d"),
),
post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
)
modality_trunks = {}
modality_trunks[ModalityType.VISION] = instantiate_trunk(
vision_embed_dim,
vision_num_blocks,
vision_num_heads,
pre_transformer_ln=True,
add_bias_kv=False,
drop_path=0.0,
)
modality_trunks[ModalityType.TEXT] = instantiate_trunk(
text_embed_dim,
text_num_blocks,
text_num_heads,
pre_transformer_ln=False,
add_bias_kv=False,
drop_path=0.0,
)
modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
audio_embed_dim,
audio_num_blocks,
audio_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=audio_drop_path,
)
modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
depth_embed_dim,
depth_num_blocks,
depth_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=depth_drop_path,
)
modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
thermal_embed_dim,
thermal_num_blocks,
thermal_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=thermal_drop_path,
)
modality_trunks[ModalityType.IMU] = instantiate_trunk(
imu_embed_dim,
imu_num_blocks,
imu_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=imu_drop_path,
)
return nn.ModuleDict(modality_trunks)
def _create_modality_heads(
self,
out_embed_dim,
vision_embed_dim,
text_embed_dim,
audio_embed_dim,
depth_embed_dim,
thermal_embed_dim,
imu_embed_dim,
):
modality_heads = {}
modality_heads[ModalityType.VISION] = nn.Sequential(
nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
proj=nn.Sequential(
nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
nn.Linear(text_embed_dim, out_embed_dim, bias=False),
)
)
modality_heads[ModalityType.AUDIO] = nn.Sequential(
nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.DEPTH] = nn.Sequential(
nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.THERMAL] = nn.Sequential(
nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.IMU] = nn.Sequential(
nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Dropout(p=0.5),
nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
)
return nn.ModuleDict(modality_heads)
def _create_modality_postprocessors(self, out_embed_dim):
modality_postprocessors = {}
modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
Normalize(dim=-1), LearnableLogitScaling(learnable=True)
)
modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
)
modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
)
modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
)
modality_postprocessors[ModalityType.IMU] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
)
return nn.ModuleDict(modality_postprocessors)
def forward(self, inputs):
outputs = {}
for modality_key, modality_value in inputs.items():
reduce_list = (
modality_value.ndim >= 5
) # Audio and Video inputs consist of multiple clips
if reduce_list:
B, S = modality_value.shape[:2]
modality_value = modality_value.reshape(
B * S, *modality_value.shape[2:]
)
if modality_value is not None:
modality_value = self.modality_preprocessors[modality_key](
**{modality_key: modality_value}
)
trunk_inputs = modality_value["trunk"]
head_inputs = modality_value["head"]
modality_value = self.modality_trunks[modality_key](**trunk_inputs)
modality_value = self.modality_heads[modality_key](
modality_value, **head_inputs
)
modality_value = self.modality_postprocessors[modality_key](
modality_value
)
if reduce_list:
modality_value = modality_value.reshape(B, S, -1)
modality_value = modality_value.mean(dim=1)
outputs[modality_key] = modality_value
return outputs
def imagebind_huge(pretrained=False):
model = ImageBindModel(
vision_embed_dim=1280,
vision_num_blocks=32,
vision_num_heads=16,
text_embed_dim=1024,
text_num_blocks=24,
text_num_heads=16,
out_embed_dim=1024,
audio_drop_path=0.1,
imu_drop_path=0.7,
)
if pretrained:
if not os.path.exists(".checkpoints/imagebind_huge.pth"):
print(
"Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..."
)
os.makedirs(".checkpoints", exist_ok=True)
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
".checkpoints/imagebind_huge.pth",
progress=True,
)
model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth"))
return model

View File

@ -0,0 +1,687 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import gzip
import html
import io
import math
from functools import lru_cache
from typing import Callable, List, Optional
import ftfy
import numpy as np
import regex as re
import torch
import torch.nn as nn
from iopath.common.file_io import g_pathmgr
from timm.models.layers import trunc_normal_
from models.helpers import cast_if_src_dtype, VerboseNNModule
def get_sinusoid_encoding_table(n_position, d_hid):
"""Sinusoid position encoding table"""
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
N = pos_embed.shape[1]
if N == target_spatial_size:
return pos_embed
dim = pos_embed.shape[-1]
# nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
pos_embed = nn.functional.interpolate(
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
0, 3, 1, 2
),
scale_factor=math.sqrt(target_spatial_size / N),
mode="bicubic",
)
if updated:
pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return pos_embed
def interpolate_pos_encoding(
npatch_per_img,
pos_embed,
patches_layout,
input_shape=None,
first_patch_idx=1,
):
assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
if npatch_per_img == N:
return pos_embed
assert (
patches_layout[-1] == patches_layout[-2]
), "Interpolation of pos embed not supported for non-square layouts"
class_emb = pos_embed[:, :first_patch_idx]
pos_embed = pos_embed[:, first_patch_idx:]
if input_shape is None or patches_layout[0] == 1:
# simple 2D pos embedding, no temporal component
pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
elif patches_layout[0] > 1:
# pos embed has a temporal component
assert len(input_shape) == 4, "temporal interpolation not supported"
# we only support 2D interpolation in this case
num_frames = patches_layout[0]
num_spatial_tokens = patches_layout[1] * patches_layout[2]
pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
# interpolate embedding for zeroth frame
pos_embed = interpolate_pos_encoding_2d(
npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
)
else:
raise ValueError("This type of interpolation isn't implemented")
return torch.cat((class_emb, pos_embed), dim=1)
def _get_pos_embedding(
npatch_per_img,
pos_embed,
patches_layout,
input_shape,
first_patch_idx=1,
):
pos_embed = interpolate_pos_encoding(
npatch_per_img,
pos_embed,
patches_layout,
input_shape=input_shape,
first_patch_idx=first_patch_idx,
)
return pos_embed
class PatchEmbedGeneric(nn.Module):
"""
PatchEmbed from Hydra
"""
def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
super().__init__()
if len(proj_stem) > 1:
self.proj = nn.Sequential(*proj_stem)
else:
# Special case to be able to load pre-trained models that were
# trained with a standard stem
self.proj = proj_stem[0]
self.norm_layer = norm_layer
def get_patch_layout(self, img_size):
with torch.no_grad():
dummy_img = torch.zeros(
[
1,
]
+ img_size
)
dummy_out = self.proj(dummy_img)
embed_dim = dummy_out.shape[1]
patches_layout = tuple(dummy_out.shape[2:])
num_patches = np.prod(patches_layout)
return patches_layout, num_patches, embed_dim
def forward(self, x):
x = self.proj(x)
# B C (T) H W -> B (T)HW C
x = x.flatten(2).transpose(1, 2)
if self.norm_layer is not None:
x = self.norm_layer(x)
return x
class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
def __init__(
self,
patches_layout: List,
num_patches: int,
num_cls_tokens: int,
embed_dim: int,
learnable: bool,
) -> None:
super().__init__()
self.num_cls_tokens = num_cls_tokens
self.patches_layout = patches_layout
self.num_patches = num_patches
self.num_tokens = num_cls_tokens + num_patches
self.learnable = learnable
if self.learnable:
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
trunc_normal_(self.pos_embed, std=0.02)
else:
self.register_buffer(
"pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
)
def get_pos_embedding(self, vision_input, all_vision_tokens):
input_shape = vision_input.shape
pos_embed = _get_pos_embedding(
all_vision_tokens.size(1) - self.num_cls_tokens,
pos_embed=self.pos_embed,
patches_layout=self.patches_layout,
input_shape=input_shape,
first_patch_idx=self.num_cls_tokens,
)
return pos_embed
class RGBDTPreprocessor(VerboseNNModule):
def __init__(
self,
rgbt_stem: PatchEmbedGeneric,
depth_stem: PatchEmbedGeneric,
img_size: List = (3, 224, 224),
num_cls_tokens: int = 1,
pos_embed_fn: Callable = None,
use_type_embed: bool = False,
init_param_style: str = "openclip",
) -> None:
super().__init__()
stem = rgbt_stem if rgbt_stem is not None else depth_stem
(
self.patches_layout,
self.num_patches,
self.embed_dim,
) = stem.get_patch_layout(img_size)
self.rgbt_stem = rgbt_stem
self.depth_stem = depth_stem
self.use_pos_embed = pos_embed_fn is not None
self.use_type_embed = use_type_embed
self.num_cls_tokens = num_cls_tokens
if self.use_pos_embed:
self.pos_embedding_helper = pos_embed_fn(
patches_layout=self.patches_layout,
num_cls_tokens=num_cls_tokens,
num_patches=self.num_patches,
embed_dim=self.embed_dim,
)
if self.num_cls_tokens > 0:
self.cls_token = nn.Parameter(
torch.zeros(1, self.num_cls_tokens, self.embed_dim)
)
if self.use_type_embed:
self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.init_parameters(init_param_style)
@torch.no_grad()
def init_parameters(self, init_param_style):
if init_param_style == "openclip":
# OpenCLIP style initialization
scale = self.embed_dim**-0.5
if self.use_pos_embed:
nn.init.normal_(self.pos_embedding_helper.pos_embed)
self.pos_embedding_helper.pos_embed *= scale
if self.num_cls_tokens > 0:
nn.init.normal_(self.cls_token)
self.cls_token *= scale
elif init_param_style == "vit":
self.cls_token.data.fill_(0)
else:
raise ValueError(f"Unknown init {init_param_style}")
if self.use_type_embed:
nn.init.normal_(self.type_embed)
def tokenize_input_and_cls_pos(self, input, stem, mask):
# tokens is of shape B x L x D
tokens = stem(input)
assert tokens.ndim == 3
assert tokens.shape[2] == self.embed_dim
B = tokens.shape[0]
if self.num_cls_tokens > 0:
class_tokens = self.cls_token.expand(
B, -1, -1
) # stole class_tokens impl from Phil Wang, thanks
tokens = torch.cat((class_tokens, tokens), dim=1)
if self.use_pos_embed:
pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
tokens = tokens + pos_embed
if self.use_type_embed:
tokens = tokens + self.type_embed.expand(B, -1, -1)
return tokens
def forward(self, vision=None, depth=None, patch_mask=None):
if patch_mask is not None:
raise NotImplementedError()
if vision is not None:
vision_tokens = self.tokenize_input_and_cls_pos(
vision, self.rgbt_stem, patch_mask
)
if depth is not None:
depth_tokens = self.tokenize_input_and_cls_pos(
depth, self.depth_stem, patch_mask
)
# aggregate tokens
if vision is not None and depth is not None:
final_tokens = vision_tokens + depth_tokens
else:
final_tokens = vision_tokens if vision is not None else depth_tokens
return_dict = {
"trunk": {
"tokens": final_tokens,
},
"head": {},
}
return return_dict
class AudioPreprocessor(RGBDTPreprocessor):
def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
def forward(self, audio=None):
return super().forward(vision=audio)
class ThermalPreprocessor(RGBDTPreprocessor):
def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
def forward(self, thermal=None):
return super().forward(vision=thermal)
def build_causal_attention_mask(context_length):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(context_length, context_length, requires_grad=False)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
class TextPreprocessor(VerboseNNModule):
def __init__(
self,
vocab_size: int,
context_length: int,
embed_dim: int,
causal_masking: bool,
supply_seq_len_to_head: bool = True,
num_cls_tokens: int = 0,
init_param_style: str = "openclip",
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.context_length = context_length
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_embed = nn.Parameter(
torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
)
self.causal_masking = causal_masking
if self.causal_masking:
mask = build_causal_attention_mask(self.context_length)
# register the mask as a buffer so it can be moved to the right device
self.register_buffer("mask", mask)
self.supply_seq_len_to_head = supply_seq_len_to_head
self.num_cls_tokens = num_cls_tokens
self.embed_dim = embed_dim
if num_cls_tokens > 0:
assert self.causal_masking is False, "Masking + CLS token isn't implemented"
self.cls_token = nn.Parameter(
torch.zeros(1, self.num_cls_tokens, embed_dim)
)
self.init_parameters(init_param_style)
@torch.no_grad()
def init_parameters(self, init_param_style="openclip"):
# OpenCLIP style initialization
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.pos_embed, std=0.01)
if init_param_style == "openclip":
# OpenCLIP style initialization
scale = self.embed_dim**-0.5
if self.num_cls_tokens > 0:
nn.init.normal_(self.cls_token)
self.cls_token *= scale
elif init_param_style == "vit":
self.cls_token.data.fill_(0)
else:
raise ValueError(f"Unknown init {init_param_style}")
def forward(self, text):
# text tokens are of shape B x L x D
text_tokens = self.token_embedding(text)
# concat CLS tokens if any
if self.num_cls_tokens > 0:
B = text_tokens.shape[0]
class_tokens = self.cls_token.expand(
B, -1, -1
) # stole class_tokens impl from Phil Wang, thanks
text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
text_tokens = text_tokens + self.pos_embed
return_dict = {
"trunk": {
"tokens": text_tokens,
},
"head": {},
}
# Compute sequence length after adding CLS tokens
if self.supply_seq_len_to_head:
text_lengths = text.argmax(dim=-1)
return_dict["head"] = {
"seq_len": text_lengths,
}
if self.causal_masking:
return_dict["trunk"].update({"attn_mask": self.mask})
return return_dict
class Im2Video(nn.Module):
"""Convert an image into a trivial video."""
def __init__(self, time_dim=2):
super().__init__()
self.time_dim = time_dim
def forward(self, x):
if x.ndim == 4:
# B, C, H, W -> B, C, T, H, W
return x.unsqueeze(self.time_dim)
elif x.ndim == 5:
return x
else:
raise ValueError(f"Dimension incorrect {x.shape}")
class PadIm2Video(Im2Video):
def __init__(self, ntimes, pad_type, time_dim=2):
super().__init__(time_dim=time_dim)
assert ntimes > 0
assert pad_type in ["zero", "repeat"]
self.ntimes = ntimes
self.pad_type = pad_type
def forward(self, x):
x = super().forward(x)
if x.shape[self.time_dim] == 1:
if self.pad_type == "repeat":
new_shape = [1] * len(x.shape)
new_shape[self.time_dim] = self.ntimes
x = x.repeat(new_shape)
elif self.pad_type == "zero":
padarg = [0, 0] * len(x.shape)
padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
x = nn.functional.pad(x, padarg)
return x
# Modified from github.com/openai/CLIP
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str, context_length=77):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with g_pathmgr.open(bpe_path, "rb") as fh:
bpe_bytes = io.BytesIO(fh.read())
merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
"<|startoftext|>": "<|startoftext|>",
"<|endoftext|>": "<|endoftext|>",
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)
self.context_length = context_length
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = (
bytearray([self.byte_decoder[c] for c in text])
.decode("utf-8", errors="replace")
.replace("</w>", " ")
)
return text
def __call__(self, texts, context_length=None):
if not context_length:
context_length = self.context_length
if isinstance(texts, str):
texts = [texts]
sot_token = self.encoder["<|startoftext|>"]
eot_token = self.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
tokens = tokens[:context_length]
result[i, : len(tokens)] = torch.tensor(tokens)
if len(result) == 1:
return result[0]
return result
class IMUPreprocessor(VerboseNNModule):
def __init__(
self,
kernel_size: int,
imu_stem: PatchEmbedGeneric,
embed_dim: int,
img_size: List = (6, 2000),
num_cls_tokens: int = 1,
pos_embed_fn: Callable = None,
init_param_style: str = "openclip",
) -> None:
super().__init__()
stem = imu_stem
self.imu_stem = imu_stem
self.embed_dim = embed_dim
self.use_pos_embed = pos_embed_fn is not None
self.num_cls_tokens = num_cls_tokens
self.kernel_size = kernel_size
self.pos_embed = nn.Parameter(
torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
)
if self.num_cls_tokens > 0:
self.cls_token = nn.Parameter(
torch.zeros(1, self.num_cls_tokens, self.embed_dim)
)
self.init_parameters(init_param_style)
@torch.no_grad()
def init_parameters(self, init_param_style):
nn.init.normal_(self.pos_embed, std=0.01)
if init_param_style == "openclip":
# OpenCLIP style initialization
scale = self.embed_dim**-0.5
if self.num_cls_tokens > 0:
nn.init.normal_(self.cls_token)
self.cls_token *= scale
elif init_param_style == "vit":
self.cls_token.data.fill_(0)
else:
raise ValueError(f"Unknown init {init_param_style}")
def tokenize_input_and_cls_pos(self, input, stem):
# tokens is of shape B x L x D
tokens = stem.norm_layer(stem.proj(input))
assert tokens.ndim == 3
assert tokens.shape[2] == self.embed_dim
B = tokens.shape[0]
if self.num_cls_tokens > 0:
class_tokens = self.cls_token.expand(
B, -1, -1
) # stole class_tokens impl from Phil Wang, thanks
tokens = torch.cat((class_tokens, tokens), dim=1)
if self.use_pos_embed:
tokens = tokens + self.pos_embed
return tokens
def forward(self, imu):
# Patchify
imu = imu.unfold(
-1,
self.kernel_size,
self.kernel_size,
).permute(0, 2, 1, 3)
imu = imu.reshape(imu.size(0), imu.size(1), -1)
imu_tokens = self.tokenize_input_and_cls_pos(
imu,
self.imu_stem,
)
return_dict = {
"trunk": {
"tokens": imu_tokens,
},
"head": {},
}
return return_dict

284
models/transformer.py Normal file
View File

@ -0,0 +1,284 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Code modified from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ;
# https://github.com/facebookresearch/deit/blob/main/models.py
# and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py
import copy
import fnmatch
import logging
from functools import partial
from typing import Callable, List
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, trunc_normal_
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version,
# can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class MultiheadAttention(nn.MultiheadAttention):
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
class ViTAttention(Attention):
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
assert attn_mask is None
return super().forward(x)
class BlockWithMasking(nn.Module):
def __init__(
self,
dim: int,
attn_target: Callable,
mlp_ratio: int = 4,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
ffn_dropout_rate: float = 0.0,
drop_path: float = 0.0,
layer_scale_type: str = None,
layer_scale_init_value: float = 1e-4,
):
super().__init__()
assert not isinstance(
attn_target, nn.Module
), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
self.attn = attn_target()
if drop_path > 0.0:
self.drop_path = DropPath(drop_path)
else:
self.drop_path = nn.Identity()
self.norm_1 = norm_layer(dim)
mlp_hidden_dim = int(mlp_ratio * dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=ffn_dropout_rate,
)
self.norm_2 = norm_layer(dim)
self.layer_scale_type = layer_scale_type
if self.layer_scale_type is not None:
assert self.layer_scale_type in [
"per_channel",
"scalar",
], f"Found Layer scale type {self.layer_scale_type}"
if self.layer_scale_type == "per_channel":
# one gamma value per channel
gamma_shape = [1, 1, dim]
elif self.layer_scale_type == "scalar":
# single gamma value for all channels
gamma_shape = [1, 1, 1]
# two gammas: for each part of the fwd in the encoder
self.layer_scale_gamma1 = nn.Parameter(
torch.ones(size=gamma_shape) * layer_scale_init_value,
requires_grad=True,
)
self.layer_scale_gamma2 = nn.Parameter(
torch.ones(size=gamma_shape) * layer_scale_init_value,
requires_grad=True,
)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
if self.layer_scale_type is None:
x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
x = x + self.drop_path(self.mlp(self.norm_2(x)))
else:
x = (
x
+ self.drop_path(self.attn(self.norm_1(x), attn_mask))
* self.layer_scale_gamma1
)
x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
return x
_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
class SimpleTransformer(nn.Module):
def __init__(
self,
attn_target: Callable,
embed_dim: int,
num_blocks: int,
block: Callable = BlockWithMasking,
pre_transformer_layer: Callable = None,
post_transformer_layer: Callable = None,
drop_path_rate: float = 0.0,
drop_path_type: str = "progressive",
norm_layer: Callable = _LAYER_NORM,
mlp_ratio: int = 4,
ffn_dropout_rate: float = 0.0,
layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar"
layer_scale_init_value: float = 1e-4, # from cait; float
weight_init_style: str = "jax", # possible values jax or pytorch
):
"""
Simple Transformer with the following features
1. Supports masked attention
2. Supports DropPath
3. Supports LayerScale
4. Supports Dropout in Attention and FFN
5. Makes few assumptions about the input except that it is a Tensor
"""
super().__init__()
self.pre_transformer_layer = pre_transformer_layer
if drop_path_type == "progressive":
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
elif drop_path_type == "uniform":
dpr = [drop_path_rate for i in range(num_blocks)]
else:
raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
self.blocks = nn.Sequential(
*[
block(
dim=embed_dim,
attn_target=attn_target,
mlp_ratio=mlp_ratio,
ffn_dropout_rate=ffn_dropout_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
layer_scale_type=layer_scale_type,
layer_scale_init_value=layer_scale_init_value,
)
for i in range(num_blocks)
]
)
self.post_transformer_layer = post_transformer_layer
self.weight_init_style = weight_init_style
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
if self.weight_init_style == "jax":
# Based on MAE and official Jax ViT implementation
torch.nn.init.xavier_uniform_(m.weight)
elif self.weight_init_style == "pytorch":
# PyTorch ViT uses trunc_normal_
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(
self,
tokens: torch.Tensor,
attn_mask: torch.Tensor = None,
use_checkpoint: bool = False,
checkpoint_every_n: int = 1,
checkpoint_blk_ids: List[int] = None,
):
"""
Inputs
- tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
- attn: mask of shape L x L
Output
- x: data of shape N x L x D (or L x N x D depending on the attention implementation)
"""
if self.pre_transformer_layer:
tokens = self.pre_transformer_layer(tokens)
if use_checkpoint and checkpoint_blk_ids is None:
checkpoint_blk_ids = [
blk_id
for blk_id in range(len(self.blocks))
if blk_id % checkpoint_every_n == 0
]
if checkpoint_blk_ids:
checkpoint_blk_ids = set(checkpoint_blk_ids)
for blk_id, blk in enumerate(self.blocks):
if use_checkpoint and blk_id in checkpoint_blk_ids:
tokens = checkpoint.checkpoint(
blk, tokens, attn_mask, use_reentrant=False
)
else:
tokens = blk(tokens, attn_mask=attn_mask)
if self.post_transformer_layer:
tokens = self.post_transformer_layer(tokens)
return tokens

11
requirements.txt Normal file
View File

@ -0,0 +1,11 @@
torch==1.13
torchvision==0.14.0
torchaudio==0.13.0
timm==0.6.7
ftfy
regex
einops
fvcore
decord==0.6.0
gradio
pytorchvideo