add
Build-Deploy-Actions Details

This commit is contained in:
jianjiang 2023-04-25 21:51:33 +08:00
parent 7d95652df1
commit 56f65aaf2f
66 changed files with 12347 additions and 0 deletions

33
CLIP/.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,33 @@
name: test
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
CLIP-test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
pytorch-version: [1.7.1, 1.9.1, 1.10.1]
include:
- python-version: 3.8
pytorch-version: 1.7.1
torchvision-version: 0.8.2
- python-version: 3.8
pytorch-version: 1.9.1
torchvision-version: 0.10.1
- python-version: 3.8
pytorch-version: 1.10.1
torchvision-version: 0.11.2
steps:
- uses: conda-incubator/setup-miniconda@v2
- run: conda install -n test python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} torchvision=${{ matrix.torchvision-version }} cpuonly -c pytorch
- uses: actions/checkout@v2
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install pytest
- run: pip install .
- run: pytest

10
CLIP/.gitignore vendored Normal file
View File

@ -0,0 +1,10 @@
__pycache__/
*.py[cod]
*$py.class
*.egg-info
.pytest_cache
.ipynb_checkpoints
thumbs.db
.DS_Store
.idea

BIN
CLIP/CLIP.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 246 KiB

22
CLIP/LICENSE Normal file
View File

@ -0,0 +1,22 @@
MIT License
Copyright (c) 2021 OpenAI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

1
CLIP/MANIFEST.in Normal file
View File

@ -0,0 +1 @@
include clip/bpe_simple_vocab_16e6.txt.gz

199
CLIP/README.md Normal file
View File

