diff --git a/.gitea/workflows/build.yaml b/.gitea/workflows/build.yaml new file mode 100644 index 0000000..ba0d002 --- /dev/null +++ b/.gitea/workflows/build.yaml @@ -0,0 +1,47 @@ +name: Build +run-name: ${{ github.actor }} is upgrade release 🚀 +on: [push] +env: + REPOSITORY: ${{ github.repository }} + COMMIT_ID: ${{ github.sha }} +jobs: + Build-Deploy-Actions: + runs-on: ubuntu-latest + steps: + - run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event." + - run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by Gitea!" + - run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}." + - name: Check out repository code + uses: actions/checkout@v3 + - + name: Setup Git LFS + run: | + git lfs install + git lfs fetch + git lfs checkout + - name: List files in the repository + run: | + ls ${{ github.workspace }} + - + name: Docker Image Info + id: image-info + run: | + echo "::set-output name=image_name::$(echo $REPOSITORY | tr '[:upper:]' '[:lower:]')" + echo "::set-output name=image_tag::${COMMIT_ID:0:10}" + - + name: Login to Docker Hub + uses: docker/login-action@v2 + with: + registry: artifacts.iflytek.com + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - + name: Build and push + run: | + docker version + docker buildx build -t artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }} . --file ${{ github.workspace }}/Dockerfile --load + docker push artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }} + docker rmi artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }} + - run: echo "🍏 This job's status is ${{ job.status }}." diff --git a/app.py b/app.py new file mode 100644 index 0000000..6a0b5d8 --- /dev/null +++ b/app.py @@ -0,0 +1,62 @@ +import data +import torch +import gradio as gr +from models import imagebind_model +from models.imagebind_model import ModalityType +from gradio.themes.utils import sizes + + +theme = gr.themes.Default(radius_size=sizes.radius_none).set( + block_label_text_color = '#4D63FF', + block_title_text_color = '#4D63FF', + button_primary_text_color = '#4D63FF', + button_primary_background_fill='#FFFFFF', + button_primary_border_color='#4D63FF', + button_primary_background_fill_hover='#EDEFFF', +) + +css = "footer {visibility: hidden}" + + +device = "cuda:0" if torch.cuda.is_available() else "cpu" +model = imagebind_model.imagebind_huge(pretrained=True) +model.eval() +model.to(device) + + +def audio_text(audio, text_list): + audio_paths = [audio] + labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] + + inputs = { + ModalityType.TEXT: data.load_and_transform_text(labels, device), + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), + } + + with torch.no_grad(): + embeddings = model(inputs) + + scores = torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1).squeeze(0).tolist() + + score_dict = {label:score for label, score in zip(labels, scores)} + print(score_dict) + + return score_dict + +with gr.Blocks(theme=theme, css=css) as demo: + gr.Markdown(""" +
音频分类
+ """) + with gr.Row(): + with gr.Column(): + audio = gr.inputs.Audio(type='filepath',label="音频输入") + text = gr.inputs.Textbox(lines=1,label="类别") + with gr.Row(): + button = gr.Button("提交", variant="primary") + outputs = gr.Label(label="类别") + + button.click(fn=audio_text, inputs=[audio, text], outputs=outputs) + examples = gr.Examples(examples=[[".assets/dog_audio.wav", "A dog|A car|A bird"],[".assets/car_audio.wav", "A dog|A car|A bird"], [".assets/bird_audio.wav", "A dog|A car|A bird"]],inputs=[audio, text], label="例子") + +if __name__ == "__main__": + demo.queue().launch(server_name = "0.0.0.0") diff --git a/bpe/bpe_simple_vocab_16e6.txt.gz b/bpe/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000..7b5088a Binary files /dev/null and b/bpe/bpe_simple_vocab_16e6.txt.gz differ diff --git a/data.py b/data.py new file mode 100644 index 0000000..80c7aca --- /dev/null +++ b/data.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torchaudio +import logging + +from models.multimodal_preprocessors import SimpleTokenizer +from PIL import Image +from pytorchvideo import transforms as pv_transforms +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler +from pytorchvideo.data.encoded_video import EncodedVideo + +from torchvision import transforms +from torchvision.transforms._transforms_video import NormalizeVideo + +DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds + +BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz" + + +def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): + # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102 + waveform -= waveform.mean() + fbank = torchaudio.compliance.kaldi.fbank( + waveform, + htk_compat=True, + sample_frequency=sample_rate, + use_energy=False, + window_type="hanning", + num_mel_bins=num_mel_bins, + dither=0.0, + frame_length=25, + frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, + ) + # Convert to [mel_bins, num_frames] shape + fbank = fbank.transpose(0, 1) + # Pad to target_length + n_frames = fbank.size(1) + p = target_length - n_frames + # if p is too large (say >20%), flash a warning + if abs(p) / n_frames > 0.2: + logging.warning( + "Large gap between audio n_frames(%d) and " + "target_length (%d). Is the audio_target_length " + "setting correct?", + n_frames, + target_length, + ) + # cut and pad + if p > 0: + fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) + elif p < 0: + fbank = fbank[:, 0:target_length] + # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 + # channel image + fbank = fbank.unsqueeze(0) + return fbank + + +def get_clip_timepoints(clip_sampler, duration): + # Read out all clips in this video + all_clips_timepoints = [] + is_last_clip = False + end = 0.0 + while not is_last_clip: + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) + all_clips_timepoints.append((start, end)) + return all_clips_timepoints + + +def load_and_transform_vision_data(image_paths, device): + if image_paths is None: + return None + + image_ouputs = [] + for image_path in image_paths: + data_transform = transforms.Compose( + [ + transforms.Resize( + 224, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + with open(image_path, "rb") as fopen: + image = Image.open(fopen).convert("RGB") + + image = data_transform(image).to(device) + image_ouputs.append(image) + return torch.stack(image_ouputs, dim=0) + + +def load_and_transform_text(text, device): + if text is None: + return None + tokenizer = SimpleTokenizer(bpe_path=BPE_PATH) + tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text] + tokens = torch.cat(tokens, dim=0) + return tokens + + +def load_and_transform_audio_data( + audio_paths, + device, + num_mel_bins=128, + target_length=204, + sample_rate=16000, + clip_duration=2, + clips_per_video=3, + mean=-4.268, + std=9.138, +): + if audio_paths is None: + return None + + audio_outputs = [] + clip_sampler = ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + + for audio_path in audio_paths: + waveform, sr = torchaudio.load(audio_path) + if sample_rate != sr: + waveform = torchaudio.functional.resample( + waveform, orig_freq=sr, new_freq=sample_rate + ) + all_clips_timepoints = get_clip_timepoints( + clip_sampler, waveform.size(1) / sample_rate + ) + all_clips = [] + for clip_timepoints in all_clips_timepoints: + waveform_clip = waveform[ + :, + int(clip_timepoints[0] * sample_rate) : int( + clip_timepoints[1] * sample_rate + ), + ] + waveform_melspec = waveform2melspec( + waveform_clip, sample_rate, num_mel_bins, target_length + ) + all_clips.append(waveform_melspec) + + normalize = transforms.Normalize(mean=mean, std=std) + all_clips = [normalize(ac).to(device) for ac in all_clips] + + all_clips = torch.stack(all_clips, dim=0) + audio_outputs.append(all_clips) + + return torch.stack(audio_outputs, dim=0) + + +def get_clip_timepoints(clip_sampler, duration): + # Read out all clips in this video + all_clips_timepoints = [] + is_last_clip = False + end = 0.0 + while not is_last_clip: + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) + all_clips_timepoints.append((start, end)) + return all_clips_timepoints + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Peform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to peform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +class SpatialCrop(nn.Module): + """ + Convert the video into 3 smaller clips spatially. Must be used after the + temporal crops to get spatial crops, and should be used with + -2 in the spatial crop at the slowfast augmentation stage (so full + frames are passed in here). Will return a larger list with the + 3x spatial crops as well. + """ + + def __init__(self, crop_size: int = 224, num_crops: int = 3): + super().__init__() + self.crop_size = crop_size + if num_crops == 3: + self.crops_to_ext = [0, 1, 2] + self.flipped_crops_to_ext = [] + elif num_crops == 1: + self.crops_to_ext = [1] + self.flipped_crops_to_ext = [] + else: + raise NotImplementedError("Nothing else supported yet") + + def forward(self, videos): + """ + Args: + videos: A list of C, T, H, W videos. + Returns: + videos: A list with 3x the number of elements. Each video converted + to C, T, H', W' by spatial cropping. + """ + assert isinstance(videos, list), "Must be a list of videos after temporal crops" + assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" + res = [] + for video in videos: + for spatial_idx in self.crops_to_ext: + res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) + if not self.flipped_crops_to_ext: + continue + flipped_video = transforms.functional.hflip(video) + for spatial_idx in self.flipped_crops_to_ext: + res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) + return res + + +def load_and_transform_video_data( + video_paths, + device, + clip_duration=2, + clips_per_video=5, + sample_rate=16000, +): + if video_paths is None: + return None + + video_outputs = [] + video_transform = transforms.Compose( + [ + pv_transforms.ShortSideScale(224), + NormalizeVideo( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + clip_sampler = ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) + + for video_path in video_paths: + video = EncodedVideo.from_path( + video_path, + decoder="decord", + decode_audio=False, + **{"sample_rate": sample_rate}, + ) + + all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) + + all_video = [] + for clip_timepoints in all_clips_timepoints: + # Read the clip, get frames + clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) + if clip is None: + raise ValueError("No clip found") + video_clip = frame_sampler(clip["video"]) + video_clip = video_clip / 255.0 # since this is float, need 0-1 + + all_video.append(video_clip) + + all_video = [video_transform(clip) for clip in all_video] + all_video = SpatialCrop(224, num_crops=3)(all_video) + + all_video = torch.stack(all_video, dim=0) + video_outputs.append(all_video) + + return torch.stack(video_outputs, dim=0).to(device) diff --git a/models/helpers.py b/models/helpers.py new file mode 100644 index 0000000..049e1f1 --- /dev/null +++ b/models/helpers.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import einops +import numpy as np +import torch + +import torch.nn as nn + + +class Normalize(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.nn.functional.normalize(x, dim=self.dim, p=2) + + +class LearnableLogitScaling(nn.Module): + def __init__( + self, + logit_scale_init: float = 1 / 0.07, + learnable: bool = True, + max_logit_scale: float = 100, + ) -> None: + super().__init__() + self.max_logit_scale = max_logit_scale + self.logit_scale_init = logit_scale_init + self.learnable = learnable + log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) + if learnable: + self.log_logit_scale = nn.Parameter(log_logit_scale) + else: + self.register_buffer("log_logit_scale", log_logit_scale) + + def forward(self, x): + return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x + + def extra_repr(self): + st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" + return st + + +class EinOpsRearrange(nn.Module): + def __init__(self, rearrange_expr: str, **kwargs) -> None: + super().__init__() + self.rearrange_expr = rearrange_expr + self.kwargs = kwargs + + def forward(self, x): + assert isinstance(x, torch.Tensor) + return einops.rearrange(x, self.rearrange_expr, **self.kwargs) + + +class VerboseNNModule(nn.Module): + """ + Wrapper around nn.Module that prints registered buffers and parameter names. + """ + + @staticmethod + def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: + st = ( + "(" + + name + + "): " + + "tensor(" + + str(tuple(tensor[1].shape)) + + ", requires_grad=" + + str(tensor[1].requires_grad) + + ")\n" + ) + return st + + def extra_repr(self) -> str: + named_modules = set() + for p in self.named_modules(): + named_modules.update([p[0]]) + named_modules = list(named_modules) + + string_repr = "" + for p in self.named_parameters(): + name = p[0].split(".")[0] + if name not in named_modules: + string_repr += self.get_readable_tensor_repr(name, p) + + for p in self.named_buffers(): + name = p[0].split(".")[0] + string_repr += self.get_readable_tensor_repr(name, p) + + return string_repr + + +def cast_if_src_dtype( + tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype +): + updated = False + if tensor.dtype == src_dtype: + tensor = tensor.to(dtype=tgt_dtype) + updated = True + return tensor, updated + + +class QuickGELU(nn.Module): + # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class SelectElement(nn.Module): + def __init__(self, index) -> None: + super().__init__() + self.index = index + + def forward(self, x): + assert x.ndim >= 3 + return x[:, self.index, ...] + + +class SelectEOSAndProject(nn.Module): + """ + Text Pooling used in OpenCLIP + """ + + def __init__(self, proj: nn.Module) -> None: + super().__init__() + self.proj = proj + + def forward(self, x, seq_len): + assert x.ndim == 3 + # x is of shape B x L x D + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), seq_len] + x = self.proj(x) + return x diff --git a/models/imagebind_model.py b/models/imagebind_model.py new file mode 100644 index 0000000..395aabf --- /dev/null +++ b/models/imagebind_model.py @@ -0,0 +1,517 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import urllib +from functools import partial +from types import SimpleNamespace + +import torch +import torch.nn as nn + +from models.helpers import ( + EinOpsRearrange, + LearnableLogitScaling, + Normalize, + SelectElement, + SelectEOSAndProject, +) +from models.multimodal_preprocessors import ( + AudioPreprocessor, + IMUPreprocessor, + PadIm2Video, + PatchEmbedGeneric, + RGBDTPreprocessor, + SpatioTemporalPosEmbeddingHelper, + TextPreprocessor, + ThermalPreprocessor, +) + +from models.transformer import MultiheadAttention, SimpleTransformer + + +ModalityType = SimpleNamespace( + VISION="vision", + TEXT="text", + AUDIO="audio", + THERMAL="thermal", + DEPTH="depth", + IMU="imu", +) + + +class ImageBindModel(nn.Module): + def __init__( + self, + video_frames=2, + kernel_size=(2, 14, 14), + audio_kernel_size=16, + audio_stride=10, + out_embed_dim=768, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_num_mel_bins=128, + audio_target_len=204, + audio_drop_path=0.1, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + depth_embed_dim=384, + depth_kernel_size=16, + depth_num_blocks=12, + depth_num_heads=8, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_kernel_size=16, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_kernel_size=8, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + super().__init__() + + self.modality_preprocessors = self._create_modality_preprocessors( + video_frames, + vision_embed_dim, + kernel_size, + text_embed_dim, + audio_embed_dim, + audio_kernel_size, + audio_stride, + audio_num_mel_bins, + audio_target_len, + depth_embed_dim, + depth_kernel_size, + thermal_embed_dim, + thermal_kernel_size, + imu_embed_dim, + ) + + self.modality_trunks = self._create_modality_trunks( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + text_embed_dim, + text_num_blocks, + text_num_heads, + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + audio_drop_path, + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + depth_drop_path, + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + thermal_drop_path, + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + imu_drop_path, + ) + + self.modality_heads = self._create_modality_heads( + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ) + + self.modality_postprocessors = self._create_modality_postprocessors( + out_embed_dim + ) + + def _create_modality_preprocessors( + self, + video_frames=2, + vision_embed_dim=1024, + kernel_size=(2, 14, 14), + text_embed_dim=768, + audio_embed_dim=768, + audio_kernel_size=16, + audio_stride=10, + audio_num_mel_bins=128, + audio_target_len=204, + depth_embed_dim=768, + depth_kernel_size=16, + thermal_embed_dim=768, + thermal_kernel_size=16, + imu_embed_dim=512, + ): + rgbt_stem = PatchEmbedGeneric( + proj_stem=[ + PadIm2Video(pad_type="repeat", ntimes=2), + nn.Conv3d( + in_channels=3, + kernel_size=kernel_size, + out_channels=vision_embed_dim, + stride=kernel_size, + bias=False, + ), + ] + ) + rgbt_preprocessor = RGBDTPreprocessor( + img_size=[3, video_frames, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=rgbt_stem, + depth_stem=None, + ) + + text_preprocessor = TextPreprocessor( + context_length=77, + vocab_size=49408, + embed_dim=text_embed_dim, + causal_masking=True, + ) + + audio_stem = PatchEmbedGeneric( + proj_stem=[ + nn.Conv2d( + in_channels=1, + kernel_size=audio_kernel_size, + stride=audio_stride, + out_channels=audio_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), + ) + audio_preprocessor = AudioPreprocessor( + img_size=[1, audio_num_mel_bins, audio_target_len], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + audio_stem=audio_stem, + ) + + depth_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=depth_kernel_size, + in_channels=1, + out_channels=depth_embed_dim, + stride=depth_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), + ) + + depth_preprocessor = RGBDTPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=None, + depth_stem=depth_stem, + ) + + thermal_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=thermal_kernel_size, + in_channels=1, + out_channels=thermal_embed_dim, + stride=thermal_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), + ) + thermal_preprocessor = ThermalPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + thermal_stem=thermal_stem, + ) + + imu_stem = PatchEmbedGeneric( + [ + nn.Linear( + in_features=48, + out_features=imu_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), + ) + + imu_preprocessor = IMUPreprocessor( + img_size=[6, 2000], + num_cls_tokens=1, + kernel_size=8, + embed_dim=imu_embed_dim, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + imu_stem=imu_stem, + ) + + modality_preprocessors = { + ModalityType.VISION: rgbt_preprocessor, + ModalityType.TEXT: text_preprocessor, + ModalityType.AUDIO: audio_preprocessor, + ModalityType.DEPTH: depth_preprocessor, + ModalityType.THERMAL: thermal_preprocessor, + ModalityType.IMU: imu_preprocessor, + } + + return nn.ModuleDict(modality_preprocessors) + + def _create_modality_trunks( + self, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_drop_path=0.0, + depth_embed_dim=768, + depth_num_blocks=12, + depth_num_heads=12, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + def instantiate_trunk( + embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path + ): + return SimpleTransformer( + embed_dim=embed_dim, + num_blocks=num_blocks, + ffn_dropout_rate=0.0, + drop_path_rate=drop_path, + attn_target=partial( + MultiheadAttention, + embed_dim=embed_dim, + num_heads=num_heads, + bias=True, + add_bias_kv=add_bias_kv, + ), + pre_transformer_layer=nn.Sequential( + nn.LayerNorm(embed_dim, eps=1e-6) + if pre_transformer_ln + else nn.Identity(), + EinOpsRearrange("b l d -> l b d"), + ), + post_transformer_layer=EinOpsRearrange("l b d -> b l d"), + ) + + modality_trunks = {} + modality_trunks[ModalityType.VISION] = instantiate_trunk( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + pre_transformer_ln=True, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.TEXT] = instantiate_trunk( + text_embed_dim, + text_num_blocks, + text_num_heads, + pre_transformer_ln=False, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.AUDIO] = instantiate_trunk( + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=audio_drop_path, + ) + modality_trunks[ModalityType.DEPTH] = instantiate_trunk( + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=depth_drop_path, + ) + modality_trunks[ModalityType.THERMAL] = instantiate_trunk( + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=thermal_drop_path, + ) + modality_trunks[ModalityType.IMU] = instantiate_trunk( + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=imu_drop_path, + ) + + return nn.ModuleDict(modality_trunks) + + def _create_modality_heads( + self, + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ): + modality_heads = {} + + modality_heads[ModalityType.VISION] = nn.Sequential( + nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(vision_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.TEXT] = SelectEOSAndProject( + proj=nn.Sequential( + nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), + nn.Linear(text_embed_dim, out_embed_dim, bias=False), + ) + ) + + modality_heads[ModalityType.AUDIO] = nn.Sequential( + nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(audio_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.DEPTH] = nn.Sequential( + nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(depth_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.THERMAL] = nn.Sequential( + nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.IMU] = nn.Sequential( + nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Dropout(p=0.5), + nn.Linear(imu_embed_dim, out_embed_dim, bias=False), + ) + + return nn.ModuleDict(modality_heads) + + def _create_modality_postprocessors(self, out_embed_dim): + modality_postprocessors = {} + + modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) + modality_postprocessors[ModalityType.TEXT] = nn.Sequential( + Normalize(dim=-1), LearnableLogitScaling(learnable=True) + ) + modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=20.0, learnable=False), + ) + modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=10.0, learnable=False), + ) + modality_postprocessors[ModalityType.IMU] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + + return nn.ModuleDict(modality_postprocessors) + + def forward(self, inputs): + outputs = {} + for modality_key, modality_value in inputs.items(): + reduce_list = ( + modality_value.ndim >= 5 + ) # Audio and Video inputs consist of multiple clips + if reduce_list: + B, S = modality_value.shape[:2] + modality_value = modality_value.reshape( + B * S, *modality_value.shape[2:] + ) + + if modality_value is not None: + modality_value = self.modality_preprocessors[modality_key]( + **{modality_key: modality_value} + ) + trunk_inputs = modality_value["trunk"] + head_inputs = modality_value["head"] + modality_value = self.modality_trunks[modality_key](**trunk_inputs) + modality_value = self.modality_heads[modality_key]( + modality_value, **head_inputs + ) + modality_value = self.modality_postprocessors[modality_key]( + modality_value + ) + + if reduce_list: + modality_value = modality_value.reshape(B, S, -1) + modality_value = modality_value.mean(dim=1) + + outputs[modality_key] = modality_value + + return outputs + + +def imagebind_huge(pretrained=False): + model = ImageBindModel( + vision_embed_dim=1280, + vision_num_blocks=32, + vision_num_heads=16, + text_embed_dim=1024, + text_num_blocks=24, + text_num_heads=16, + out_embed_dim=1024, + audio_drop_path=0.1, + imu_drop_path=0.7, + ) + + if pretrained: + if not os.path.exists(".checkpoints/imagebind_huge.pth"): + print( + "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..." + ) + os.makedirs(".checkpoints", exist_ok=True) + torch.hub.download_url_to_file( + "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", + ".checkpoints/imagebind_huge.pth", + progress=True, + ) + + model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth")) + + return model diff --git a/models/multimodal_preprocessors.py b/models/multimodal_preprocessors.py new file mode 100644 index 0000000..b19711d --- /dev/null +++ b/models/multimodal_preprocessors.py @@ -0,0 +1,687 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gzip +import html +import io +import math +from functools import lru_cache +from typing import Callable, List, Optional + +import ftfy + +import numpy as np +import regex as re +import torch +import torch.nn as nn +from iopath.common.file_io import g_pathmgr +from timm.models.layers import trunc_normal_ + +from models.helpers import cast_if_src_dtype, VerboseNNModule + + +def get_sinusoid_encoding_table(n_position, d_hid): + """Sinusoid position encoding table""" + + # TODO: make it with torch instead of numpy + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)] + ) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): + N = pos_embed.shape[1] + if N == target_spatial_size: + return pos_embed + dim = pos_embed.shape[-1] + # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 + pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(target_spatial_size / N), + mode="bicubic", + ) + if updated: + pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=None, + first_patch_idx=1, +): + assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none" + N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists + if npatch_per_img == N: + return pos_embed + + assert ( + patches_layout[-1] == patches_layout[-2] + ), "Interpolation of pos embed not supported for non-square layouts" + + class_emb = pos_embed[:, :first_patch_idx] + pos_embed = pos_embed[:, first_patch_idx:] + + if input_shape is None or patches_layout[0] == 1: + # simple 2D pos embedding, no temporal component + pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed) + elif patches_layout[0] > 1: + # pos embed has a temporal component + assert len(input_shape) == 4, "temporal interpolation not supported" + # we only support 2D interpolation in this case + num_frames = patches_layout[0] + num_spatial_tokens = patches_layout[1] * patches_layout[2] + pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) + # interpolate embedding for zeroth frame + pos_embed = interpolate_pos_encoding_2d( + npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0) + ) + else: + raise ValueError("This type of interpolation isn't implemented") + + return torch.cat((class_emb, pos_embed), dim=1) + + +def _get_pos_embedding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape, + first_patch_idx=1, +): + pos_embed = interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=input_shape, + first_patch_idx=first_patch_idx, + ) + return pos_embed + + +class PatchEmbedGeneric(nn.Module): + """ + PatchEmbed from Hydra + """ + + def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None): + super().__init__() + + if len(proj_stem) > 1: + self.proj = nn.Sequential(*proj_stem) + else: + # Special case to be able to load pre-trained models that were + # trained with a standard stem + self.proj = proj_stem[0] + self.norm_layer = norm_layer + + def get_patch_layout(self, img_size): + with torch.no_grad(): + dummy_img = torch.zeros( + [ + 1, + ] + + img_size + ) + dummy_out = self.proj(dummy_img) + embed_dim = dummy_out.shape[1] + patches_layout = tuple(dummy_out.shape[2:]) + num_patches = np.prod(patches_layout) + return patches_layout, num_patches, embed_dim + + def forward(self, x): + x = self.proj(x) + # B C (T) H W -> B (T)HW C + x = x.flatten(2).transpose(1, 2) + if self.norm_layer is not None: + x = self.norm_layer(x) + return x + + +class SpatioTemporalPosEmbeddingHelper(VerboseNNModule): + def __init__( + self, + patches_layout: List, + num_patches: int, + num_cls_tokens: int, + embed_dim: int, + learnable: bool, + ) -> None: + super().__init__() + self.num_cls_tokens = num_cls_tokens + self.patches_layout = patches_layout + self.num_patches = num_patches + self.num_tokens = num_cls_tokens + num_patches + self.learnable = learnable + if self.learnable: + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) + trunc_normal_(self.pos_embed, std=0.02) + else: + self.register_buffer( + "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim) + ) + + def get_pos_embedding(self, vision_input, all_vision_tokens): + input_shape = vision_input.shape + pos_embed = _get_pos_embedding( + all_vision_tokens.size(1) - self.num_cls_tokens, + pos_embed=self.pos_embed, + patches_layout=self.patches_layout, + input_shape=input_shape, + first_patch_idx=self.num_cls_tokens, + ) + return pos_embed + + +class RGBDTPreprocessor(VerboseNNModule): + def __init__( + self, + rgbt_stem: PatchEmbedGeneric, + depth_stem: PatchEmbedGeneric, + img_size: List = (3, 224, 224), + num_cls_tokens: int = 1, + pos_embed_fn: Callable = None, + use_type_embed: bool = False, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + stem = rgbt_stem if rgbt_stem is not None else depth_stem + ( + self.patches_layout, + self.num_patches, + self.embed_dim, + ) = stem.get_patch_layout(img_size) + self.rgbt_stem = rgbt_stem + self.depth_stem = depth_stem + self.use_pos_embed = pos_embed_fn is not None + self.use_type_embed = use_type_embed + self.num_cls_tokens = num_cls_tokens + + if self.use_pos_embed: + self.pos_embedding_helper = pos_embed_fn( + patches_layout=self.patches_layout, + num_cls_tokens=num_cls_tokens, + num_patches=self.num_patches, + embed_dim=self.embed_dim, + ) + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + if self.use_type_embed: + self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.use_pos_embed: + nn.init.normal_(self.pos_embedding_helper.pos_embed) + self.pos_embedding_helper.pos_embed *= scale + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + if self.use_type_embed: + nn.init.normal_(self.type_embed) + + def tokenize_input_and_cls_pos(self, input, stem, mask): + # tokens is of shape B x L x D + tokens = stem(input) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) + tokens = tokens + pos_embed + if self.use_type_embed: + tokens = tokens + self.type_embed.expand(B, -1, -1) + return tokens + + def forward(self, vision=None, depth=None, patch_mask=None): + if patch_mask is not None: + raise NotImplementedError() + + if vision is not None: + vision_tokens = self.tokenize_input_and_cls_pos( + vision, self.rgbt_stem, patch_mask + ) + + if depth is not None: + depth_tokens = self.tokenize_input_and_cls_pos( + depth, self.depth_stem, patch_mask + ) + + # aggregate tokens + if vision is not None and depth is not None: + final_tokens = vision_tokens + depth_tokens + else: + final_tokens = vision_tokens if vision is not None else depth_tokens + return_dict = { + "trunk": { + "tokens": final_tokens, + }, + "head": {}, + } + return return_dict + + +class AudioPreprocessor(RGBDTPreprocessor): + def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) + + def forward(self, audio=None): + return super().forward(vision=audio) + + +class ThermalPreprocessor(RGBDTPreprocessor): + def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs) + + def forward(self, thermal=None): + return super().forward(vision=thermal) + + +def build_causal_attention_mask(context_length): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(context_length, context_length, requires_grad=False) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + +class TextPreprocessor(VerboseNNModule): + def __init__( + self, + vocab_size: int, + context_length: int, + embed_dim: int, + causal_masking: bool, + supply_seq_len_to_head: bool = True, + num_cls_tokens: int = 0, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + self.vocab_size = vocab_size + self.context_length = context_length + self.token_embedding = nn.Embedding(vocab_size, embed_dim) + self.pos_embed = nn.Parameter( + torch.empty(1, self.context_length + num_cls_tokens, embed_dim) + ) + self.causal_masking = causal_masking + if self.causal_masking: + mask = build_causal_attention_mask(self.context_length) + # register the mask as a buffer so it can be moved to the right device + self.register_buffer("mask", mask) + + self.supply_seq_len_to_head = supply_seq_len_to_head + self.num_cls_tokens = num_cls_tokens + self.embed_dim = embed_dim + if num_cls_tokens > 0: + assert self.causal_masking is False, "Masking + CLS token isn't implemented" + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style="openclip"): + # OpenCLIP style initialization + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def forward(self, text): + # text tokens are of shape B x L x D + text_tokens = self.token_embedding(text) + # concat CLS tokens if any + if self.num_cls_tokens > 0: + B = text_tokens.shape[0] + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + text_tokens = torch.cat((class_tokens, text_tokens), dim=1) + text_tokens = text_tokens + self.pos_embed + return_dict = { + "trunk": { + "tokens": text_tokens, + }, + "head": {}, + } + # Compute sequence length after adding CLS tokens + if self.supply_seq_len_to_head: + text_lengths = text.argmax(dim=-1) + return_dict["head"] = { + "seq_len": text_lengths, + } + if self.causal_masking: + return_dict["trunk"].update({"attn_mask": self.mask}) + return return_dict + + +class Im2Video(nn.Module): + """Convert an image into a trivial video.""" + + def __init__(self, time_dim=2): + super().__init__() + self.time_dim = time_dim + + def forward(self, x): + if x.ndim == 4: + # B, C, H, W -> B, C, T, H, W + return x.unsqueeze(self.time_dim) + elif x.ndim == 5: + return x + else: + raise ValueError(f"Dimension incorrect {x.shape}") + + +class PadIm2Video(Im2Video): + def __init__(self, ntimes, pad_type, time_dim=2): + super().__init__(time_dim=time_dim) + assert ntimes > 0 + assert pad_type in ["zero", "repeat"] + self.ntimes = ntimes + self.pad_type = pad_type + + def forward(self, x): + x = super().forward(x) + if x.shape[self.time_dim] == 1: + if self.pad_type == "repeat": + new_shape = [1] * len(x.shape) + new_shape[self.time_dim] = self.ntimes + x = x.repeat(new_shape) + elif self.pad_type == "zero": + padarg = [0, 0] * len(x.shape) + padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] + x = nn.functional.pad(x, padarg) + return x + + +# Modified from github.com/openai/CLIP +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str, context_length=77): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + with g_pathmgr.open(bpe_path, "rb") as fh: + bpe_bytes = io.BytesIO(fh.read()) + merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + self.context_length = context_length + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + def __call__(self, texts, context_length=None): + if not context_length: + context_length = self.context_length + + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = tokens[:context_length] + result[i, : len(tokens)] = torch.tensor(tokens) + + if len(result) == 1: + return result[0] + return result + + +class IMUPreprocessor(VerboseNNModule): + def __init__( + self, + kernel_size: int, + imu_stem: PatchEmbedGeneric, + embed_dim: int, + img_size: List = (6, 2000), + num_cls_tokens: int = 1, + pos_embed_fn: Callable = None, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + stem = imu_stem + self.imu_stem = imu_stem + self.embed_dim = embed_dim + self.use_pos_embed = pos_embed_fn is not None + self.num_cls_tokens = num_cls_tokens + self.kernel_size = kernel_size + self.pos_embed = nn.Parameter( + torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim) + ) + + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def tokenize_input_and_cls_pos(self, input, stem): + # tokens is of shape B x L x D + tokens = stem.norm_layer(stem.proj(input)) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + tokens = tokens + self.pos_embed + return tokens + + def forward(self, imu): + # Patchify + imu = imu.unfold( + -1, + self.kernel_size, + self.kernel_size, + ).permute(0, 2, 1, 3) + imu = imu.reshape(imu.size(0), imu.size(1), -1) + + imu_tokens = self.tokenize_input_and_cls_pos( + imu, + self.imu_stem, + ) + + return_dict = { + "trunk": { + "tokens": imu_tokens, + }, + "head": {}, + } + return return_dict diff --git a/models/transformer.py b/models/transformer.py new file mode 100644 index 0000000..98902ac --- /dev/null +++ b/models/transformer.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Code modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ; +# https://github.com/facebookresearch/deit/blob/main/models.py +# and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py + + +import copy +import fnmatch +import logging +from functools import partial +from typing import Callable, List + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import DropPath, trunc_normal_ + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, + # can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class MultiheadAttention(nn.MultiheadAttention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + +class ViTAttention(Attention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + assert attn_mask is None + return super().forward(x) + + +class BlockWithMasking(nn.Module): + def __init__( + self, + dim: int, + attn_target: Callable, + mlp_ratio: int = 4, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ffn_dropout_rate: float = 0.0, + drop_path: float = 0.0, + layer_scale_type: str = None, + layer_scale_init_value: float = 1e-4, + ): + super().__init__() + + assert not isinstance( + attn_target, nn.Module + ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" + self.attn = attn_target() + if drop_path > 0.0: + self.drop_path = DropPath(drop_path) + else: + self.drop_path = nn.Identity() + self.norm_1 = norm_layer(dim) + mlp_hidden_dim = int(mlp_ratio * dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=ffn_dropout_rate, + ) + self.norm_2 = norm_layer(dim) + self.layer_scale_type = layer_scale_type + if self.layer_scale_type is not None: + assert self.layer_scale_type in [ + "per_channel", + "scalar", + ], f"Found Layer scale type {self.layer_scale_type}" + if self.layer_scale_type == "per_channel": + # one gamma value per channel + gamma_shape = [1, 1, dim] + elif self.layer_scale_type == "scalar": + # single gamma value for all channels + gamma_shape = [1, 1, 1] + # two gammas: for each part of the fwd in the encoder + self.layer_scale_gamma1 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + self.layer_scale_gamma2 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + if self.layer_scale_type is None: + x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + x = x + self.drop_path(self.mlp(self.norm_2(x))) + else: + x = ( + x + + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + * self.layer_scale_gamma1 + ) + x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 + return x + + +_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) + + +class SimpleTransformer(nn.Module): + def __init__( + self, + attn_target: Callable, + embed_dim: int, + num_blocks: int, + block: Callable = BlockWithMasking, + pre_transformer_layer: Callable = None, + post_transformer_layer: Callable = None, + drop_path_rate: float = 0.0, + drop_path_type: str = "progressive", + norm_layer: Callable = _LAYER_NORM, + mlp_ratio: int = 4, + ffn_dropout_rate: float = 0.0, + layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar" + layer_scale_init_value: float = 1e-4, # from cait; float + weight_init_style: str = "jax", # possible values jax or pytorch + ): + """ + Simple Transformer with the following features + 1. Supports masked attention + 2. Supports DropPath + 3. Supports LayerScale + 4. Supports Dropout in Attention and FFN + 5. Makes few assumptions about the input except that it is a Tensor + """ + super().__init__() + self.pre_transformer_layer = pre_transformer_layer + if drop_path_type == "progressive": + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] + elif drop_path_type == "uniform": + dpr = [drop_path_rate for i in range(num_blocks)] + else: + raise ValueError(f"Unknown drop_path_type: {drop_path_type}") + + self.blocks = nn.Sequential( + *[ + block( + dim=embed_dim, + attn_target=attn_target, + mlp_ratio=mlp_ratio, + ffn_dropout_rate=ffn_dropout_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + layer_scale_type=layer_scale_type, + layer_scale_init_value=layer_scale_init_value, + ) + for i in range(num_blocks) + ] + ) + self.post_transformer_layer = post_transformer_layer + self.weight_init_style = weight_init_style + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + if self.weight_init_style == "jax": + # Based on MAE and official Jax ViT implementation + torch.nn.init.xavier_uniform_(m.weight) + elif self.weight_init_style == "pytorch": + # PyTorch ViT uses trunc_normal_ + trunc_normal_(m.weight, std=0.02) + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + tokens: torch.Tensor, + attn_mask: torch.Tensor = None, + use_checkpoint: bool = False, + checkpoint_every_n: int = 1, + checkpoint_blk_ids: List[int] = None, + ): + """ + Inputs + - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) + - attn: mask of shape L x L + + Output + - x: data of shape N x L x D (or L x N x D depending on the attention implementation) + """ + if self.pre_transformer_layer: + tokens = self.pre_transformer_layer(tokens) + if use_checkpoint and checkpoint_blk_ids is None: + checkpoint_blk_ids = [ + blk_id + for blk_id in range(len(self.blocks)) + if blk_id % checkpoint_every_n == 0 + ] + if checkpoint_blk_ids: + checkpoint_blk_ids = set(checkpoint_blk_ids) + for blk_id, blk in enumerate(self.blocks): + if use_checkpoint and blk_id in checkpoint_blk_ids: + tokens = checkpoint.checkpoint( + blk, tokens, attn_mask, use_reentrant=False + ) + else: + tokens = blk(tokens, attn_mask=attn_mask) + if self.post_transformer_layer: + tokens = self.post_transformer_layer(tokens) + return tokens diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..93f3b80 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +torch==1.13 +torchvision==0.14.0 +torchaudio==0.13.0 +timm==0.6.7 +ftfy +regex +einops +fvcore +decord==0.6.0 +gradio +pytorchvideo