149 lines
5.3 KiB
Python
149 lines
5.3 KiB
Python
import tempfile
|
|
|
|
import numpy as np
|
|
import torch
|
|
import trimesh
|
|
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
|
from shap_e.diffusion.sample import sample_latents
|
|
from shap_e.models.download import load_config, load_model
|
|
from shap_e.models.nn.camera import (DifferentiableCameraBatch,
|
|
DifferentiableProjectiveCamera)
|
|
from shap_e.models.transmitter.base import Transmitter, VectorDecoder
|
|
from shap_e.rendering.torch_mesh import TorchMesh
|
|
from shap_e.util.collections import AttrDict
|
|
from shap_e.util.image_util import load_image
|
|
|
|
|
|
# Copied from https://github.com/openai/shap-e/blob/d99cedaea18e0989e340163dbaeb4b109fa9e8ec/shap_e/util/notebooks.py#L15-L42
|
|
def create_pan_cameras(size: int,
|
|
device: torch.device) -> DifferentiableCameraBatch:
|
|
origins = []
|
|
xs = []
|
|
ys = []
|
|
zs = []
|
|
for theta in np.linspace(0, 2 * np.pi, num=20):
|
|
z = np.array([np.sin(theta), np.cos(theta), -0.5])
|
|
z /= np.sqrt(np.sum(z**2))
|
|
origin = -z * 4
|
|
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
|
|
y = np.cross(z, x)
|
|
origins.append(origin)
|
|
xs.append(x)
|
|
ys.append(y)
|
|
zs.append(z)
|
|
return DifferentiableCameraBatch(
|
|
shape=(1, len(xs)),
|
|
flat_camera=DifferentiableProjectiveCamera(
|
|
origin=torch.from_numpy(np.stack(origins,
|
|
axis=0)).float().to(device),
|
|
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
|
|
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
|
|
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
|
|
width=size,
|
|
height=size,
|
|
x_fov=0.7,
|
|
y_fov=0.7,
|
|
),
|
|
)
|
|
|
|
|
|
# Copied from https://github.com/openai/shap-e/blob/8625e7c15526d8510a2292f92165979268d0e945/shap_e/util/notebooks.py#LL64C1-L76C33
|
|
@torch.no_grad()
|
|
def decode_latent_mesh(
|
|
xm: Transmitter | VectorDecoder,
|
|
latent: torch.Tensor,
|
|
) -> TorchMesh:
|
|
decoded = xm.renderer.render_views(
|
|
AttrDict(cameras=create_pan_cameras(
|
|
2, latent.device)), # lowest resolution possible
|
|
params=(xm.encoder if isinstance(xm, Transmitter) else
|
|
xm).bottleneck_to_params(latent[None]),
|
|
options=AttrDict(rendering_mode='stf', render_with_direction=False),
|
|
)
|
|
return decoded.raw_meshes[0]
|
|
|
|
|
|
class Model:
|
|
def __init__(self):
|
|
self.device = torch.device(
|
|
'cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.xm = load_model('transmitter', device=self.device)
|
|
self.diffusion = diffusion_from_config(load_config('diffusion'))
|
|
self.model_text = None
|
|
self.model_image = None
|
|
|
|
def load_model(self, model_name: str) -> None:
|
|
assert model_name in ['text300M', 'image300M']
|
|
if model_name == 'text300M' and self.model_text is None:
|
|
self.model_text = load_model(model_name, device=self.device)
|
|
elif model_name == 'image300M' and self.model_image is None:
|
|
self.model_image = load_model(model_name, device=self.device)
|
|
|
|
def to_glb(self, latent: torch.Tensor) -> str:
|
|
ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
|
|
delete=False,
|
|
mode='w+b')
|
|
decode_latent_mesh(self.xm, latent).tri_mesh().write_ply(ply_path)
|
|
|
|
mesh = trimesh.load(ply_path.name)
|
|
rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
|
|
mesh = mesh.apply_transform(rot)
|
|
rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
|
|
mesh = mesh.apply_transform(rot)
|
|
|
|
mesh_path = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
|
|
mesh.export(mesh_path.name, file_type='glb')
|
|
|
|
return mesh_path.name
|
|
|
|
def run_text(self,
|
|
prompt: str,
|
|
seed: int = 0,
|
|
guidance_scale: float = 15.0,
|
|
num_steps: int = 64) -> str:
|
|
self.load_model('text300M')
|
|
torch.manual_seed(seed)
|
|
|
|
latents = sample_latents(
|
|
batch_size=1,
|
|
model=self.model_text,
|
|
diffusion=self.diffusion,
|
|
guidance_scale=guidance_scale,
|
|
model_kwargs=dict(texts=[prompt]),
|
|
progress=True,
|
|
clip_denoised=True,
|
|
use_fp16=True,
|
|
use_karras=True,
|
|
karras_steps=num_steps,
|
|
sigma_min=1e-3,
|
|
sigma_max=160,
|
|
s_churn=0,
|
|
)
|
|
return self.to_glb(latents[0])
|
|
|
|
def run_image(self,
|
|
image_path: str,
|
|
seed: int = 0,
|
|
guidance_scale: float = 3.0,
|
|
num_steps: int = 64) -> str:
|
|
self.load_model('image300M')
|
|
torch.manual_seed(seed)
|
|
|
|
image = load_image(image_path)
|
|
latents = sample_latents(
|
|
batch_size=1,
|
|
model=self.model_image,
|
|
diffusion=self.diffusion,
|
|
guidance_scale=guidance_scale,
|
|
model_kwargs=dict(images=[image]),
|
|
progress=True,
|
|
clip_denoised=True,
|
|
use_fp16=True,
|
|
use_karras=True,
|
|
karras_steps=num_steps,
|
|
sigma_min=1e-3,
|
|
sigma_max=160,
|
|
s_churn=0,
|
|
)
|
|
return self.to_glb(latents[0])
|