@ -0,0 +1,199 @@
# CLIP
[[Blog]](https://openai.com/blog/clip/) [[Paper]](https://arxiv.org/abs/2103.00020) [[Model Card]](model-card.md) [[Colab]](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb)
CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision.
## Approach
![CLIP](CLIP.png)
## Usage
First, [install PyTorch 1.7.1](https://pytorch.org/get-started/locally/) (or later) and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick:
```bash
$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
$ pip install ftfy regex tqdm
$ pip install git+https://github.com/openai/CLIP.git
```
Replace `cudatoolkit=11.0` above with the appropriate CUDA version on your machine or `cpuonly` when installing on a machine without a GPU.
```python
import torch
import clip
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
```
## API
The CLIP module `clip` provides the following methods:
#### `clip.available_models()`
Returns the names of the available CLIP models.
#### `clip.load(name, device=..., jit=False)`
Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint.
The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded.
#### `clip.tokenize(text: Union[str, List[str]], context_length=77)`
Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model
---
The model returned by `clip.load()` supports the following methods:
#### `model.encode_image(image: Tensor)`
Given a batch of images, returns the image features encoded by the vision portion of the CLIP model.
#### `model.encode_text(text: Tensor)`
Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model.
#### `model(image: Tensor, text: Tensor)`
Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100.
## More Examples
### Zero-Shot Prediction
The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset.
```python
import os
import clip
import torch
from torchvision.datasets import CIFAR100
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)
# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
```
The output will look like the following (the exact numbers may be slightly different depending on the compute device):
```
Top predictions:
snake: 65.31%
turtle: 12.29%
sweet_pepper: 3.83%
lizard: 1.88%
crocodile: 1.75%
```
Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs.
### Linear-probe evaluation
The example below uses [scikit-learn](https://scikit-learn.org/) to perform logistic regression on image features.
```python
import os
import clip
import torch
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)
def get_features(dataset):
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
features = model.encode_image(images.to(device))
all_features.append(features)
all_labels.append(labels)
return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()
# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)
# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)
# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")
```
Note that the `C` value should be determined via a hyperparameter sweep using a validation split.
## See Also
* [OpenCLIP](https://github.com/mlfoundations/open_clip): includes larger and independently trained CLIP models up to ViT-G/14
* [Hugging Face implementation of CLIP](https://huggingface.co/docs/transformers/model_doc/clip): for easier integration with the HF ecosystem

1
CLIP/clip/__init__.py Normal file
View File

@ -0,0 +1 @@
from .clip import *

BIN
CLIP/clip/bpe_simple_vocab_16e6.txt.gz (Stored with Git LFS) Normal file

Binary file not shown.

237
CLIP/clip/clip.py Normal file
View File

@ -0,0 +1,237 @@
import hashlib
import os
import urllib
import warnings
from typing import Any, Union, List
from pkg_resources import packaging
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
}
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(opened_file, map_location="cpu")
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result

436
CLIP/clip/model.py Normal file
View File

@ -0,0 +1,436 @@
from collections import OrderedDict
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
return model.eval()

View File

@ -0,0 +1,132 @@
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text

12
CLIP/data/country211.md Normal file
View File

@ -0,0 +1,12 @@
# The Country211 Dataset
In the paper, we used an image classification dataset called Country211, to evaluate the model's capability on geolocation. To do so, we filtered the YFCC100m dataset that have GPS coordinate corresponding to a [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes) and created a balanced dataset by sampling 150 train images, 50 validation images, and 100 test images images for each country.
The following command will download an 11GB archive countaining the images and extract into a subdirectory `country211`:
```bash
wget https://openaipublic.azureedge.net/clip/data/country211.tgz
tar zxvf country211.tgz
```
These images are a subset of the YFCC100m dataset. Use of the underlying media files is subject to the Creative Commons licenses chosen by their creators/uploaders. For more information about the YFCC100M dataset, visit [the official website](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/).

3401
CLIP/data/prompts.md Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,11 @@
# The Rendered SST2 Dataset
In the paper, we used an image classification dataset called Rendered SST2, to evaluate the model's capability on optical character recognition. To do so, we rendered the sentences in the [Standford Sentiment Treebank v2](https://nlp.stanford.edu/sentiment/treebank.html) dataset and used those as the input to the CLIP image encoder.
The following command will download a 131MB archive countaining the images and extract into a subdirectory `rendered-sst2`:
```bash
wget https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz
tar zxvf rendered-sst2.tgz
```

14
CLIP/data/yfcc100m.md Normal file
View File

@ -0,0 +1,14 @@
# The YFCC100M Subset
In the paper, we performed a dataset ablation using a subset of the YFCC100M dataset and showed that the performance remained largely similar.
The subset contains 14,829,396 images, about 15% of the full dataset, which have been filtered to only keep those with natural languag titles and/or descriptions in English.
We provide the list of (line number, photo identifier, photo hash) of each image contained in this subset. These correspond to the first three columns in the dataset's metadata TSV file.
```bash
wget https://openaipublic.azureedge.net/clip/data/yfcc100m_subset_data.tsv.bz2
bunzip2 yfcc100m_subset_data.tsv.bz2
```
Use of the underlying media files is subject to the Creative Commons licenses chosen by their creators/uploaders. For more information about the YFCC100M dataset, visit [the official website](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/).

42
CLIP/hubconf.py Normal file
View File

@ -0,0 +1,42 @@
from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models
import re
import string
#dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"]
# For compatibility (cannot include special characters in function name)
model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()}
def _create_hub_entrypoint(model):
def entrypoint(**kwargs):
return _load(model, **kwargs)
entrypoint.__doc__ = f"""Loads the {model} CLIP model
Parameters
----------
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
model : torch.nn.Module
The {model} CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
return entrypoint
def tokenize():
return _tokenize
_entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()}
globals().update(_entrypoints)

120
CLIP/model-card.md Normal file
View File

@ -0,0 +1,120 @@
# Model Card: CLIP
Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), were providing some accompanying information about the multimodal model.
## Model Details
The CLIP model was developed by researchers at OpenAI to learn about what contributes to robustness in computer vision tasks. The model was also developed to test the ability of models to generalize to arbitrary image classification tasks in a zero-shot manner. It was not developed for general model deployment - to deploy models like CLIP, researchers will first need to carefully study their capabilities in relation to the specific context theyre being deployed within.
### Model Date
January 2021
### Model Type
The base model uses a ResNet50 with several modifications as an image encoder and uses a masked self-attention Transformer as a text encoder. These encoders are trained to maximize the similarity of (image, text) pairs via a contrastive loss. There is also a variant of the model where the ResNet image encoder is replaced with a Vision Transformer.
### Model Versions
Initially, weve released one CLIP model based on the Vision Transformer architecture equivalent to ViT-B/32, along with the RN50 model, using the architecture equivalent to ResNet-50.
As part of the staged release process, we have also released the RN101 model, as well as RN50x4, a RN50 scaled up 4x according to the [EfficientNet](https://arxiv.org/abs/1905.11946) scaling rule. In July 2021, we additionally released the RN50x16 and ViT-B/16 models, and in January 2022, the RN50x64 and ViT-L/14 models were released. Lastly, the ViT-L/14@336px model was released in April 2022.
Please see the paper linked below for further details about their specification.
### Documents
- [Blog Post](https://openai.com/blog/clip/)
- [CLIP Paper](https://arxiv.org/abs/2103.00020)
## Model Use
### Intended Use
The model is intended as a research output for research communities. We hope that this model will enable researchers to better understand and explore zero-shot, arbitrary image classification. We also hope it can be used for interdisciplinary studies of the potential impact of such models - the CLIP paper includes a discussion of potential downstream impacts to provide an example for this sort of analysis.
#### Primary intended uses
The primary intended users of these models are AI researchers.
We primarily imagine the model will be used by researchers to better understand robustness, generalization, and other capabilities, biases, and constraints of computer vision models.
### Out-of-Scope Use Cases
**Any** deployed use case of the model - whether commercial or not - is currently out of scope. Non-deployed use cases such as image search in a constrained environment, are also not recommended unless there is thorough in-domain testing of the model with a specific, fixed class taxonomy. This is because our safety assessment demonstrated a high need for task specific testing especially given the variability of CLIPs performance with different class taxonomies. This makes untested and unconstrained deployment of the model in any use case currently potentially harmful.
Certain use cases which would fall under the domain of surveillance and facial recognition are always out-of-scope regardless of performance of the model. This is because the use of artificial intelligence for tasks such as these can be premature currently given the lack of testing norms and checks to ensure its fair use.
Since the model has not been purposefully trained in or evaluated on any languages other than English, its use should be limited to English language use cases.
## Data
The model was trained on publicly available image-caption data. This was done through a combination of crawling a handful of websites and using commonly-used pre-existing image datasets such as [YFCC100M](http://projects.dfki.uni-kl.de/yfcc100m/). A large portion of the data comes from our crawling of the internet. This means that the data is more representative of people and societies most connected to the internet which tend to skew towards more developed nations, and younger, male users.
### Data Mission Statement
Our goal with building this dataset was to test out robustness and generalizability in computer vision tasks. As a result, the focus was on gathering large quantities of data from different publicly-available internet data sources. The data was gathered in a mostly non-interventionist manner. However, we only crawled websites that had policies against excessively violent and adult images and allowed us to filter out such content. We do not intend for this dataset to be used as the basis for any commercial or deployed model and will not be releasing the dataset.
## Performance and Limitations
### Performance
We have evaluated the performance of CLIP on a wide range of benchmarks across a variety of computer vision datasets such as OCR to texture recognition to fine-grained classification. The paper describes model performance on the following datasets:
- Food101
- CIFAR10
- CIFAR100
- Birdsnap
- SUN397
- Stanford Cars
- FGVC Aircraft
- VOC2007
- DTD
- Oxford-IIIT Pet dataset
- Caltech101
- Flowers102
- MNIST
- SVHN
- IIIT5K
- Hateful Memes
- SST-2
- UCF101
- Kinetics700
- Country211
- CLEVR Counting
- KITTI Distance
- STL-10
- RareAct
- Flickr30
- MSCOCO
- ImageNet
- ImageNet-A
- ImageNet-R
- ImageNet Sketch
- ObjectNet (ImageNet Overlap)
- Youtube-BB
- ImageNet-Vid
## Limitations
CLIP and our analysis of it have a number of limitations. CLIP currently struggles with respect to certain tasks such as fine grained classification and counting objects. CLIP also poses issues with regards to fairness and bias which we discuss in the paper and briefly in the next section. Additionally, our approach to testing CLIP also has an important limitation- in many cases we have used linear probes to evaluate the performance of CLIP and there is evidence suggesting that linear probes can underestimate model performance.
### Bias and Fairness
We find that the performance of CLIP - and the specific biases it exhibits - can depend significantly on class design and the choices one makes for categories to include and exclude. We tested the risk of certain kinds of denigration with CLIP by classifying images of people from [Fairface](https://arxiv.org/abs/1908.04913) into crime-related and non-human animal categories. We found significant disparities with respect to race and gender. Additionally, we found that these disparities could shift based on how the classes were constructed. (Details captured in the Broader Impacts Section in the paper).
We also tested the performance of CLIP on gender, race and age classification using the Fairface dataset (We default to using race categories as they are constructed in the Fairface dataset.) in order to assess quality of performance across different demographics. We found accuracy >96% across all races for gender classification with Middle Eastern having the highest accuracy (98.4%) and White having the lowest (96.5%). Additionally, CLIP averaged ~93% for racial classification and ~63% for age classification. Our use of evaluations to test for gender, race and age classification as well as denigration harms is simply to evaluate performance of the model across people and surface potential risks and not to demonstrate an endorsement/enthusiasm for such tasks.
## Feedback
### Where to send questions or comments about the model
Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

5
CLIP/requirements.txt Normal file
View File

@ -0,0 +1,5 @@
ftfy
regex
tqdm
torch
torchvision

21
CLIP/setup.py Normal file
View File

@ -0,0 +1,21 @@
import os
import pkg_resources
from setuptools import setup, find_packages
setup(
name="clip",
py_modules=["clip"],
version="1.0",
description="",
author="OpenAI",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
include_package_data=True,
extras_require={'dev': ['pytest']},
)

View File

@ -0,0 +1,25 @@
import numpy as np
import pytest
import torch
from PIL import Image
import clip
@pytest.mark.parametrize('model_name', clip.available_models())
def test_consistency(model_name):
device = "cpu"
jit_model, transform = clip.load(model_name, device=device, jit=True)
py_model, _ = clip.load(model_name, device=device, jit=False)
image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
with torch.no_grad():
logits_per_image, _ = jit_model(image, text)
jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
logits_per_image, _ = py_model(image, text)
py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)

6
point-e/.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
*.egg-info/
__pycache__/
point_e_model_cache/
.ipynb_checkpoints/
.DS_Store

22
point-e/LICENSE Normal file
View File

@ -0,0 +1,22 @@
MIT License
Copyright (c) 2022 OpenAI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

28
point-e/README.md Normal file
View File

@ -0,0 +1,28 @@
# Point·E
![Animation of four 3D point clouds rotating](point_e/examples/paper_banner.gif)
This is the official code and model release for [Point-E: A System for Generating 3D Point Clouds from Complex Prompts](https://arxiv.org/abs/2212.08751).
# Usage
Install with `pip install -e .`.
To get started with examples, see the following notebooks:
* [image2pointcloud.ipynb](point_e/examples/image2pointcloud.ipynb) - sample a point cloud, conditioned on some example synthetic view images.
* [text2pointcloud.ipynb](point_e/examples/text2pointcloud.ipynb) - use our small, worse quality pure text-to-3D model to produce 3D point clouds directly from text descriptions. This model's capabilities are limited, but it does understand some simple categories and colors.
* [pointcloud2mesh.ipynb](point_e/examples/pointcloud2mesh.ipynb) - try our SDF regression model for producing meshes from point clouds.
For our P-FID and P-IS evaluation scripts, see:
* [evaluate_pfid.py](point_e/evals/scripts/evaluate_pfid.py)
* [evaluate_pis.py](point_e/evals/scripts/evaluate_pis.py)
For our Blender rendering code, see [blender_script.py](point_e/evals/scripts/blender_script.py)
# Samples
You can download the seed images and point clouds corresponding to the paper banner images [here](https://openaipublic.azureedge.net/main/point-e/banner_pcs.zip).
You can download the seed images used for COCO CLIP R-Precision evaluations [here](https://openaipublic.azureedge.net/main/point-e/coco_images.zip).

62
point-e/model-card.md Normal file
View File

@ -0,0 +1,62 @@
# Model Card: Point-E
This is the official codebase for running the point cloud diffusion models and SDF regression models described in [Point-E: A System for Generating 3D Point Clouds from Complex Prompts](https://arxiv.org/abs/2212.08751). These models were trained and released by OpenAI.
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about how the models were trained and evaluated.
# Model Details
The Point-E models are trained for use as point cloud diffusion models and SDF regression models.
Our image-conditional models are often capable of producing coherent 3D point clouds, given a single rendering of a 3D object. However, the models sometimes fail to do so, either producing incorrect geometry where the rendering is occluded, or producing geometry that is inconsistent with visible parts of the rendering. The resulting point clouds are relatively low-resolution, and are often noisy and contain defects such as outliers or cracks.
Our text-conditional model is sometimes capable of producing 3D point clouds which can be recognized as the provided text description, especially when the text description is simple. However, we find that this model fails to generalize to complex prompts or unusual objects.
## Model Date
December 2022
## Model Versions
* `base40M-imagevec` - a 40 million parameter image to point cloud model that conditions on a single CLIP ViT-L/14 image vector. This model can be used to generate point clouds from rendered images, but does not perform as well as our other models for this task.
* `base40M-textvec` - a 40 million parameter text to point cloud model that conditions on a single CLIP ViT-L/14 text vector. This model can be used to directly generate point clouds from text descriptions, but only works for simple prompts.
* `base40M-uncond` - a 40 million parameter point cloud diffusion model that generates unconditional samples. This is included only as a baseline.
* `base40M` - a 40 million parameter image to point cloud diffusion model that conditions on the latent grid from a CLIP ViT-L/14 model. This model can be used to generate point clouds from rendered images, but is not as good as the larger models trained on the same task.
* `base300M` - a 300 million parameter image to point cloud diffusion model that conditions on the latent grid from a CLIP ViT-L/14 model. This model can be used to generate point clouds from rendered images, but it is slightly worse than base1B
* `base1B` - a 1 billion parameter image to point cloud diffusion model that conditions on the latent grid from a CLIP ViT-L/14 model.
* `upsample` - a 40 million parameter point cloud upsampling model that can optionally condition on an image as well. This takes a point cloud of 1024 points and upsamples it to 4096 points.
* `sdf` - a small model for predicting signed distance functions from 3D point clouds. This can be used to predict meshes from point clouds.
* `pointnet` - a small point cloud classification model used for our P-FID and P-IS evaluation metrics.
## Paper & samples
[Paper](https://arxiv.org/abs/2212.08751) / [Sample point clouds](point_e/examples/paper_banner.gif)
# Training data
These models were trained on a dataset of several million 3D models. We filtered the dataset to avoid flat objects, and used [CLIP](https://github.com/openai/CLIP/blob/main/model-card.md) to cluster the dataset and downweight clusters of 3D models which appeared to contain mostly unrecognizable objects. We additionally down-weighted clusters which appeared to consist of many similar-looking objects. We processed the resulting dataset into renders (RGB point clouds of 4K points each) and text captions from the associated metadata.
Our SDF regression model was trained on a subset of the above dataset. In particular, we only retained 3D meshes which were manifold (i.e. watertight and free of singularities).
# Evaluated Use
We release these models to help advance research in generative modeling. Due to the limitations and biases of our models, we do not currently recommend it for commercial use. We understand that our models may be used in ways we haven't anticipated, and that it is difficult to define clear boundaries around what constitutes appropriate "research" use. In particular, we caution against using these models in applications where precision is critical, as subtle flaws in the outputs could lead to errors or inaccuracies.
Functionally, these models are trained to be able to perform the following tasks for research purposes, and are evaluated on these tasks:
* Generate 3D point clouds conditioned on single rendered images
* Generate 3D point clouds conditioned on text
* Create 3D meshes from noisy 3D point clouds
Our image-conditional models are intended to produce coherent point clouds, given a representative rendering of a 3D object. However, at their current level of capabilities, the models sometimes fail to generate coherent output, either producing incorrect geometry where the rendering is occluded, or producing geometry that is inconsistent with visible parts of the rendering. The resulting point clouds are relatively low-resolution, and are often noisy and contain defects such as outliers or cracks.
Our text-conditional model is sometimes capable of producing 3D point clouds which can be recognized as the provided text description, especially when the text description is simple. However, we find that this model fails to generalize to complex prompts or unusual objects.
# Performance and Limitations
Our image-conditional models are limited by the text-to-image model that is used to produce synthetic views. If the text-to-image model contains a bias or fails to understand a particular concept, these limitations will be passed down to the image-conditional point cloud model through conditioning images.
While our main focus is on image-conditional models, we also experimented with a text-conditional model. We find that this model can sometimes produce 3D models of people that exhibit gender biases (for example, samples for "a man" tend to be wider and less narrow than samples for "a woman"). We additionally find that this model is sometimes capable of producing violent objects such as guns or tanks, although these generations are always low-quality and unrealistic.
Since our dataset contains many simplistic, cartoonish 3D objects, our models are prone to mimicking this style.
While these models were developed for research purposes, they have potential implications if used more broadly. For example, the ability to generate 3D point clouds from single images could help advance research in computer graphics, virtual reality, and robotics. The text-conditional model could allow for users to easily create 3D models from simple descriptions, which could be useful for rapid prototyping or 3D printing.
The combination of these models with 3D printing could potentially be harmful, for example if used to prototype dangerous objects or when parts created by the model are trusted without external validation.
Finally, point cloud models inherit many of the same risks and limitations as image-generation models, including the propensity to produce biased or otherwise harmful content or to carry dual-use risk. More research is needed on how these risks manifest themselves as capabilities improve.

View File

View File

View File

@ -0,0 +1,64 @@
"""
Based on https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py
"""
from typing import Any, Dict
import numpy as np
from .gaussian_diffusion import (
GaussianDiffusion,
SpacedDiffusion,
get_named_beta_schedule,
space_timesteps,
)
BASE_DIFFUSION_CONFIG = {
"channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0],
"channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255],
"mean_type": "epsilon",
"schedule": "cosine",
"timesteps": 1024,
}
DIFFUSION_CONFIGS = {
"base40M-imagevec": BASE_DIFFUSION_CONFIG,
"base40M-textvec": BASE_DIFFUSION_CONFIG,
"base40M-uncond": BASE_DIFFUSION_CONFIG,
"base40M": BASE_DIFFUSION_CONFIG,
"base300M": BASE_DIFFUSION_CONFIG,
"base1B": BASE_DIFFUSION_CONFIG,
"upsample": {
"channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0],
"channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255],
"mean_type": "epsilon",
"schedule": "linear",
"timesteps": 1024,
},
}
def diffusion_from_config(config: Dict[str, Any]) -> GaussianDiffusion:
schedule = config["schedule"]
steps = config["timesteps"]
respace = config.get("respacing", None)
mean_type = config.get("mean_type", "epsilon")
betas = get_named_beta_schedule(schedule, steps)
channel_scales = config.get("channel_scales", None)
channel_biases = config.get("channel_biases", None)
if channel_scales is not None:
channel_scales = np.array(channel_scales)
if channel_biases is not None:
channel_biases = np.array(channel_biases)
kwargs = dict(
betas=betas,
model_mean_type=mean_type,
model_var_type="learned_range",
loss_type="mse",
channel_scales=channel_scales,
channel_biases=channel_biases,
)
if respace is None:
return GaussianDiffusion(**kwargs)
else:
return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,332 @@
"""
Based on: https://github.com/crowsonkb/k-diffusion
Copyright (c) 2022 Katherine Crowson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import numpy as np
import torch as th
from .gaussian_diffusion import GaussianDiffusion, mean_flat
class KarrasDenoiser:
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def get_snr(self, sigmas):
return sigmas**-2
def get_sigmas(self, sigmas):
return sigmas
def get_scalings(self, sigma):
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out, c_in
def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None):
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = th.randn_like(x_start)
terms = {}
dims = x_start.ndim
x_t = x_start + noise * append_dims(sigmas, dims)
c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)]
model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs)
target = (x_start - c_skip * x_t) / c_out
terms["mse"] = mean_flat((model_output - target) ** 2)
terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
terms["loss"] = terms["mse"]
return terms
def denoise(self, model, x_t, sigmas, **model_kwargs):
c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)]
rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
denoised = c_out * model_output + c_skip * x_t
return model_output, denoised
class GaussianToKarrasDenoiser:
def __init__(self, model, diffusion):
from scipy import interpolate
self.model = model
self.diffusion = diffusion
self.alpha_cumprod_to_t = interpolate.interp1d(
diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps)
)
def sigma_to_t(self, sigma):
alpha_cumprod = 1.0 / (sigma**2 + 1)
if alpha_cumprod > self.diffusion.alphas_cumprod[0]:
return 0
elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]:
return self.diffusion.num_timesteps - 1
else:
return float(self.alpha_cumprod_to_t(alpha_cumprod))
def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None):
t = th.tensor(
[self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()],
dtype=th.long,
device=sigmas.device,
)
c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim)
out = self.diffusion.p_mean_variance(
self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
return None, out["pred_xstart"]
def karras_sample(*args, **kwargs):
last = None
for x in karras_sample_progressive(*args, **kwargs):
last = x["x"]
return last
def karras_sample_progressive(
diffusion,
model,
shape,
steps,
clip_denoised=True,
progress=False,
model_kwargs=None,
device=None,
sigma_min=0.002,
sigma_max=80, # higher for highres?
rho=7.0,
sampler="heun",
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
guidance_scale=0.0,
):
sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
x_T = th.randn(*shape, device=device) * sigma_max
sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[
sampler
]
if sampler != "ancestral":
sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise)
else:
sampler_args = {}
if isinstance(diffusion, KarrasDenoiser):
def denoiser(x_t, sigma):
_, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)
if clip_denoised:
denoised = denoised.clamp(-1, 1)
return denoised
elif isinstance(diffusion, GaussianDiffusion):
model = GaussianToKarrasDenoiser(model, diffusion)
def denoiser(x_t, sigma):
_, denoised = model.denoise(
x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
return denoised
else:
raise NotImplementedError
if guidance_scale != 0 and guidance_scale != 1:
def guided_denoiser(x_t, sigma):
x_t = th.cat([x_t, x_t], dim=0)
sigma = th.cat([sigma, sigma], dim=0)
x_0 = denoiser(x_t, sigma)
cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0)
x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0)
return x_0
else:
guided_denoiser = denoiser
for obj in sample_fn(
guided_denoiser,
x_T,
sigmas,
progress=progress,
**sampler_args,
):
if isinstance(diffusion, GaussianDiffusion):
yield diffusion.unscale_out_dict(obj)
else:
yield obj
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = th.linspace(0, 1, n)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
def get_ancestral_step(sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up
@th.no_grad()
def sample_euler_ancestral(model, x, sigmas, progress=False):
"""Ancestral sampling with Euler method steps."""
s_in = x.new_ones([x.shape[0]])
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
denoised = model(x, sigmas[i] * s_in)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised}
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + th.randn_like(x) * sigma_up
yield {"x": x, "pred_xstart": x}
@th.no_grad()
def sample_heun(
denoiser,
x,
sigmas,
progress=False,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
s_in = x.new_ones([x.shape[0]])
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
gamma = (
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
)
eps = th.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = denoiser(x, sigma_hat * s_in)
d = to_d(x, sigma_hat, denoised)
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised}
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
yield {"x": x, "pred_xstart": denoised}
@th.no_grad()
def sample_dpm(
denoiser,
x,
sigmas,
progress=False,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
s_in = x.new_ones([x.shape[0]])
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
gamma = (
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
)
eps = th.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = denoiser(x, sigma_hat * s_in)
d = to_d(x, sigma_hat, denoised)
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
denoised_2 = denoiser(x_2, sigma_mid * s_in)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
yield {"x": x, "pred_xstart": denoised}
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
def append_zero(x):
return th.cat([x, x.new_zeros([1])])

View File

@ -0,0 +1,263 @@
"""
Helpers for sampling from a single- or multi-stage point cloud diffusion model.
"""
from typing import Any, Callable, Dict, Iterator, List, Sequence, Tuple
import torch
import torch.nn as nn
from point_e.util.point_cloud import PointCloud
from .gaussian_diffusion import GaussianDiffusion
from .k_diffusion import karras_sample_progressive
class PointCloudSampler:
"""
A wrapper around a model or stack of models that produces conditional or
unconditional sample tensors.
By default, this will load models and configs from files.
If you want to modify the sampler arguments of an existing sampler, call
with_options() or with_args().
"""
def __init__(
self,
device: torch.device,
models: Sequence[nn.Module],
diffusions: Sequence[GaussianDiffusion],
num_points: Sequence[int],
aux_channels: Sequence[str],
model_kwargs_key_filter: Sequence[str] = ("*",),
guidance_scale: Sequence[float] = (3.0, 3.0),
clip_denoised: bool = True,
use_karras: Sequence[bool] = (True, True),
karras_steps: Sequence[int] = (64, 64),
sigma_min: Sequence[float] = (1e-3, 1e-3),
sigma_max: Sequence[float] = (120, 160),
s_churn: Sequence[float] = (3, 0),
):
n = len(models)
assert n > 0
if n > 1:
if len(guidance_scale) == 1:
# Don't guide the upsamplers by default.
guidance_scale = list(guidance_scale) + [1.0] * (n - 1)
if len(use_karras) == 1:
use_karras = use_karras * n
if len(karras_steps) == 1:
karras_steps = karras_steps * n
if len(sigma_min) == 1:
sigma_min = sigma_min * n
if len(sigma_max) == 1:
sigma_max = sigma_max * n
if len(s_churn) == 1:
s_churn = s_churn * n
if len(model_kwargs_key_filter) == 1:
model_kwargs_key_filter = model_kwargs_key_filter * n
if len(model_kwargs_key_filter) == 0:
model_kwargs_key_filter = ["*"] * n
assert len(guidance_scale) == n
assert len(use_karras) == n
assert len(karras_steps) == n
assert len(sigma_min) == n
assert len(sigma_max) == n
assert len(s_churn) == n
assert len(model_kwargs_key_filter) == n
self.device = device
self.num_points = num_points
self.aux_channels = aux_channels
self.model_kwargs_key_filter = model_kwargs_key_filter
self.guidance_scale = guidance_scale
self.clip_denoised = clip_denoised
self.use_karras = use_karras
self.karras_steps = karras_steps
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.s_churn = s_churn
self.models = models
self.diffusions = diffusions
@property
def num_stages(self) -> int:
return len(self.models)
def sample_batch(self, batch_size: int, model_kwargs: Dict[str, Any]) -> torch.Tensor:
samples = None
for x in self.sample_batch_progressive(batch_size, model_kwargs):
samples = x
return samples
def sample_batch_progressive(
self, batch_size: int, model_kwargs: Dict[str, Any]
) -> Iterator[torch.Tensor]:
samples = None
for (
model,
diffusion,
stage_num_points,
stage_guidance_scale,
stage_use_karras,
stage_karras_steps,
stage_sigma_min,
stage_sigma_max,
stage_s_churn,
stage_key_filter,
) in zip(
self.models,
self.diffusions,
self.num_points,
self.guidance_scale,
self.use_karras,
self.karras_steps,
self.sigma_min,
self.sigma_max,
self.s_churn,
self.model_kwargs_key_filter,
):
stage_model_kwargs = model_kwargs.copy()
if stage_key_filter != "*":
use_keys = set(stage_key_filter.split(","))
stage_model_kwargs = {k: v for k, v in stage_model_kwargs.items() if k in use_keys}
if samples is not None:
stage_model_kwargs["low_res"] = samples
if hasattr(model, "cached_model_kwargs"):
stage_model_kwargs = model.cached_model_kwargs(batch_size, stage_model_kwargs)
sample_shape = (batch_size, 3 + len(self.aux_channels), stage_num_points)
if stage_guidance_scale != 1 and stage_guidance_scale != 0:
for k, v in stage_model_kwargs.copy().items():
stage_model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
if stage_use_karras:
samples_it = karras_sample_progressive(
diffusion=diffusion,
model=model,
shape=sample_shape,
steps=stage_karras_steps,
clip_denoised=self.clip_denoised,
model_kwargs=stage_model_kwargs,
device=self.device,
sigma_min=stage_sigma_min,
sigma_max=stage_sigma_max,
s_churn=stage_s_churn,
guidance_scale=stage_guidance_scale,
)
else:
internal_batch_size = batch_size
if stage_guidance_scale:
model = self._uncond_guide_model(model, stage_guidance_scale)
internal_batch_size *= 2
samples_it = diffusion.p_sample_loop_progressive(
model,
shape=(internal_batch_size, *sample_shape[1:]),
model_kwargs=stage_model_kwargs,
device=self.device,
clip_denoised=self.clip_denoised,
)
for x in samples_it:
samples = x["pred_xstart"][:batch_size]
if "low_res" in stage_model_kwargs:
samples = torch.cat(
[stage_model_kwargs["low_res"][: len(samples)], samples], dim=-1
)
yield samples
@classmethod
def combine(cls, *samplers: "PointCloudSampler") -> "PointCloudSampler":
assert all(x.device == samplers[0].device for x in samplers[1:])
assert all(x.aux_channels == samplers[0].aux_channels for x in samplers[1:])
assert all(x.clip_denoised == samplers[0].clip_denoised for x in samplers[1:])
return cls(
device=samplers[0].device,
models=[x for y in samplers for x in y.models],
diffusions=[x for y in samplers for x in y.diffusions],
num_points=[x for y in samplers for x in y.num_points],
aux_channels=samplers[0].aux_channels,
model_kwargs_key_filter=[x for y in samplers for x in y.model_kwargs_key_filter],
guidance_scale=[x for y in samplers for x in y.guidance_scale],
clip_denoised=samplers[0].clip_denoised,
use_karras=[x for y in samplers for x in y.use_karras],
karras_steps=[x for y in samplers for x in y.karras_steps],
sigma_min=[x for y in samplers for x in y.sigma_min],
sigma_max=[x for y in samplers for x in y.sigma_max],
s_churn=[x for y in samplers for x in y.s_churn],
)
def _uncond_guide_model(
self, model: Callable[..., torch.Tensor], scale: float
) -> Callable[..., torch.Tensor]:
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
return model_fn
def split_model_output(
self,
output: torch.Tensor,
rescale_colors: bool = False,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
assert (
len(self.aux_channels) + 3 == output.shape[1]
), "there must be three spatial channels before aux"
pos, joined_aux = output[:, :3], output[:, 3:]
aux = {}
for i, name in enumerate(self.aux_channels):
v = joined_aux[:, i]
if name in {"R", "G", "B", "A"}:
v = v.clamp(0, 255).round()
if rescale_colors:
v = v / 255.0
aux[name] = v
return pos, aux
def output_to_point_clouds(self, output: torch.Tensor) -> List[PointCloud]:
res = []
for sample in output:
xyz, aux = self.split_model_output(sample[None], rescale_colors=True)
res.append(
PointCloud(
coords=xyz[0].t().cpu().numpy(),
channels={k: v[0].cpu().numpy() for k, v in aux.items()},
)
)
return res
def with_options(
self,
guidance_scale: float,
clip_denoised: bool,
use_karras: Sequence[bool] = (True, True),
karras_steps: Sequence[int] = (64, 64),
sigma_min: Sequence[float] = (1e-3, 1e-3),
sigma_max: Sequence[float] = (120, 160),
s_churn: Sequence[float] = (3, 0),
) -> "PointCloudSampler":
return PointCloudSampler(
device=self.device,
models=self.models,
diffusions=self.diffusions,
num_points=self.num_points,
aux_channels=self.aux_channels,
model_kwargs_key_filter=self.model_kwargs_key_filter,
guidance_scale=guidance_scale,
clip_denoised=clip_denoised,
use_karras=use_karras,
karras_steps=karras_steps,
sigma_min=sigma_min,
sigma_max=sigma_max,
s_churn=s_churn,
)

View File

View File

@ -0,0 +1,119 @@
from abc import ABC, abstractmethod
from multiprocessing.pool import ThreadPool
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from point_e.models.download import load_checkpoint
from .npz_stream import NpzStreamer
from .pointnet2_cls_ssg import get_model
def get_torch_devices() -> List[Union[str, torch.device]]:
if torch.cuda.is_available():
return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
else:
return ["cpu"]
class FeatureExtractor(ABC):
@property
@abstractmethod
def supports_predictions(self) -> bool:
pass
@property
@abstractmethod
def feature_dim(self) -> int:
pass
@property
@abstractmethod
def num_classes(self) -> int:
pass
@abstractmethod
def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]:
"""
For a stream of point cloud batches, compute feature vectors and class
predictions.
:param point_clouds: a streamer for a sample batch. Typically, arr_0
will contain the XYZ coordinates.
:return: a tuple (features, predictions)
- features: a [B x feature_dim] array of feature vectors.
- predictions: a [B x num_classes] array of probabilities.
"""
class PointNetClassifier(FeatureExtractor):
def __init__(
self,
devices: List[Union[str, torch.device]],
device_batch_size: int = 64,
cache_dir: Optional[str] = None,
):
state_dict = load_checkpoint("pointnet", device=torch.device("cpu"), cache_dir=cache_dir)[
"model_state_dict"
]
self.device_batch_size = device_batch_size
self.devices = devices
self.models = []
for device in devices:
model = get_model(num_class=40, normal_channel=False, width_mult=2)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
self.models.append(model)
@property
def supports_predictions(self) -> bool:
return True
@property
def feature_dim(self) -> int:
return 256
@property
def num_classes(self) -> int:
return 40
def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]:
batch_size = self.device_batch_size * len(self.devices)
point_clouds = (x["arr_0"] for x in streamer.stream(batch_size, ["arr_0"]))
output_features = []
output_predictions = []
with ThreadPool(len(self.devices)) as pool:
for batch in point_clouds:
batch = normalize_point_clouds(batch)
batches = []
for i, device in zip(range(0, len(batch), self.device_batch_size), self.devices):
batches.append(
torch.from_numpy(batch[i : i + self.device_batch_size])
.permute(0, 2, 1)
.to(dtype=torch.float32, device=device)
)
def compute_features(i_batch):
i, batch = i_batch
with torch.no_grad():
return self.models[i](batch, features=True)
for logits, _, features in pool.imap(compute_features, enumerate(batches)):
output_features.append(features.cpu().numpy())
output_predictions.append(logits.exp().cpu().numpy())
return np.concatenate(output_features, axis=0), np.concatenate(output_predictions, axis=0)
def normalize_point_clouds(pc: np.ndarray) -> np.ndarray:
centroids = np.mean(pc, axis=1, keepdims=True)
pc = pc - centroids
m = np.max(np.sqrt(np.sum(pc**2, axis=-1, keepdims=True)), axis=1, keepdims=True)
pc = pc / m
return pc

View File

@ -0,0 +1,81 @@
"""
Adapted from https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py
"""
import warnings
import numpy as np
from scipy import linalg
class InvalidFIDException(Exception):
pass
class FIDStatistics:
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
self.mu = mu
self.sigma = sigma
def frechet_distance(self, other, eps=1e-6):
"""
Compute the Frechet distance between two sets of statistics.
"""
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
mu1, sigma1 = self.mu, self.sigma
mu2, sigma2 = other.mu, other.sigma
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
assert (
sigma1.shape == sigma2.shape
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
% eps
)
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
def compute_statistics(feats: np.ndarray) -> FIDStatistics:
mu = np.mean(feats, axis=0)
sigma = np.cov(feats, rowvar=False)
return FIDStatistics(mu, sigma)
def compute_inception_score(preds: np.ndarray, split_size: int = 5000) -> float:
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
scores = []
for i in range(0, len(preds), split_size):
part = preds[i : i + split_size]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return float(np.mean(scores))

View File

@ -0,0 +1,270 @@
import glob
import io
import os
import re
import zipfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Sequence, Tuple
import numpy as np
@dataclass
class NumpyArrayInfo:
"""
Information about an array in an npz file.
"""
name: str
dtype: np.dtype
shape: Tuple[int]
@classmethod
def infos_from_first_file(cls, glob_path: str) -> Dict[str, "NumpyArrayInfo"]:
paths, _ = _npz_paths_and_length(glob_path)
return cls.infos_from_file(paths[0])
@classmethod
def infos_from_file(cls, npz_path: str) -> Dict[str, "NumpyArrayInfo"]:
"""
Extract the info of every array in an npz file.
"""
if not os.path.exists(npz_path):
raise FileNotFoundError(f"batch of samples was not found: {npz_path}")
results = {}
with open(npz_path, "rb") as f:
with zipfile.ZipFile(f, "r") as zip_f:
for name in zip_f.namelist():
if not name.endswith(".npy"):
continue
key_name = name[: -len(".npy")]
with zip_f.open(name, "r") as arr_f:
version = np.lib.format.read_magic(arr_f)
if version == (1, 0):
header = np.lib.format.read_array_header_1_0(arr_f)
elif version == (2, 0):
header = np.lib.format.read_array_header_2_0(arr_f)
else:
raise ValueError(f"unknown numpy array version: {version}")
shape, _, dtype = header
results[key_name] = cls(name=key_name, dtype=dtype, shape=shape)
return results
@property
def elem_shape(self) -> Tuple[int]:
return self.shape[1:]
def validate(self):
if self.name in {"R", "G", "B"}:
if len(self.shape) != 2:
raise ValueError(
f"expecting exactly 2-D shape for '{self.name}' but got: {self.shape}"
)
elif self.name == "arr_0":
if len(self.shape) < 2:
raise ValueError(f"expecting at least 2-D shape but got: {self.shape}")
elif len(self.shape) == 3:
# For audio, we require continuous samples.
if not np.issubdtype(self.dtype, np.floating):
raise ValueError(
f"invalid dtype for audio batch: {self.dtype} (expected float)"
)
elif self.dtype != np.uint8:
raise ValueError(f"invalid dtype for image batch: {self.dtype} (expected uint8)")
class NpzStreamer:
def __init__(self, glob_path: str):
self.paths, self.trunc_length = _npz_paths_and_length(glob_path)
self.infos = NumpyArrayInfo.infos_from_file(self.paths[0])
def keys(self) -> List[str]:
return list(self.infos.keys())
def stream(self, batch_size: int, keys: Sequence[str]) -> Iterator[Dict[str, np.ndarray]]:
cur_batch = None
num_remaining = self.trunc_length
for path in self.paths:
if num_remaining is not None and num_remaining <= 0:
break
with open_npz_arrays(path, keys) as readers:
combined_reader = CombinedReader(keys, readers)
while num_remaining is None or num_remaining > 0:
read_bs = batch_size
if cur_batch is not None:
read_bs -= _dict_batch_size(cur_batch)
if num_remaining is not None:
read_bs = min(read_bs, num_remaining)
batch = combined_reader.read_batch(read_bs)
if batch is None:
break
if num_remaining is not None:
num_remaining -= _dict_batch_size(batch)
if cur_batch is None:
cur_batch = batch
else:
cur_batch = {
# pylint: disable=unsubscriptable-object
k: np.concatenate([cur_batch[k], v], axis=0)
for k, v in batch.items()
}
if _dict_batch_size(cur_batch) == batch_size:
yield cur_batch
cur_batch = None
if cur_batch is not None:
yield cur_batch
def _npz_paths_and_length(glob_path: str) -> Tuple[List[str], Optional[int]]:
# Match slice syntax like path[:100].
count_match = re.match("^(.*)\\[:([0-9]*)\\]$", glob_path)
if count_match:
raw_path = count_match[1]
max_count = int(count_match[2])
else:
raw_path = glob_path
max_count = None
paths = sorted(glob.glob(raw_path))
if not len(paths):
raise ValueError(f"no paths found matching: {glob_path}")
return paths, max_count
class NpzArrayReader(ABC):
@abstractmethod
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
pass
class StreamingNpzArrayReader(NpzArrayReader):
def __init__(self, arr_f, shape, dtype):
self.arr_f = arr_f
self.shape = shape
self.dtype = dtype
self.idx = 0
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.shape[0]:
return None
bs = min(batch_size, self.shape[0] - self.idx)
self.idx += bs
if self.dtype.itemsize == 0:
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
read_count = bs * np.prod(self.shape[1:])
read_size = int(read_count * self.dtype.itemsize)
data = _read_bytes(self.arr_f, read_size, "array data")
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
class MemoryNpzArrayReader(NpzArrayReader):
def __init__(self, arr):
self.arr = arr
self.idx = 0
@classmethod
def load(cls, path: str, arr_name: str):
with open(path, "rb") as f:
arr = np.load(f)[arr_name]
return cls(arr)
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.arr.shape[0]:
return None
res = self.arr[self.idx : self.idx + batch_size]
self.idx += batch_size
return res
@contextmanager
def open_npz_arrays(path: str, arr_names: Sequence[str]) -> List[NpzArrayReader]:
if not len(arr_names):
yield []
return
arr_name = arr_names[0]
with open_array(path, arr_name) as arr_f:
version = np.lib.format.read_magic(arr_f)
header = None
if version == (1, 0):
header = np.lib.format.read_array_header_1_0(arr_f)
elif version == (2, 0):
header = np.lib.format.read_array_header_2_0(arr_f)
if header is None:
reader = MemoryNpzArrayReader.load(path, arr_name)
else:
shape, fortran, dtype = header
if fortran or dtype.hasobject:
reader = MemoryNpzArrayReader.load(path, arr_name)
else:
reader = StreamingNpzArrayReader(arr_f, shape, dtype)
with open_npz_arrays(path, arr_names[1:]) as next_readers:
yield [reader] + next_readers
class CombinedReader:
def __init__(self, keys: List[str], readers: List[NpzArrayReader]):
self.keys = keys
self.readers = readers
def read_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray]]:
batches = [r.read_batch(batch_size) for r in self.readers]
any_none = any(x is None for x in batches)
all_none = all(x is None for x in batches)
if any_none != all_none:
raise RuntimeError("different keys had different numbers of elements")
if any_none:
return None
if any(len(x) != len(batches[0]) for x in batches):
raise RuntimeError("different keys had different numbers of elements")
return dict(zip(self.keys, batches))
def _read_bytes(fp, size, error_template="ran out of data"):
"""
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
Read from file-like object until size bytes are read.
Raises ValueError if not EOF is encountered before size bytes are read.
Non-blocking objects only supported if they derive from io objects.
Required as e.g. ZipExtFile in python 2.6 can return less data than
requested.
"""
data = bytes()
while True:
# io files (default in python3) return None or raise on
# would-block, python2 file will truncate, probably nothing can be
# done about that. note that regular files can't be non-blocking
try:
r = fp.read(size - len(data))
data += r
if len(r) == 0 or len(data) == size:
break
except io.BlockingIOError:
pass
if len(data) != size:
msg = "EOF: reading %s, expected %d bytes got %d"
raise ValueError(msg % (error_template, size, len(data)))
else:
return data
@contextmanager
def open_array(path: str, arr_name: str):
with open(path, "rb") as f:
with zipfile.ZipFile(f, "r") as zip_f:
if f"{arr_name}.npy" not in zip_f.namelist():
raise ValueError(f"missing {arr_name} in npz file")
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
yield arr_f
def _dict_batch_size(objs: Dict[str, np.ndarray]) -> int:
return len(next(iter(objs.values())))

View File

@ -0,0 +1,101 @@
"""
Based on: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/eb64fe0b4c24055559cea26299cb485dcb43d8dd/models/pointnet2_cls_ssg.py
MIT License
Copyright (c) 2019 benny
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import torch.nn as nn
import torch.nn.functional as F
from .pointnet2_utils import PointNetSetAbstraction
class get_model(nn.Module):
def __init__(self, num_class, normal_channel=True, width_mult=1):
super(get_model, self).__init__()
self.width_mult = width_mult
in_channel = 6 if normal_channel else 3
self.normal_channel = normal_channel
self.sa1 = PointNetSetAbstraction(
npoint=512,
radius=0.2,
nsample=32,
in_channel=in_channel,
mlp=[64 * width_mult, 64 * width_mult, 128 * width_mult],
group_all=False,
)
self.sa2 = PointNetSetAbstraction(
npoint=128,
radius=0.4,
nsample=64,
in_channel=128 * width_mult + 3,
mlp=[128 * width_mult, 128 * width_mult, 256 * width_mult],
group_all=False,
)
self.sa3 = PointNetSetAbstraction(
npoint=None,
radius=None,
nsample=None,
in_channel=256 * width_mult + 3,
mlp=[256 * width_mult, 512 * width_mult, 1024 * width_mult],
group_all=True,
)
self.fc1 = nn.Linear(1024 * width_mult, 512 * width_mult)
self.bn1 = nn.BatchNorm1d(512 * width_mult)
self.drop1 = nn.Dropout(0.4)
self.fc2 = nn.Linear(512 * width_mult, 256 * width_mult)
self.bn2 = nn.BatchNorm1d(256 * width_mult)
self.drop2 = nn.Dropout(0.4)
self.fc3 = nn.Linear(256 * width_mult, num_class)
def forward(self, xyz, features=False):
B, _, _ = xyz.shape
if self.normal_channel:
norm = xyz[:, 3:, :]
xyz = xyz[:, :3, :]
else:
norm = None
l1_xyz, l1_points = self.sa1(xyz, norm)
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
x = l3_points.view(B, 1024 * self.width_mult)
x = self.drop1(F.relu(self.bn1(self.fc1(x))))
result_features = self.bn2(self.fc2(x))
x = self.drop2(F.relu(result_features))
x = self.fc3(x)
x = F.log_softmax(x, -1)
if features:
return x, l3_points, result_features
else:
return x, l3_points
class get_loss(nn.Module):
def __init__(self):
super(get_loss, self).__init__()
def forward(self, pred, target, trans_feat):
total_loss = F.nll_loss(pred, target)
return total_loss

View File

@ -0,0 +1,356 @@
"""
Based on: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/eb64fe0b4c24055559cea26299cb485dcb43d8dd/models/pointnet_utils.py
MIT License
Copyright (c) 2019 benny
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from time import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
def pc_normalize(pc):
l = pc.shape[0]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = (
torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint, deterministic=False):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
if deterministic:
farthest = torch.arange(0, B, dtype=torch.long).to(device)
else:
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius**2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, deterministic=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
S = npoint
fps_idx = farthest_point_sample(xyz, npoint, deterministic=deterministic) # [B, npoint, C]
new_xyz = index_points(xyz, fps_idx)
idx = query_ball_point(radius, nsample, xyz, new_xyz)
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat(
[grouped_xyz_norm, grouped_points], dim=-1
) # [B, npoint, nsample, C+D]
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
def sample_and_group_all(xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C)
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(
self.npoint, self.radius, self.nsample, xyz, points, deterministic=not self.training
)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = torch.max(new_points, 2)[0]
new_xyz = new_xyz.permute(0, 2, 1)
return new_xyz, new_points
class PointNetSetAbstractionMsg(nn.Module):
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
super(PointNetSetAbstractionMsg, self).__init__()
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
B, N, C = xyz.shape
S = self.npoint
new_xyz = index_points(xyz, farthest_point_sample(xyz, S, deterministic=not self.training))
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
for j in range(len(self.conv_blocks[i])):
conv = self.conv_blocks[i][j]
bn = self.bn_blocks[i][j]
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
new_points_list.append(new_points)
new_xyz = new_xyz.permute(0, 2, 1)
new_points_concat = torch.cat(new_points_list, dim=1)
return new_xyz, new_points_concat
class PointNetFeaturePropagation(nn.Module):
def __init__(self, in_channel, mlp):
super(PointNetFeaturePropagation, self).__init__()
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
last_channel = out_channel
def forward(self, xyz1, xyz2, points1, points2):
"""
Input:
xyz1: input points position data, [B, C, N]
xyz2: sampled input points position data, [B, C, S]
points1: input points data, [B, D, N]
points2: input points data, [B, D, S]
Return:
new_points: upsampled points data, [B, D', N]
"""
xyz1 = xyz1.permute(0, 2, 1)
xyz2 = xyz2.permute(0, 2, 1)
points2 = points2.permute(0, 2, 1)
B, N, C = xyz1.shape
_, S, _ = xyz2.shape
if S == 1:
interpolated_points = points2.repeat(1, N, 1)
else:
dists = square_distance(xyz1, xyz2)
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_points = torch.sum(
index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2
)
if points1 is not None:
points1 = points1.permute(0, 2, 1)
new_points = torch.cat([points1, interpolated_points], dim=-1)
else:
new_points = interpolated_points
new_points = new_points.permute(0, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
return new_points

View File

@ -0,0 +1,533 @@
"""
Script to run within Blender to render a 3D model as RGBAD images.
Example usage
blender -b -P blender_script.py -- \
--input_path ../../examples/example_data/corgi.ply \
--output_path render_out
Pass `--camera_pose z-circular-elevated` for the rendering used to compute
CLIP R-Precision results.
The output directory will include metadata json files for each rendered view,
as well as a global metadata file for the render. Each image will be saved as
a collection of 16-bit PNG files for each channel (rgbad), as well as a full
grayscale render of the view.
"""
import argparse
import json
import math
import os
import random
import sys
import bpy
from mathutils import Vector
from mathutils.noise import random_unit_vector
MAX_DEPTH = 5.0
FORMAT_VERSION = 6
UNIFORM_LIGHT_DIRECTION = [0.09387503, -0.63953443, -0.7630093]
def clear_scene():
bpy.ops.object.select_all(action="SELECT")
bpy.ops.object.delete()
def clear_lights():
bpy.ops.object.select_all(action="DESELECT")
for obj in bpy.context.scene.objects.values():
if isinstance(obj.data, bpy.types.Light):
obj.select_set(True)
bpy.ops.object.delete()
def import_model(path):
clear_scene()
_, ext = os.path.splitext(path)
ext = ext.lower()
if ext == ".obj":
bpy.ops.import_scene.obj(filepath=path)
elif ext in [".glb", ".gltf"]:
bpy.ops.import_scene.gltf(filepath=path)
elif ext == ".stl":
bpy.ops.import_mesh.stl(filepath=path)
elif ext == ".fbx":
bpy.ops.import_scene.fbx(filepath=path)
elif ext == ".dae":
bpy.ops.wm.collada_import(filepath=path)
elif ext == ".ply":
bpy.ops.import_mesh.ply(filepath=path)
else:
raise RuntimeError(f"unexpected extension: {ext}")
def scene_root_objects():
for obj in bpy.context.scene.objects.values():
if not obj.parent:
yield obj
def scene_bbox(single_obj=None, ignore_matrix=False):
bbox_min = (math.inf,) * 3
bbox_max = (-math.inf,) * 3
found = False
for obj in scene_meshes() if single_obj is None else [single_obj]:
found = True
for coord in obj.bound_box:
coord = Vector(coord)
if not ignore_matrix:
coord = obj.matrix_world @ coord
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
if not found:
raise RuntimeError("no objects in scene to compute bounding box for")
return Vector(bbox_min), Vector(bbox_max)
def scene_meshes():
for obj in bpy.context.scene.objects.values():
if isinstance(obj.data, (bpy.types.Mesh)):
yield obj
def normalize_scene():
bbox_min, bbox_max = scene_bbox()
scale = 1 / max(bbox_max - bbox_min)
for obj in scene_root_objects():
obj.scale = obj.scale * scale
# Apply scale to matrix_world.
bpy.context.view_layer.update()
bbox_min, bbox_max = scene_bbox()
offset = -(bbox_min + bbox_max) / 2
for obj in scene_root_objects():
obj.matrix_world.translation += offset
bpy.ops.object.select_all(action="DESELECT")
def create_camera():
# https://b3d.interplanety.org/en/how-to-create-camera-through-the-blender-python-api/
camera_data = bpy.data.cameras.new(name="Camera")
camera_object = bpy.data.objects.new("Camera", camera_data)
bpy.context.scene.collection.objects.link(camera_object)
bpy.context.scene.camera = camera_object
def set_camera(direction, camera_dist=2.0):
camera_pos = -camera_dist * direction
bpy.context.scene.camera.location = camera_pos
# https://blender.stackexchange.com/questions/5210/pointing-the-camera-in-a-particular-direction-programmatically
rot_quat = direction.to_track_quat("-Z", "Y")
bpy.context.scene.camera.rotation_euler = rot_quat.to_euler()
bpy.context.view_layer.update()
def randomize_camera(camera_dist=2.0):
direction = random_unit_vector()
set_camera(direction, camera_dist=camera_dist)
def pan_camera(time, axis="Z", camera_dist=2.0, elevation=-0.1):
angle = time * math.pi * 2
direction = [-math.cos(angle), -math.sin(angle), -elevation]
assert axis in ["X", "Y", "Z"]
if axis == "X":
direction = [direction[2], *direction[:2]]
elif axis == "Y":
direction = [direction[0], -elevation, direction[1]]
direction = Vector(direction).normalized()
set_camera(direction, camera_dist=camera_dist)
def place_camera(time, camera_pose_mode="random", camera_dist_min=2.0, camera_dist_max=2.0):
camera_dist = random.uniform(camera_dist_min, camera_dist_max)
if camera_pose_mode == "random":
randomize_camera(camera_dist=camera_dist)
elif camera_pose_mode == "z-circular":
pan_camera(time, axis="Z", camera_dist=camera_dist)
elif camera_pose_mode == "z-circular-elevated":
pan_camera(time, axis="Z", camera_dist=camera_dist, elevation=0.2617993878)
else:
raise ValueError(f"Unknown camera pose mode: {camera_pose_mode}")
def create_light(location, energy=1.0, angle=0.5 * math.pi / 180):
# https://blender.stackexchange.com/questions/215624/how-to-create-a-light-with-the-python-api-in-blender-2-92
light_data = bpy.data.lights.new(name="Light", type="SUN")
light_data.energy = energy
light_data.angle = angle
light_object = bpy.data.objects.new(name="Light", object_data=light_data)
direction = -location
rot_quat = direction.to_track_quat("-Z", "Y")
light_object.rotation_euler = rot_quat.to_euler()
bpy.context.view_layer.update()
bpy.context.collection.objects.link(light_object)
light_object.location = location
def create_random_lights(count=4, distance=2.0, energy=1.5):
clear_lights()
for _ in range(count):
create_light(random_unit_vector() * distance, energy=energy)
def create_camera_light():
clear_lights()
create_light(bpy.context.scene.camera.location, energy=5.0)
def create_uniform_light(backend):
clear_lights()
# Random direction to decorrelate axis-aligned sides.
pos = Vector(UNIFORM_LIGHT_DIRECTION)
angle = 0.0092 if backend == "CYCLES" else math.pi
create_light(pos, energy=5.0, angle=angle)
create_light(-pos, energy=5.0, angle=angle)
def create_vertex_color_shaders():
# By default, Blender will ignore vertex colors in both the
# Eevee and Cycles backends, since these colors aren't
# associated with a material.
#
# What we do here is create a simple material shader and link
# the vertex color to the material color.
for obj in bpy.context.scene.objects.values():
if not isinstance(obj.data, (bpy.types.Mesh)):
continue
if len(obj.data.materials):
# We don't want to override any existing materials.
continue
color_keys = (obj.data.vertex_colors or {}).keys()
if not len(color_keys):
# Many objects will have no materials *or* vertex colors.
continue
mat = bpy.data.materials.new(name="VertexColored")
mat.use_nodes = True
# There should be a Principled BSDF by default.
bsdf_node = None
for node in mat.node_tree.nodes:
if node.type == "BSDF_PRINCIPLED":
bsdf_node = node
assert bsdf_node is not None, "material has no Principled BSDF node to modify"
socket_map = {}
for input in bsdf_node.inputs:
socket_map[input.name] = input
# Make sure nothing lights the object except for the diffuse color.
socket_map["Specular"].default_value = 0.0
socket_map["Roughness"].default_value = 1.0
v_color = mat.node_tree.nodes.new("ShaderNodeVertexColor")
v_color.layer_name = color_keys[0]
mat.node_tree.links.new(v_color.outputs[0], socket_map["Base Color"])
obj.data.materials.append(mat)
def create_default_materials():
for obj in bpy.context.scene.objects.values():
if isinstance(obj.data, (bpy.types.Mesh)):
if not len(obj.data.materials):
mat = bpy.data.materials.new(name="DefaultMaterial")
mat.use_nodes = True
obj.data.materials.append(mat)
def find_materials():
all_materials = set()
for obj in bpy.context.scene.objects.values():
if not isinstance(obj.data, (bpy.types.Mesh)):
continue
for mat in obj.data.materials:
all_materials.add(mat)
return all_materials
def get_socket_value(tree, socket):
default = socket.default_value
if not isinstance(default, float):
default = list(default)
for link in tree.links:
if link.to_socket == socket:
return (link.from_socket, default)
return (None, default)
def clear_socket_input(tree, socket):
for link in list(tree.links):
if link.to_socket == socket:
tree.links.remove(link)
def set_socket_value(tree, socket, socket_and_default):
clear_socket_input(tree, socket)
old_source_socket, default = socket_and_default
if isinstance(default, float) and not isinstance(socket.default_value, float):
# Codepath for setting Emission to a previous alpha value.
socket.default_value = [default] * 3 + [1.0]
else:
socket.default_value = default
if old_source_socket is not None:
tree.links.new(old_source_socket, socket)
def setup_nodes(output_path, capturing_material_alpha: bool = False):
tree = bpy.context.scene.node_tree
links = tree.links
for node in tree.nodes:
tree.nodes.remove(node)
# Helpers to perform math on links and constants.
def node_op(op: str, *args, clamp=False):
node = tree.nodes.new(type="CompositorNodeMath")
node.operation = op
if clamp:
node.use_clamp = True
for i, arg in enumerate(args):
if isinstance(arg, (int, float)):
node.inputs[i].default_value = arg
else:
links.new(arg, node.inputs[i])
return node.outputs[0]
def node_clamp(x, maximum=1.0):
return node_op("MINIMUM", x, maximum)
def node_mul(x, y, **kwargs):
return node_op("MULTIPLY", x, y, **kwargs)
input_node = tree.nodes.new(type="CompositorNodeRLayers")
input_node.scene = bpy.context.scene
input_sockets = {}
for output in input_node.outputs:
input_sockets[output.name] = output
if capturing_material_alpha:
color_socket = input_sockets["Image"]
else:
raw_color_socket = input_sockets["Image"]
# We apply sRGB here so that our fixed-point depth map and material
# alpha values are not sRGB, and so that we perform ambient+diffuse
# lighting in linear RGB space.
color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace")
color_node.from_color_space = "Linear"
color_node.to_color_space = "sRGB"
tree.links.new(raw_color_socket, color_node.inputs[0])
color_socket = color_node.outputs[0]
split_node = tree.nodes.new(type="CompositorNodeSepRGBA")
tree.links.new(color_socket, split_node.inputs[0])
# Create separate file output nodes for every channel we care about.
# The process calling this script must decide how to recombine these
# channels, possibly into a single image.
for i, channel in enumerate("rgba") if not capturing_material_alpha else [(0, "MatAlpha")]:
output_node = tree.nodes.new(type="CompositorNodeOutputFile")
output_node.base_path = f"{output_path}_{channel}"
links.new(split_node.outputs[i], output_node.inputs[0])
if capturing_material_alpha:
# No need to re-write depth here.
return
depth_out = node_clamp(node_mul(input_sockets["Depth"], 1 / MAX_DEPTH))
output_node = tree.nodes.new(type="CompositorNodeOutputFile")
output_node.base_path = f"{output_path}_depth"
links.new(depth_out, output_node.inputs[0])
def render_scene(output_path, fast_mode: bool):
use_workbench = bpy.context.scene.render.engine == "BLENDER_WORKBENCH"
if use_workbench:
# We must use a different engine to compute depth maps.
bpy.context.scene.render.engine = "BLENDER_EEVEE"
bpy.context.scene.eevee.taa_render_samples = 1 # faster, since we discard image.
if fast_mode:
if bpy.context.scene.render.engine == "BLENDER_EEVEE":
bpy.context.scene.eevee.taa_render_samples = 1
elif bpy.context.scene.render.engine == "CYCLES":
bpy.context.scene.cycles.samples = 256
else:
if bpy.context.scene.render.engine == "CYCLES":
# We should still impose a per-frame time limit
# so that we don't timeout completely.
bpy.context.scene.cycles.time_limit = 40
bpy.context.view_layer.update()
bpy.context.scene.use_nodes = True
bpy.context.scene.view_layers["ViewLayer"].use_pass_z = True
bpy.context.scene.view_settings.view_transform = "Raw" # sRGB done in graph nodes
bpy.context.scene.render.film_transparent = True
bpy.context.scene.render.resolution_x = 512
bpy.context.scene.render.resolution_y = 512
bpy.context.scene.render.image_settings.file_format = "PNG"
bpy.context.scene.render.image_settings.color_mode = "BW"
bpy.context.scene.render.image_settings.color_depth = "16"
bpy.context.scene.render.filepath = output_path
setup_nodes(output_path)
bpy.ops.render.render(write_still=True)
# The output images must be moved from their own sub-directories, or
# discarded if we are using workbench for the color.
for channel_name in ["r", "g", "b", "a", "depth"]:
sub_dir = f"{output_path}_{channel_name}"
image_path = os.path.join(sub_dir, os.listdir(sub_dir)[0])
name, ext = os.path.splitext(output_path)
if channel_name == "depth" or not use_workbench:
os.rename(image_path, f"{name}_{channel_name}{ext}")
else:
os.remove(image_path)
os.removedirs(sub_dir)
if use_workbench:
# Re-render RGBA using workbench with texture mode, since this seems
# to show the most reasonable colors when lighting is broken.
bpy.context.scene.use_nodes = False
bpy.context.scene.render.engine = "BLENDER_WORKBENCH"
bpy.context.scene.render.image_settings.color_mode = "RGBA"
bpy.context.scene.render.image_settings.color_depth = "8"
bpy.context.scene.display.shading.color_type = "TEXTURE"
bpy.context.scene.display.shading.light = "FLAT"
if fast_mode:
# Single pass anti-aliasing.
bpy.context.scene.display.render_aa = "FXAA"
os.remove(output_path)
bpy.ops.render.render(write_still=True)
bpy.context.scene.render.image_settings.color_mode = "BW"
bpy.context.scene.render.image_settings.color_depth = "16"
def scene_fov():
x_fov = bpy.context.scene.camera.data.angle_x
y_fov = bpy.context.scene.camera.data.angle_y
width = bpy.context.scene.render.resolution_x
height = bpy.context.scene.render.resolution_y
if bpy.context.scene.camera.data.angle == x_fov:
y_fov = 2 * math.atan(math.tan(x_fov / 2) * height / width)
else:
x_fov = 2 * math.atan(math.tan(y_fov / 2) * width / height)
return x_fov, y_fov
def write_camera_metadata(path):
x_fov, y_fov = scene_fov()
bbox_min, bbox_max = scene_bbox()
matrix = bpy.context.scene.camera.matrix_world
with open(path, "w") as f:
json.dump(
dict(
format_version=FORMAT_VERSION,
max_depth=MAX_DEPTH,
bbox=[list(bbox_min), list(bbox_max)],
origin=list(matrix.col[3])[:3],
x_fov=x_fov,
y_fov=y_fov,
x=list(matrix.col[0])[:3],
y=list(-matrix.col[1])[:3],
z=list(-matrix.col[2])[:3],
),
f,
)
def save_rendering_dataset(
input_path: str,
output_path: str,
num_images: int,
backend: str,
light_mode: str,
camera_pose: str,
camera_dist_min: float,
camera_dist_max: float,
fast_mode: bool,
):
assert light_mode in ["random", "uniform", "camera"]
assert camera_pose in ["random", "z-circular", "z-circular-elevated"]
import_model(input_path)
bpy.context.scene.render.engine = backend
normalize_scene()
if light_mode == "random":
create_random_lights()
elif light_mode == "uniform":
create_uniform_light(backend)
create_camera()
create_vertex_color_shaders()
for i in range(num_images):
t = i / max(num_images - 1, 1) # same as np.linspace(0, 1, num_images)
place_camera(
t,
camera_pose_mode=camera_pose,
camera_dist_min=camera_dist_min,
camera_dist_max=camera_dist_max,
)
if light_mode == "camera":
create_camera_light()
render_scene(
os.path.join(output_path, f"{i:05}.png"),
fast_mode=fast_mode,
)
write_camera_metadata(os.path.join(output_path, f"{i:05}.json"))
with open(os.path.join(output_path, "info.json"), "w") as f:
info = dict(
backend=backend,
light_mode=light_mode,
fast_mode=fast_mode,
format_version=FORMAT_VERSION,
channels=["R", "G", "B", "A", "D"],
scale=0.5, # The scene is bounded by [-scale, scale].
)
json.dump(info, f)
def main():
try:
dash_index = sys.argv.index("--")
except ValueError as exc:
raise ValueError("arguments must be preceded by '--'") from exc
raw_args = sys.argv[dash_index + 1 :]
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", required=True, type=str)
parser.add_argument("--output_path", required=True, type=str)
parser.add_argument("--num_images", type=int, default=20)
parser.add_argument("--backend", type=str, default="BLENDER_EEVEE")
parser.add_argument("--light_mode", type=str, default="uniform")
parser.add_argument("--camera_pose", type=str, default="random")
parser.add_argument("--camera_dist_min", type=float, default=2.0)
parser.add_argument("--camera_dist_max", type=float, default=2.0)
parser.add_argument("--fast_mode", action="store_true")
args = parser.parse_args(raw_args)
save_rendering_dataset(
input_path=args.input_path,
output_path=args.output_path,
num_images=args.num_images,
backend=args.backend,
light_mode=args.light_mode,
camera_pose=args.camera_pose,
camera_dist_min=args.camera_dist_min,
camera_dist_max=args.camera_dist_max,
fast_mode=args.fast_mode,
)
main()

View File

@ -0,0 +1,40 @@
"""
Evaluate P-FID between two batches of point clouds.
The point cloud batches should be saved to two npz files, where there
is an arr_0 key of shape [N x K x 3], where K is the dimensionality of
each point cloud and N is the number of clouds.
"""
import argparse
from point_e.evals.feature_extractor import PointNetClassifier, get_torch_devices
from point_e.evals.fid_is import compute_statistics
from point_e.evals.npz_stream import NpzStreamer
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("batch_1", type=str)
parser.add_argument("batch_2", type=str)
args = parser.parse_args()
print("creating classifier...")
clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.cache_dir)
print("computing first batch activations")
features_1, _ = clf.features_and_preds(NpzStreamer(args.batch_1))
stats_1 = compute_statistics(features_1)
del features_1
features_2, _ = clf.features_and_preds(NpzStreamer(args.batch_2))
stats_2 = compute_statistics(features_2)
del features_2
print(f"P-FID: {stats_1.frechet_distance(stats_2)}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,31 @@
"""
Evaluate P-IS of a batch of point clouds.
The point cloud batch should be saved to an npz file, where there is an
arr_0 key of shape [N x K x 3], where K is the dimensionality of each
point cloud and N is the number of clouds.
"""
import argparse
from point_e.evals.feature_extractor import PointNetClassifier, get_torch_devices
from point_e.evals.fid_is import compute_inception_score
from point_e.evals.npz_stream import NpzStreamer
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("batch", type=str)
args = parser.parse_args()
print("creating classifier...")
clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.cache_dir)
print("computing batch predictions")
_, preds = clf.features_and_preds(NpzStreamer(args.batch))
print(f"P-IS: {compute_inception_score(preds)}")
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

BIN
point-e/point_e/examples/example_data/pc_corgi.npz (Stored with Git LFS) Normal file

Binary file not shown.

BIN
point-e/point_e/examples/example_data/pc_cube_stack.npz (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -0,0 +1,115 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"import torch\n",
"from tqdm.auto import tqdm\n",
"\n",
"from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config\n",
"from point_e.diffusion.sampler import PointCloudSampler\n",
"from point_e.models.download import load_checkpoint\n",
"from point_e.models.configs import MODEL_CONFIGS, model_from_config\n",
"from point_e.util.plotting import plot_point_cloud"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print('creating base model...')\n",
"base_name = 'base40M' # use base300M or base1B for better results\n",
"base_model = model_from_config(MODEL_CONFIGS[base_name], device)\n",
"base_model.eval()\n",
"base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])\n",
"\n",
"print('creating upsample model...')\n",
"upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)\n",
"upsampler_model.eval()\n",
"upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])\n",
"\n",
"print('downloading base checkpoint...')\n",
"base_model.load_state_dict(load_checkpoint(base_name, device))\n",
"\n",
"print('downloading upsampler checkpoint...')\n",
"upsampler_model.load_state_dict(load_checkpoint('upsample', device))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sampler = PointCloudSampler(\n",
" device=device,\n",
" models=[base_model, upsampler_model],\n",
" diffusions=[base_diffusion, upsampler_diffusion],\n",
" num_points=[1024, 4096 - 1024],\n",
" aux_channels=['R', 'G', 'B'],\n",
" guidance_scale=[3.0, 3.0],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load an image to condition on.\n",
"img = Image.open('example_data/cube_stack.jpg')\n",
"\n",
"# Produce a sample from the model.\n",
"samples = None\n",
"for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):\n",
" samples = x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pc = sampler.output_to_point_clouds(samples)[0]\n",
"fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.9 64-bit ('3.9.9')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "b270b0f43bc427bcab7703c037711644cc480aac7c1cc8d2940cfaf0b447ee2e"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

View File

@ -0,0 +1,106 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"from tqdm.auto import tqdm\n",
"\n",
"from point_e.models.download import load_checkpoint\n",
"from point_e.models.configs import MODEL_CONFIGS, model_from_config\n",
"from point_e.util.pc_to_mesh import marching_cubes_mesh\n",
"from point_e.util.plotting import plot_point_cloud\n",
"from point_e.util.point_cloud import PointCloud"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print('creating SDF model...')\n",
"name = 'sdf'\n",
"model = model_from_config(MODEL_CONFIGS[name], device)\n",
"model.eval()\n",
"\n",
"print('loading SDF model...')\n",
"model.load_state_dict(load_checkpoint(name, device))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load a point cloud we want to convert into a mesh.\n",
"pc = PointCloud.load('example_data/pc_corgi.npz')\n",
"\n",
"# Plot the point cloud as a sanity check.\n",
"fig = plot_point_cloud(pc, grid_size=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Produce a mesh (with vertex colors)\n",
"mesh = marching_cubes_mesh(\n",
" pc=pc,\n",
" model=model,\n",
" batch_size=4096,\n",
" grid_size=32, # increase to 128 for resolution used in evals\n",
" progress=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Write the mesh to a PLY file to import into some other program.\n",
"with open('mesh.ply', 'wb') as f:\n",
" mesh.write_ply(f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.9 64-bit ('3.9.9')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "b270b0f43bc427bcab7703c037711644cc480aac7c1cc8d2940cfaf0b447ee2e"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,115 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from tqdm.auto import tqdm\n",
"\n",
"from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config\n",
"from point_e.diffusion.sampler import PointCloudSampler\n",
"from point_e.models.download import load_checkpoint\n",
"from point_e.models.configs import MODEL_CONFIGS, model_from_config\n",
"from point_e.util.plotting import plot_point_cloud"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print('creating base model...')\n",
"base_name = 'base40M-textvec'\n",
"base_model = model_from_config(MODEL_CONFIGS[base_name], device)\n",
"base_model.eval()\n",
"base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])\n",
"\n",
"print('creating upsample model...')\n",
"upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)\n",
"upsampler_model.eval()\n",
"upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])\n",
"\n",
"print('downloading base checkpoint...')\n",
"base_model.load_state_dict(load_checkpoint(base_name, device))\n",
"\n",
"print('downloading upsampler checkpoint...')\n",
"upsampler_model.load_state_dict(load_checkpoint('upsample', device))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sampler = PointCloudSampler(\n",
" device=device,\n",
" models=[base_model, upsampler_model],\n",
" diffusions=[base_diffusion, upsampler_diffusion],\n",
" num_points=[1024, 4096 - 1024],\n",
" aux_channels=['R', 'G', 'B'],\n",
" guidance_scale=[3.0, 0.0],\n",
" model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set a prompt to condition on.\n",
"prompt = 'a red motorcycle'\n",
"\n",
"# Produce a sample from the model.\n",
"samples = None\n",
"for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[prompt]))):\n",
" samples = x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pc = sampler.output_to_point_clouds(samples)[0]\n",
"fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.9 64-bit ('3.9.9')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9 (main, Aug 15 2022, 16:40:41) \n[Clang 13.1.6 (clang-1316.0.21.2.5)]"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "b270b0f43bc427bcab7703c037711644cc480aac7c1cc8d2940cfaf0b447ee2e"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

View File

@ -0,0 +1,60 @@
"""
Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
"""
from typing import Callable, Iterable, Sequence, Union
import torch
def checkpoint(
func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
inputs: Sequence[torch.Tensor],
params: Iterable[torch.Tensor],
flag: bool,
):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads

View File

@ -0,0 +1,134 @@
from typing import Any, Dict
import torch
import torch.nn as nn
from .sdf import CrossAttentionPointCloudSDFModel
from .transformer import (
CLIPImageGridPointDiffusionTransformer,
CLIPImageGridUpsamplePointDiffusionTransformer,
CLIPImagePointDiffusionTransformer,
PointDiffusionTransformer,
UpsamplePointDiffusionTransformer,
)
MODEL_CONFIGS = {
"base40M-imagevec": {
"cond_drop_prob": 0.1,
"heads": 8,
"init_scale": 0.25,
"input_channels": 6,
"layers": 12,
"n_ctx": 1024,
"name": "CLIPImagePointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"token_cond": True,
"width": 512,
},
"base40M-textvec": {
"cond_drop_prob": 0.1,
"heads": 8,
"init_scale": 0.25,
"input_channels": 6,
"layers": 12,
"n_ctx": 1024,
"name": "CLIPImagePointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"token_cond": True,
"width": 512,
},
"base40M-uncond": {
"heads": 8,
"init_scale": 0.25,
"input_channels": 6,
"layers": 12,
"n_ctx": 1024,
"name": "PointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"width": 512,
},
"base40M": {
"cond_drop_prob": 0.1,
"heads": 8,
"init_scale": 0.25,
"input_channels": 6,
"layers": 12,
"n_ctx": 1024,
"name": "CLIPImageGridPointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"width": 512,
},
"base300M": {
"cond_drop_prob": 0.1,
"heads": 16,
"init_scale": 0.25,
"input_channels": 6,
"layers": 24,
"n_ctx": 1024,
"name": "CLIPImageGridPointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"width": 1024,
},
"base1B": {
"cond_drop_prob": 0.1,
"heads": 32,
"init_scale": 0.25,
"input_channels": 6,
"layers": 24,
"n_ctx": 1024,
"name": "CLIPImageGridPointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"width": 2048,
},
"upsample": {
"channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0],
"channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255],
"cond_ctx": 1024,
"cond_drop_prob": 0.1,
"heads": 8,
"init_scale": 0.25,
"input_channels": 6,
"layers": 12,
"n_ctx": 3072,
"name": "CLIPImageGridUpsamplePointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"width": 512,
},
"sdf": {
"decoder_heads": 4,
"decoder_layers": 4,
"encoder_heads": 4,
"encoder_layers": 8,
"init_scale": 0.25,
"n_ctx": 4096,
"name": "CrossAttentionPointCloudSDFModel",
"width": 256,
},
}
def model_from_config(config: Dict[str, Any], device: torch.device) -> nn.Module:
config = config.copy()
name = config.pop("name")
if name == "PointDiffusionTransformer":
return PointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImagePointDiffusionTransformer":
return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImageGridPointDiffusionTransformer":
return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "UpsamplePointDiffusionTransformer":
return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImageGridUpsamplePointDiffusionTransformer":
return CLIPImageGridUpsamplePointDiffusionTransformer(
device=device, dtype=torch.float32, **config
)
elif name == "CrossAttentionPointCloudSDFModel":
return CrossAttentionPointCloudSDFModel(device=device, dtype=torch.float32, **config)
raise ValueError(f"unknown model name: {name}")

View File

@ -0,0 +1,78 @@
"""
Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py
"""
import os
from functools import lru_cache
from typing import Dict, Optional
import requests
import torch
from filelock import FileLock
from tqdm.auto import tqdm
MODEL_PATHS = {
"base40M-imagevec": "https://openaipublic.azureedge.net/main/point-e/base_40m_imagevec.pt",
"base40M-textvec": "https://openaipublic.azureedge.net/main/point-e/base_40m_textvec.pt",
"base40M-uncond": "https://openaipublic.azureedge.net/main/point-e/base_40m_uncond.pt",
"base40M": "https://openaipublic.azureedge.net/main/point-e/base_40m.pt",
"base300M": "https://openaipublic.azureedge.net/main/point-e/base_300m.pt",
"base1B": "https://openaipublic.azureedge.net/main/point-e/base_1b.pt",
"upsample": "https://openaipublic.azureedge.net/main/point-e/upsample_40m.pt",
"sdf": "https://openaipublic.azureedge.net/main/point-e/sdf.pt",
"pointnet": "https://openaipublic.azureedge.net/main/point-e/pointnet.pt",
}
@lru_cache()
def default_cache_dir() -> str:
return os.path.join(os.path.abspath(os.getcwd()), "point_e_model_cache")
def fetch_file_cached(
url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
) -> str:
"""
Download the file at the given URL into a local file and return the path.
If cache_dir is specified, it will be used to download the files.
Otherwise, default_cache_dir() is used.
"""
if cache_dir is None:
cache_dir = default_cache_dir()
os.makedirs(cache_dir, exist_ok=True)
local_path = os.path.join(cache_dir, url.split("/")[-1])
if os.path.exists(local_path):
return local_path
response = requests.get(url, stream=True)
size = int(response.headers.get("content-length", "0"))
with FileLock(local_path + ".lock"):
if progress:
pbar = tqdm(total=size, unit="iB", unit_scale=True)
tmp_path = local_path + ".tmp"
with open(tmp_path, "wb") as f:
for chunk in response.iter_content(chunk_size):
if progress:
pbar.update(len(chunk))
f.write(chunk)
os.rename(tmp_path, local_path)
if progress:
pbar.close()
return local_path
def load_checkpoint(
checkpoint_name: str,
device: torch.device,
progress: bool = True,
cache_dir: Optional[str] = None,
chunk_size: int = 4096,
) -> Dict[str, torch.Tensor]:
if checkpoint_name not in MODEL_PATHS:
raise ValueError(
f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
)
path = fetch_file_cached(
MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
)
return torch.load(path, map_location=device)

View File

@ -0,0 +1,146 @@
import math
from typing import Optional
import torch
import torch.nn as nn
from .checkpoint import checkpoint
from .transformer import MLP, init_linear
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_data: int,
width: int,
heads: int,
init_scale: float,
data_width: Optional[int] = None,
):
super().__init__()
self.n_data = n_data
self.width = width
self.heads = heads
self.data_width = width if data_width is None else data_width
self.c_q = nn.Linear(width, width, device=device, dtype=dtype)
self.c_kv = nn.Linear(self.data_width, width * 2, device=device, dtype=dtype)
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
self.attention = QKVMultiheadCrossAttention(
device=device, dtype=dtype, heads=heads, n_data=n_data
)
init_linear(self.c_q, init_scale)
init_linear(self.c_kv, init_scale)
init_linear(self.c_proj, init_scale)
def forward(self, x, data):
x = self.c_q(x)
data = self.c_kv(data)
x = checkpoint(self.attention, (x, data), (), True)
x = self.c_proj(x)
return x
class QKVMultiheadCrossAttention(nn.Module):
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: int):
super().__init__()
self.device = device
self.dtype = dtype
self.heads = heads
self.n_data = n_data
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
scale = 1 / math.sqrt(math.sqrt(attn_ch))
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
weight = torch.einsum(
"bthc,bshc->bhts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
wdtype = weight.dtype
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_data: int,
width: int,
heads: int,
data_width: Optional[int] = None,
init_scale: float = 1.0,
):
super().__init__()
if data_width is None:
data_width = width
self.attn = MultiheadCrossAttention(
device=device,
dtype=dtype,
n_data=n_data,
width=width,
heads=heads,
data_width=data_width,
init_scale=init_scale,
)
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, data: torch.Tensor):
x = x + self.attn(self.ln_1(x), self.ln_2(data))
x = x + self.mlp(self.ln_3(x))
return x
class SimplePerceiver(nn.Module):
"""
Only does cross attention
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_data: int,
width: int,
layers: int,
heads: int,
init_scale: float = 0.25,
data_width: Optional[int] = None,
):
super().__init__()
self.width = width
self.layers = layers
init_scale = init_scale * math.sqrt(1.0 / width)
self.resblocks = nn.ModuleList(
[
ResidualCrossAttentionBlock(
device=device,
dtype=dtype,
n_data=n_data,
width=width,
heads=heads,
init_scale=init_scale,
data_width=data_width,
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor, data: torch.Tensor):
for block in self.resblocks:
x = block(x, data)
return x

View File

@ -0,0 +1,270 @@
from typing import Iterable, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from .download import default_cache_dir
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
class ImageCLIP(nn.Module):
"""
A wrapper around a pre-trained CLIP model that automatically handles
batches of texts, images, and embeddings.
"""
def __init__(
self,
device: torch.device,
dtype: Optional[torch.dtype] = torch.float32,
ensure_used_params: bool = True,
clip_name: str = "ViT-L/14",
cache_dir: Optional[str] = None,
):
super().__init__()
assert clip_name in ["ViT-L/14", "ViT-B/32"]
self.device = device
self.ensure_used_params = ensure_used_params
# Lazy import because of torchvision.
import clip
self.clip_model, self.preprocess = clip.load(
clip_name, device=device, download_root=cache_dir or default_cache_dir()
)
self.clip_name = clip_name
if dtype is not None:
self.clip_model.to(dtype)
self._tokenize = clip.tokenize
@property
def feature_dim(self) -> int:
if self.clip_name == "ViT-L/14":
return 768
else:
return 512
@property
def grid_size(self) -> int:
if self.clip_name == "ViT-L/14":
return 16
else:
return 7
@property
def grid_feature_dim(self) -> int:
if self.clip_name == "ViT-L/14":
return 1024
else:
return 768
def forward(
self,
batch_size: int,
images: Optional[Iterable[Optional[ImageType]]] = None,
texts: Optional[Iterable[Optional[str]]] = None,
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
) -> torch.Tensor:
"""
Generate a batch of embeddings from a mixture of images, texts,
precomputed embeddings, and possibly empty values.
For each batch element, at most one of images, texts, and embeddings
should have a non-None value. Embeddings from multiple modalities
cannot be mixed for a single batch element. If no modality is provided,
a zero embedding will be used for the batch element.
"""
image_seq = [None] * batch_size if images is None else list(images)
text_seq = [None] * batch_size if texts is None else list(texts)
embedding_seq = [None] * batch_size if embeddings is None else list(embeddings)
assert len(image_seq) == batch_size, "number of images should match batch size"
assert len(text_seq) == batch_size, "number of texts should match batch size"
assert len(embedding_seq) == batch_size, "number of embeddings should match batch size"
if self.ensure_used_params:
return self._static_multimodal_embed(
images=image_seq, texts=text_seq, embeddings=embedding_seq
)
result = torch.zeros((batch_size, self.feature_dim), device=self.device)
index_images = []
index_texts = []
for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)):
assert (
sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2
), "only one modality may be non-None per batch element"
if image is not None:
index_images.append((i, image))
elif text is not None:
index_texts.append((i, text))
elif emb is not None:
result[i] = emb.to(result)
if len(index_images):
embs = self.embed_images((img for _, img in index_images))
for (i, _), emb in zip(index_images, embs):
result[i] = emb.to(result)
if len(index_texts):
embs = self.embed_text((text for _, text in index_texts))
for (i, _), emb in zip(index_texts, embs):
result[i] = emb.to(result)
return result
def _static_multimodal_embed(
self,
images: List[Optional[ImageType]] = None,
texts: List[Optional[str]] = None,
embeddings: List[Optional[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Like forward(), but always runs all encoders to ensure that
the forward graph looks the same on every rank.
"""
image_emb = self.embed_images(images)
text_emb = self.embed_text(t if t else "" for t in texts)
joined_embs = torch.stack(
[
emb.to(device=self.device, dtype=torch.float32)
if emb is not None
else torch.zeros(self.feature_dim, device=self.device)
for emb in embeddings
],
dim=0,
)
image_flag = torch.tensor([x is not None for x in images], device=self.device)[
:, None
].expand_as(image_emb)
text_flag = torch.tensor([x is not None for x in texts], device=self.device)[
:, None
].expand_as(image_emb)
emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[
:, None
].expand_as(image_emb)
return (
image_flag.float() * image_emb
+ text_flag.float() * text_emb
+ emb_flag.float() * joined_embs
+ self.clip_model.logit_scale * 0 # avoid unused parameters
)
def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
"""
:param xs: N images, stored as numpy arrays, tensors, or PIL images.
:return: an [N x D] tensor of features.
"""
clip_inputs = self.images_to_tensor(xs)
results = self.clip_model.encode_image(clip_inputs).float()
return results / torch.linalg.norm(results, dim=-1, keepdim=True)
def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
"""
Embed text prompts as an [N x D] tensor.
"""
enc = self.clip_model.encode_text(
self._tokenize(list(prompts), truncate=True).to(self.device)
).float()
return enc / torch.linalg.norm(enc, dim=-1, keepdim=True)
def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
"""
Embed images into latent grids.
:param xs: an iterable of images to embed.
:return: a tensor of shape [N x C x L], where L = self.grid_size**2.
"""
if self.ensure_used_params:
extra_value = 0.0
for p in self.parameters():
extra_value = extra_value + p.mean() * 0.0
else:
extra_value = 0.0
x = self.images_to_tensor(xs).to(self.clip_model.dtype)
# https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225
vt = self.clip_model.visual
x = vt.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
vt.class_embedding.to(x.dtype)
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + vt.positional_embedding.to(x.dtype)
x = vt.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = vt.transformer(x)
x = x.permute(1, 2, 0) # LND -> NDL
return x[..., 1:].contiguous().float() + extra_value
def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device)
class FrozenImageCLIP:
def __init__(self, device: torch.device, **kwargs):
self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs)
for parameter in self.model.parameters():
parameter.requires_grad_(False)
@property
def feature_dim(self) -> int:
return self.model.feature_dim
@property
def grid_size(self) -> int:
return self.model.grid_size
@property
def grid_feature_dim(self) -> int:
return self.model.grid_feature_dim
def __call__(
self,
batch_size: int,
images: Optional[Iterable[Optional[ImageType]]] = None,
texts: Optional[Iterable[Optional[str]]] = None,
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
) -> torch.Tensor:
# We don't do a no_grad() here so that gradients could still
# flow to the input embeddings argument.
# This behavior is currently not used, but it could be.
return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings)
def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
with torch.no_grad():
return self.model.embed_images(xs)
def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
with torch.no_grad():
return self.model.embed_text(prompts)
def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
with torch.no_grad():
return self.model.embed_images_grid(xs)
def _image_to_pil(obj: Optional[ImageType]) -> Image.Image:
if obj is None:
return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8))
if isinstance(obj, np.ndarray):
return Image.fromarray(obj.astype(np.uint8))
elif isinstance(obj, torch.Tensor):
return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8))
else:
return obj

View File

@ -0,0 +1,139 @@
from abc import abstractmethod
from typing import Dict, Optional
import torch
import torch.nn as nn
from .perceiver import SimplePerceiver
from .transformer import Transformer
class PointCloudSDFModel(nn.Module):
@property
@abstractmethod
def device(self) -> torch.device:
"""
Get the device that should be used for input tensors.
"""
@property
@abstractmethod
def default_batch_size(self) -> int:
"""
Get a reasonable default number of query points for the model.
In some cases, this might be the only supported size.
"""
@abstractmethod
def encode_point_clouds(self, point_clouds: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Encode a batch of point clouds to cache part of the SDF calculation
done by forward().
:param point_clouds: a batch of [batch x 3 x N] points.
:return: a state representing the encoded point cloud batch.
"""
def forward(
self,
x: torch.Tensor,
point_clouds: Optional[torch.Tensor] = None,
encoded: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
"""
Predict the SDF at the coordinates x, given a batch of point clouds.
Either point_clouds or encoded should be passed. Only exactly one of
these arguments should be None.
:param x: a [batch x 3 x N'] tensor of query points.
:param point_clouds: a [batch x 3 x N] batch of point clouds.
:param encoded: the result of calling encode_point_clouds().
:return: a [batch x N'] tensor of SDF predictions.
"""
assert point_clouds is not None or encoded is not None
assert point_clouds is None or encoded is None
if point_clouds is not None:
encoded = self.encode_point_clouds(point_clouds)
return self.predict_sdf(x, encoded)
@abstractmethod
def predict_sdf(
self, x: torch.Tensor, encoded: Optional[Dict[str, torch.Tensor]]
) -> torch.Tensor:
"""
Predict the SDF at the query points given the encoded point clouds.
Each query point should be treated independently, only conditioning on
the point clouds themselves.
"""
class CrossAttentionPointCloudSDFModel(PointCloudSDFModel):
"""
Encode point clouds using a transformer, and query points using cross
attention to the encoded latents.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int = 4096,
width: int = 512,
encoder_layers: int = 12,
encoder_heads: int = 8,
decoder_layers: int = 4,
decoder_heads: int = 8,
init_scale: float = 0.25,
):
super().__init__()
self._device = device
self.n_ctx = n_ctx
self.encoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype)
self.encoder = Transformer(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
layers=encoder_layers,
heads=encoder_heads,
init_scale=init_scale,
)
self.decoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype)
self.decoder = SimplePerceiver(
device=device,
dtype=dtype,
n_data=n_ctx,
width=width,
layers=decoder_layers,
heads=decoder_heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype)
@property
def device(self) -> torch.device:
return self._device
@property
def default_batch_size(self) -> int:
return self.n_query
def encode_point_clouds(self, point_clouds: torch.Tensor) -> Dict[str, torch.Tensor]:
h = self.encoder_input_proj(point_clouds.permute(0, 2, 1))
h = self.encoder(h)
return dict(latents=h)
def predict_sdf(
self, x: torch.Tensor, encoded: Optional[Dict[str, torch.Tensor]]
) -> torch.Tensor:
data = encoded["latents"]
x = self.decoder_input_proj(x.permute(0, 2, 1))
x = self.decoder(x, data)
x = self.ln_post(x)
x = self.output_proj(x)
return x[..., 0]

View File

@ -0,0 +1,494 @@
"""
Adapted from: https://github.com/openai/openai/blob/55363aa496049423c37124b440e9e30366db3ed6/orc/orc/diffusion/vit.py
"""
import math
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from .checkpoint import checkpoint
from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType
from .util import timestep_embedding
def init_linear(l, stddev):
nn.init.normal_(l.weight, std=stddev)
if l.bias is not None:
nn.init.constant_(l.bias, 0.0)
class MultiheadAttention(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
heads: int,
init_scale: float,
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
init_linear(self.c_qkv, init_scale)
init_linear(self.c_proj, init_scale)
def forward(self, x):
x = self.c_qkv(x)
x = checkpoint(self.attention, (x,), (), True)
x = self.c_proj(x)
return x
class MLP(nn.Module):
def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
self.gelu = nn.GELU()
init_linear(self.c_fc, init_scale)
init_linear(self.c_proj, init_scale)
def forward(self, x):
return self.c_proj(self.gelu(self.c_fc(x)))
class QKVMultiheadAttention(nn.Module):
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
super().__init__()
self.device = device
self.dtype = dtype
self.heads = heads
self.n_ctx = n_ctx
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.heads // 3
scale = 1 / math.sqrt(math.sqrt(attn_ch))
qkv = qkv.view(bs, n_ctx, self.heads, -1)
q, k, v = torch.split(qkv, attn_ch, dim=-1)
weight = torch.einsum(
"bthc,bshc->bhts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
wdtype = weight.dtype
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
heads: int,
init_scale: float = 1.0,
):
super().__init__()
self.attn = MultiheadAttention(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
)
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
def forward(self, x: torch.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
layers: int,
heads: int,
init_scale: float = 0.25,
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
init_scale = init_scale * math.sqrt(1.0 / width)
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor):
for block in self.resblocks:
x = block(x)
return x
class PointDiffusionTransformer(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
input_channels: int = 3,
output_channels: int = 3,
n_ctx: int = 1024,
width: int = 512,
layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
time_token_cond: bool = False,
):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.n_ctx = n_ctx
self.time_token_cond = time_token_cond
self.time_embed = MLP(
device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
)
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.backbone = Transformer(
device=device,
dtype=dtype,
n_ctx=n_ctx + int(time_token_cond),
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
with torch.no_grad():
self.output_proj.weight.zero_()
self.output_proj.bias.zero_()
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
:param x: an [N x C x T] tensor.
:param t: an [N] tensor.
:return: an [N x C' x T] tensor.
"""
assert x.shape[-1] == self.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])
def _forward_with_cond(
self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]]
) -> torch.Tensor:
h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC
for emb, as_token in cond_as_token:
if not as_token:
h = h + emb[:, None]
extra_tokens = [
(emb[:, None] if len(emb.shape) == 2 else emb)
for emb, as_token in cond_as_token
if as_token
]
if len(extra_tokens):
h = torch.cat(extra_tokens + [h], dim=1)
h = self.ln_pre(h)
h = self.backbone(h)
h = self.ln_post(h)
if len(extra_tokens):
h = h[:, sum(h.shape[1] for h in extra_tokens) :]
h = self.output_proj(h)
return h.permute(0, 2, 1)
class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int = 1024,
token_cond: bool = False,
cond_drop_prob: float = 0.0,
frozen_clip: bool = True,
cache_dir: Optional[str] = None,
**kwargs,
):
super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), **kwargs)
self.n_ctx = n_ctx
self.token_cond = token_cond
self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device, cache_dir=cache_dir)
self.clip_embed = nn.Linear(
self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype
)
self.cond_drop_prob = cond_drop_prob
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
return dict(embeddings=self.clip(batch_size, **model_kwargs))
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
images: Optional[Iterable[Optional[ImageType]]] = None,
texts: Optional[Iterable[Optional[str]]] = None,
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
):
"""
:param x: an [N x C x T] tensor.
:param t: an [N] tensor.
:param images: a batch of images to condition on.
:param texts: a batch of texts to condition on.
:param embeddings: a batch of CLIP embeddings to condition on.
:return: an [N x C' x T] tensor.
"""
assert x.shape[-1] == self.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings)
assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0]
if self.training:
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
clip_out = clip_out * mask[:, None].to(clip_out)
# Rescale the features to have unit variance
clip_out = math.sqrt(clip_out.shape[1]) * clip_out
clip_embed = self.clip_embed(clip_out)
cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)]
return self._forward_with_cond(x, cond)
class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int = 1024,
cond_drop_prob: float = 0.0,
frozen_clip: bool = True,
cache_dir: Optional[str] = None,
**kwargs,
):
clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(
device,
cache_dir=cache_dir,
)
super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs)
self.n_ctx = n_ctx
self.clip = clip
self.clip_embed = nn.Sequential(
nn.LayerNorm(
normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
),
nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
)
self.cond_drop_prob = cond_drop_prob
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
_ = batch_size
with torch.no_grad():
return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"]))
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
images: Optional[Iterable[ImageType]] = None,
embeddings: Optional[Iterable[torch.Tensor]] = None,
):
"""
:param x: an [N x C x T] tensor.
:param t: an [N] tensor.
:param images: a batch of images to condition on.
:param embeddings: a batch of CLIP latent grids to condition on.
:return: an [N x C' x T] tensor.
"""
assert images is not None or embeddings is not None, "must specify images or embeddings"
assert images is None or embeddings is None, "cannot specify both images and embeddings"
assert x.shape[-1] == self.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
if images is not None:
clip_out = self.clip.embed_images_grid(images)
else:
clip_out = embeddings
if self.training:
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
clip_out = clip_out * mask[:, None, None].to(clip_out)
clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC
clip_embed = self.clip_embed(clip_out)
cond = [(t_embed, self.time_token_cond), (clip_embed, True)]
return self._forward_with_cond(x, cond)
class UpsamplePointDiffusionTransformer(PointDiffusionTransformer):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
cond_input_channels: Optional[int] = None,
cond_ctx: int = 1024,
n_ctx: int = 4096 - 1024,
channel_scales: Optional[Sequence[float]] = None,
channel_biases: Optional[Sequence[float]] = None,
**kwargs,
):
super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs)
self.n_ctx = n_ctx
self.cond_input_channels = cond_input_channels or self.input_channels
self.cond_point_proj = nn.Linear(
self.cond_input_channels, self.backbone.width, device=device, dtype=dtype
)
self.register_buffer(
"channel_scales",
torch.tensor(channel_scales, dtype=dtype, device=device)
if channel_scales is not None
else None,
)
self.register_buffer(
"channel_biases",
torch.tensor(channel_biases, dtype=dtype, device=device)
if channel_biases is not None
else None,
)
def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor):
"""
:param x: an [N x C1 x T] tensor.
:param t: an [N] tensor.
:param low_res: an [N x C2 x T'] tensor of conditioning points.
:return: an [N x C3 x T] tensor.
"""
assert x.shape[-1] == self.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
low_res_embed = self._embed_low_res(low_res)
cond = [(t_embed, self.time_token_cond), (low_res_embed, True)]
return self._forward_with_cond(x, cond)
def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor:
if self.channel_scales is not None:
x = x * self.channel_scales[None, :, None]
if self.channel_biases is not None:
x = x + self.channel_biases[None, :, None]
return self.cond_point_proj(x.permute(0, 2, 1))
class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int = 4096 - 1024,
cond_drop_prob: float = 0.0,
frozen_clip: bool = True,
cache_dir: Optional[str] = None,
**kwargs,
):
clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(
device,
cache_dir=cache_dir,
)
super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs)
self.n_ctx = n_ctx
self.clip = clip
self.clip_embed = nn.Sequential(
nn.LayerNorm(
normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
),
nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
)
self.cond_drop_prob = cond_drop_prob
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
if "images" not in model_kwargs:
zero_emb = torch.zeros(
[batch_size, self.clip.grid_feature_dim, self.clip.grid_size**2],
device=next(self.parameters()).device,
)
return dict(embeddings=zero_emb, low_res=model_kwargs["low_res"])
with torch.no_grad():
return dict(
embeddings=self.clip.embed_images_grid(model_kwargs["images"]),
low_res=model_kwargs["low_res"],
)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
*,
low_res: torch.Tensor,
images: Optional[Iterable[ImageType]] = None,
embeddings: Optional[Iterable[torch.Tensor]] = None,
):
"""
:param x: an [N x C1 x T] tensor.
:param t: an [N] tensor.
:param low_res: an [N x C2 x T'] tensor of conditioning points.
:param images: a batch of images to condition on.
:param embeddings: a batch of CLIP latent grids to condition on.
:return: an [N x C3 x T] tensor.
"""
assert x.shape[-1] == self.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
low_res_embed = self._embed_low_res(low_res)
if images is not None:
clip_out = self.clip.embed_images_grid(images)
elif embeddings is not None:
clip_out = embeddings
else:
# Support unconditional generation.
clip_out = torch.zeros(
[len(x), self.clip.grid_feature_dim, self.clip.grid_size**2],
dtype=x.dtype,
device=x.device,
)
if self.training:
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
clip_out = clip_out * mask[:, None, None].to(clip_out)
clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC
clip_embed = self.clip_embed(clip_out)
cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)]
return self._forward_with_cond(x, cond)

View File

@ -0,0 +1,23 @@
import math
import torch
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding

View File

View File

@ -0,0 +1,87 @@
from dataclasses import dataclass, field
from typing import BinaryIO, Dict, Optional, Union
import numpy as np
from .ply_util import write_ply
@dataclass
class TriMesh:
"""
A 3D triangle mesh with optional data at the vertices and faces.
"""
# [N x 3] array of vertex coordinates.
verts: np.ndarray
# [M x 3] array of triangles, pointing to indices in verts.
faces: np.ndarray
# [P x 3] array of normal vectors per face.
normals: Optional[np.ndarray] = None
# Extra data per vertex and face.
vertex_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict)
face_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict)
@classmethod
def load(cls, f: Union[str, BinaryIO]) -> "TriMesh":
"""
Load the mesh from a .npz file.
"""
if isinstance(f, str):
with open(f, "rb") as reader:
return cls.load(reader)
else:
obj = np.load(f)
keys = list(obj.keys())
verts = obj["verts"]
faces = obj["faces"]
normals = obj["normals"] if "normals" in keys else None
vertex_channels = {}
face_channels = {}
for key in keys:
if key.startswith("v_"):
vertex_channels[key[2:]] = obj[key]
elif key.startswith("f_"):
face_channels[key[2:]] = obj[key]
return cls(
verts=verts,
faces=faces,
normals=normals,
vertex_channels=vertex_channels,
face_channels=face_channels,
)
def save(self, f: Union[str, BinaryIO]):
"""
Save the mesh to a .npz file.
"""
if isinstance(f, str):
with open(f, "wb") as writer:
self.save(writer)
else:
obj_dict = dict(verts=self.verts, faces=self.faces)
if self.normals is not None:
obj_dict["normals"] = self.normals
for k, v in self.vertex_channels.items():
obj_dict[f"v_{k}"] = v
for k, v in self.face_channels.items():
obj_dict[f"f_{k}"] = v
np.savez(f, **obj_dict)
def has_vertex_colors(self) -> bool:
return self.vertex_channels is not None and all(x in self.vertex_channels for x in "RGB")
def write_ply(self, raw_f: BinaryIO):
write_ply(
raw_f,
coords=self.verts,
rgb=(
np.stack([self.vertex_channels[x] for x in "RGB"], axis=1)
if self.has_vertex_colors()
else None
),
faces=self.faces,
)

View File

@ -0,0 +1,96 @@
from typing import Dict
import numpy as np
import skimage
import torch
from tqdm.auto import tqdm
from point_e.models.sdf import PointCloudSDFModel
from .mesh import TriMesh
from .point_cloud import PointCloud
def marching_cubes_mesh(
pc: PointCloud,
model: PointCloudSDFModel,
batch_size: int = 4096,
grid_size: int = 128,
side_length: float = 1.02,
fill_vertex_channels: bool = True,
progress: bool = False,
) -> TriMesh:
"""
Run marching cubes on the SDF predicted from a point cloud to produce a
mesh representing the 3D surface.
:param pc: the point cloud to apply marching cubes to.
:param model: the model to use to predict SDF values.
:param grid_size: the number of samples along each axis. A total of
grid_size**3 function evaluations are performed.
:param side_length: the size of the cube containing the model, which is
assumed to be centered at the origin.
:param fill_vertex_channels: if True, use the nearest neighbor of each mesh
vertex in the point cloud to compute vertex
data (e.g. colors).
"""
voxel_size = side_length / (grid_size - 1)
min_coord = -side_length / 2
def int_coord_to_float(int_coords: torch.Tensor) -> torch.Tensor:
return int_coords.float() * voxel_size + min_coord
with torch.no_grad():
cond = model.encode_point_clouds(
torch.from_numpy(pc.coords).permute(1, 0).to(model.device)[None]
)
indices = range(0, grid_size**3, batch_size)
if progress:
indices = tqdm(indices)
volume = []
for i in indices:
indices = torch.arange(
i, min(i + batch_size, grid_size**3), step=1, dtype=torch.int64, device=model.device
)
zs = int_coord_to_float(indices % grid_size)
ys = int_coord_to_float(torch.div(indices, grid_size, rounding_mode="trunc") % grid_size)
xs = int_coord_to_float(torch.div(indices, grid_size**2, rounding_mode="trunc"))
coords = torch.stack([xs, ys, zs], dim=0)
with torch.no_grad():
volume.append(model(coords[None], encoded=cond)[0])
volume_np = torch.cat(volume).view(grid_size, grid_size, grid_size).cpu().numpy()
if np.all(volume_np < 0) or np.all(volume_np > 0):
# The volume is invalid for some reason, which will break
# marching cubes unless we center it.
volume_np -= np.mean(volume_np)
verts, faces, normals, _ = skimage.measure.marching_cubes(
volume=volume_np,
level=0,
allow_degenerate=False,
spacing=(voxel_size,) * 3,
)
# The triangles follow the left-hand rule, but we want to
# follow the right-hand rule.
# This syntax might seem roundabout, but we get incorrect
# results if we do: x[:,0], x[:,1] = x[:,1], x[:,0]
old_f1 = faces[:, 0].copy()
faces[:, 0] = faces[:, 1]
faces[:, 1] = old_f1
verts += min_coord
return TriMesh(
verts=verts,
faces=faces,
normals=normals,
vertex_channels=None if not fill_vertex_channels else _nearest_vertex_channels(pc, verts),
)
def _nearest_vertex_channels(pc: PointCloud, verts: np.ndarray) -> Dict[str, np.ndarray]:
nearest = pc.nearest_points(verts)
return {ch: arr[nearest] for ch, arr in pc.channels.items()}

View File

@ -0,0 +1,64 @@
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
from .point_cloud import PointCloud
def plot_point_cloud(
pc: PointCloud,
color: bool = True,
grid_size: int = 1,
fixed_bounds: Optional[Tuple[Tuple[float, float, float], Tuple[float, float, float]]] = (
(-0.75, -0.75, -0.75),
(0.75, 0.75, 0.75),
),
):
"""
Render a point cloud as a plot to the given image path.
:param pc: the PointCloud to plot.
:param image_path: the path to save the image, with a file extension.
:param color: if True, show the RGB colors from the point cloud.
:param grid_size: the number of random rotations to render.
"""
fig = plt.figure(figsize=(8, 8))
for i in range(grid_size):
for j in range(grid_size):
ax = fig.add_subplot(grid_size, grid_size, 1 + j + i * grid_size, projection="3d")
color_args = {}
if color:
color_args["c"] = np.stack(
[pc.channels["R"], pc.channels["G"], pc.channels["B"]], axis=-1
)
c = pc.coords
if grid_size > 1:
theta = np.pi * 2 * (i * grid_size + j) / (grid_size**2)
rotation = np.array(
[
[np.cos(theta), -np.sin(theta), 0.0],
[np.sin(theta), np.cos(theta), 0.0],
[0.0, 0.0, 1.0],
]
)
c = c @ rotation
ax.scatter(c[:, 0], c[:, 1], c[:, 2], **color_args)
if fixed_bounds is None:
min_point = c.min(0)
max_point = c.max(0)
size = (max_point - min_point).max() / 2
center = (min_point + max_point) / 2
ax.set_xlim3d(center[0] - size, center[0] + size)
ax.set_ylim3d(center[1] - size, center[1] + size)
ax.set_zlim3d(center[2] - size, center[2] + size)
else:
ax.set_xlim3d(fixed_bounds[0][0], fixed_bounds[1][0])
ax.set_ylim3d(fixed_bounds[0][1], fixed_bounds[1][1])
ax.set_zlim3d(fixed_bounds[0][2], fixed_bounds[1][2])
return fig

View File

@ -0,0 +1,68 @@
import io
import struct
from contextlib import contextmanager
from typing import BinaryIO, Iterator, Optional
import numpy as np
def write_ply(
raw_f: BinaryIO,
coords: np.ndarray,
rgb: Optional[np.ndarray] = None,
faces: Optional[np.ndarray] = None,
):
"""
Write a PLY file for a mesh or a point cloud.
:param coords: an [N x 3] array of floating point coordinates.
:param rgb: an [N x 3] array of vertex colors, in the range [0.0, 1.0].
:param faces: an [N x 3] array of triangles encoded as integer indices.
"""
with buffered_writer(raw_f) as f:
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
if rgb is not None:
f.write(b"property uchar red\n")
f.write(b"property uchar green\n")
f.write(b"property uchar blue\n")
if faces is not None:
f.write(bytes(f"element face {len(faces)}\n", "ascii"))
f.write(b"property list uchar int vertex_index\n")
f.write(b"end_header\n")
if rgb is not None:
rgb = (rgb * 255.499).round().astype(int)
vertices = [
(*coord, *rgb)
for coord, rgb in zip(
coords.tolist(),
rgb.tolist(),
)
]
format = struct.Struct("<3f3B")
for item in vertices:
f.write(format.pack(*item))
else:
format = struct.Struct("<3f")
for vertex in coords.tolist():
f.write(format.pack(*vertex))
if faces is not None:
format = struct.Struct("<B3I")
for tri in faces.tolist():
f.write(format.pack(len(tri), *tri))
@contextmanager
def buffered_writer(raw_f: BinaryIO) -> Iterator[io.BufferedIOBase]:
if isinstance(raw_f, io.BufferedIOBase):
yield raw_f
else:
f = io.BufferedWriter(raw_f)
yield f
f.flush()

View File

@ -0,0 +1,174 @@
import random
from dataclasses import dataclass
from typing import BinaryIO, Dict, List, Optional, Union
import numpy as np
from .ply_util import write_ply
COLORS = frozenset(["R", "G", "B", "A"])
def preprocess(data, channel):
if channel in COLORS:
return np.round(data * 255.0)
return data
@dataclass
class PointCloud:
"""
An array of points sampled on a surface. Each point may have zero or more
channel attributes.
:param coords: an [N x 3] array of point coordinates.
:param channels: a dict mapping names to [N] arrays of channel values.
"""
coords: np.ndarray
channels: Dict[str, np.ndarray]
@classmethod
def load(cls, f: Union[str, BinaryIO]) -> "PointCloud":
"""
Load the point cloud from a .npz file.
"""
if isinstance(f, str):
with open(f, "rb") as reader:
return cls.load(reader)
else:
obj = np.load(f)
keys = list(obj.keys())
return PointCloud(
coords=obj["coords"],
channels={k: obj[k] for k in keys if k != "coords"},
)
def save(self, f: Union[str, BinaryIO]):
"""
Save the point cloud to a .npz file.
"""
if isinstance(f, str):
with open(f, "wb") as writer:
self.save(writer)
else:
np.savez(f, coords=self.coords, **self.channels)
def write_ply(self, raw_f: BinaryIO):
write_ply(
raw_f,
coords=self.coords,
rgb=(
np.stack([self.channels[x] for x in "RGB"], axis=1)
if all(x in self.channels for x in "RGB")
else None
),
)
def random_sample(self, num_points: int, **subsample_kwargs) -> "PointCloud":
"""
Sample a random subset of this PointCloud.
:param num_points: maximum number of points to sample.
:param subsample_kwargs: arguments to self.subsample().
:return: a reduced PointCloud, or self if num_points is not less than
the current number of points.
"""
if len(self.coords) <= num_points:
return self
indices = np.random.choice(len(self.coords), size=(num_points,), replace=False)
return self.subsample(indices, **subsample_kwargs)
def farthest_point_sample(
self, num_points: int, init_idx: Optional[int] = None, **subsample_kwargs
) -> "PointCloud":
"""
Sample a subset of the point cloud that is evenly distributed in space.
First, a random point is selected. Then each successive point is chosen
such that it is furthest from the currently selected points.
The time complexity of this operation is O(NM), where N is the original
number of points and M is the reduced number. Therefore, performance
can be improved by randomly subsampling points with random_sample()
before running farthest_point_sample().
:param num_points: maximum number of points to sample.
:param init_idx: if specified, the first point to sample.
:param subsample_kwargs: arguments to self.subsample().
:return: a reduced PointCloud, or self if num_points is not less than
the current number of points.
"""
if len(self.coords) <= num_points:
return self
init_idx = random.randrange(len(self.coords)) if init_idx is None else init_idx
indices = np.zeros([num_points], dtype=np.int64)
indices[0] = init_idx
sq_norms = np.sum(self.coords**2, axis=-1)
def compute_dists(idx: int):
# Utilize equality: ||A-B||^2 = ||A||^2 + ||B||^2 - 2*(A @ B).
return sq_norms + sq_norms[idx] - 2 * (self.coords @ self.coords[idx])
cur_dists = compute_dists(init_idx)
for i in range(1, num_points):
idx = np.argmax(cur_dists)
indices[i] = idx
cur_dists = np.minimum(cur_dists, compute_dists(idx))
return self.subsample(indices, **subsample_kwargs)
def subsample(self, indices: np.ndarray, average_neighbors: bool = False) -> "PointCloud":
if not average_neighbors:
return PointCloud(
coords=self.coords[indices],
channels={k: v[indices] for k, v in self.channels.items()},
)
new_coords = self.coords[indices]
neighbor_indices = PointCloud(coords=new_coords, channels={}).nearest_points(self.coords)
# Make sure every point points to itself, which might not
# be the case if points are duplicated or there is rounding
# error.
neighbor_indices[indices] = np.arange(len(indices))
new_channels = {}
for k, v in self.channels.items():
v_sum = np.zeros_like(v[: len(indices)])
v_count = np.zeros_like(v[: len(indices)])
np.add.at(v_sum, neighbor_indices, v)
np.add.at(v_count, neighbor_indices, 1)
new_channels[k] = v_sum / v_count
return PointCloud(coords=new_coords, channels=new_channels)
def select_channels(self, channel_names: List[str]) -> np.ndarray:
data = np.stack([preprocess(self.channels[name], name) for name in channel_names], axis=-1)
return data
def nearest_points(self, points: np.ndarray, batch_size: int = 16384) -> np.ndarray:
"""
For each point in another set of points, compute the point in this
pointcloud which is closest.
:param points: an [N x 3] array of points.
:param batch_size: the number of neighbor distances to compute at once.
Smaller values save memory, while larger values may
make the computation faster.
:return: an [N] array of indices into self.coords.
"""
norms = np.sum(self.coords**2, axis=-1)
all_indices = []
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
dists = norms + np.sum(batch**2, axis=-1)[:, None] - 2 * (batch @ self.coords.T)
all_indices.append(np.argmin(dists, axis=-1))
return np.concatenate(all_indices, axis=0)
def combine(self, other: "PointCloud") -> "PointCloud":
assert self.channels.keys() == other.channels.keys()
return PointCloud(
coords=np.concatenate([self.coords, other.coords], axis=0),
channels={
k: np.concatenate([v, other.channels[k]], axis=0) for k, v in self.channels.items()
},
)

16
point-e/setup.py Normal file
View File

@ -0,0 +1,16 @@
from setuptools import setup
setup(
name="point-e",
packages=[
"point_e",
"point_e.diffusion",
"point_e.evals",
"point_e.models",
"point_e.util",
],
install_requires=[
# "clip @ git+https://github.com/openai/CLIP.git",
],
author="OpenAI",
)