-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathgobjaverse_parquet_dataset.py
273 lines (230 loc) · 13.8 KB
/
gobjaverse_parquet_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
from typing import *
from numpy import ndarray
from torch import Tensor
import os
import json
from collections import defaultdict
import cv2
import numpy as np
import torch
import torch.nn.functional as tF
from kiui.cam import orbit_camera, undo_orbit_camera
from src.data.utils.chunk_dataset import ChunkedDataset
from src.options import Options
from src.utils import normalize_normals, unproject_depth
class GObjaverseParquetDataset(ChunkedDataset):
def __init__(self, opt: Options, training: bool = True, *args, **kwargs):
self.opt = opt
self.training = training
# Default camera intrinsics
self.fxfycxcy = torch.tensor([opt.fxfy, opt.fxfy, 0.5, 0.5], dtype=torch.float32) # (4,)
if opt.prompt_embed_dir is not None:
try:
self.negative_prompt_embed = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null.npy")).float()
except FileNotFoundError:
self.negative_prompt_embed = None
try:
self.negative_pooled_prompt_embed = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null_pooled.npy")).float()
except FileNotFoundError:
self.negative_pooled_prompt_embed = None
try:
self.negative_prompt_attention_mask = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null_attention_mask.npy")).float()
except FileNotFoundError:
self.negative_prompt_attention_mask = None
if "xl" in opt.pretrained_model_name_or_path: # SDXL: zero out negative prompt embedding
if self.negative_prompt_embed is not None and self.negative_pooled_prompt_embed is not None:
self.negative_prompt_embed = torch.zeros_like(self.negative_prompt_embed)
self.negative_pooled_prompt_embed = torch.zeros_like(self.negative_pooled_prompt_embed)
super().__init__(*args, **kwargs)
def __len__(self):
return self.opt.dataset_size
def get_trainable_data_from_raw_data(self, raw_data_list) -> Dict[str, Tensor]: # only `sample["__key__"]` is in str type
assert len(raw_data_list) == 1
sample: Dict[str, bytes] = raw_data_list[0]
V, V_in = self.opt.num_views, self.opt.num_input_views
assert V >= V_in
if self.opt.load_even_views or not self.training:
_pick_func = self._pick_even_view_indices
else:
_pick_func = self._pick_random_view_indices
# Randomly sample `V_in` views (some objects may not appear in the dataset)
random_idxs = _pick_func(V_in)
_num_tries = 0
while not self._check_views_exist(sample, random_idxs):
random_idxs = _pick_func(V_in)
_num_tries += 1
if _num_tries > 100: # TODO: make `100` configurable
raise ValueError(f"Cannot find 4 views in {sample['__key__']}")
except_idxs = random_idxs + [24, 39] # filter duplicated views; hard-coded for GObjaverse
if self.opt.exclude_topdown_views:
except_idxs += [25, 26]
# Randomly sample `V` views (some views may not appear in the dataset)
for i in np.random.permutation(40): # `40` is hard-coded for GObjaverse
if len(random_idxs) >= V:
break
if f"{i:05d}.png" in sample and i not in except_idxs:
try:
_ = np.frombuffer(sample[f"{i:05d}.png"], np.uint8)
assert sample[f"{i:05d}.json"] is not None
random_idxs.append(i)
except: # TypeError: a bytes-like object is required, not 'NoneType'; KeyError: '00001.json'
pass
# Randomly repeat views if not enough views
while len(random_idxs) < V:
random_idxs.append(np.random.choice(random_idxs))
return_dict = defaultdict(list)
init_azi = None
for vid in random_idxs:
return_dict["fxfycxcy"].append(self.fxfycxcy) # (V, 4); fixed intrinsics for GObjaverse
image = self._load_png(sample[f"{vid:05d}.png"]) # (4, 512, 512)
mask = image[3:4] # (1, 512, 512)
image = image[:3] * mask + (1. - mask) # (3, 512, 512), to white bg
return_dict["image"].append(image) # (V, 3, H, W)
return_dict["mask"].append(mask) # (V, 1, H, W)
if self.opt.load_canny:
gray = cv2.cvtColor(image.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2GRAY)
canny = cv2.Canny((gray * 255.).astype(np.uint8), 100., 200.)
canny = torch.from_numpy(canny).unsqueeze(0).float().repeat(3, 1, 1) / 255. # (3, 512, 512) in [0, 1]
canny = -canny + 1. # 0->1, 1->0, i.e., white bg
return_dict["canny"].append(canny) # (V, 3, H, W)
c2w = self._load_camera_from_json(sample[f"{vid:05d}.json"])
# Blender world + OpenCV cam -> OpenGL world & cam; https://kit.kiui.moe/camera
c2w[1] *= -1
c2w[[1, 2]] = c2w[[2, 1]]
c2w[:3, 1:3] *= -1 # invert up and forward direction
return_dict["original_C2W"].append(torch.from_numpy(c2w).float()) # (V, 4, 4); for normal normalization only
# Relative azimuth w.r.t. the first view
ele, azi, dis = undo_orbit_camera(c2w) # elevation: [-90, 90] from +y(-90) to -y(90)
if init_azi is None:
init_azi = azi
azi = (azi - init_azi) % 360. # azimuth: [0, 360] from +z(0) to +x(90)
# To avoid numerical errors for elevation +/- 90 (GObjaverse index 25 (up) & 26 (down))
ele_sign = ele >= 0
ele = abs(ele) - 1e-8
ele = ele * (1. if ele_sign else -1.)
new_c2w = torch.from_numpy(orbit_camera(ele, azi, dis)).float()
return_dict["C2W"].append(new_c2w) # (V, 4, 4)
return_dict["cam_pose"].append(torch.tensor(
[np.deg2rad(ele), np.deg2rad(azi), dis], dtype=torch.float32)) # (V, 3)
# Albedo
if self.opt.load_albedo:
albedo = self._load_png(sample[f"{vid:05d}_albedo.png"]) # (3, 512, 512)
albedo = albedo * mask + (1. - mask) # (3, 512, 512), to white bg
return_dict["albedo"].append(albedo) # (V, 3, H, W)
# Normal & Depth
if self.opt.load_normal or self.opt.load_coord:
nd = self._load_png(sample[f"{vid:05d}_nd.png"], uint16=True) # (4, 512, 512)
if self.opt.load_normal:
normal = nd[:3] * 2. - 1. # (3, 512, 512) in [-1, 1]
normal[0, ...] *= -1 # to OpenGL world convention
return_dict["normal"].append(normal) # (V, 3, H, W)
if self.opt.load_coord or self.opt.load_depth:
depth = nd[3] * 5. # (512, 512); NOTE: depth is scaled by 1/5 in my data preprocessing
return_dict["depth"].append(depth) # (V, H, W)
# Metal & Roughness
if self.opt.load_mr:
mr = self._load_png(sample[f"{vid:05d}_mr.png"]) # (3, 512, 512); (metallic, roughness, padding)
mr = mr * mask + (1. - mask) # (3, 512, 512), to white bg
return_dict["mr"].append(mr) # (V, 3, H, W)
for key in return_dict.keys():
return_dict[key] = torch.stack(return_dict[key], dim=0)
if self.opt.load_normal:
# Normalize normals by the first view and transform the first view to a fixed azimuth (i.e., 0)
# Ensure `normals` and `original_C2W` are in the same camera convention
normals = normalize_normals(return_dict["normal"].unsqueeze(0), return_dict["original_C2W"].unsqueeze(0), i=0).squeeze(0)
normals = torch.einsum("brc,bvchw->bvrhw", return_dict["C2W"][0, :3, :3].unsqueeze(0), normals.unsqueeze(0)).squeeze(0)
normals = normals * 0.5 + 0.5 # [0, 1]
normals = normals * return_dict["mask"] + (1. - return_dict["mask"]) # (V, 3, 512, 512), to white bg
return_dict["normal"] = normals
return_dict.pop("original_C2W") # original C2W is only used for normal normalization
# OpenGL to COLMAP camera for Gaussian renderer
return_dict["C2W"][:, :3, 1:3] *= -1
# Whether scale the object w.r.t. the first view to a fixed size
if self.opt.norm_camera:
scale = self.opt.norm_radius / (torch.norm(return_dict["C2W"][0, :3, 3], dim=-1))
else:
scale = 1.
return_dict["C2W"][:, :3, 3] *= scale
return_dict["cam_pose"][:, 2] *= scale
if self.opt.load_coord:
# Unproject depth map to 3D world coordinate
coords = unproject_depth(return_dict["depth"].unsqueeze(0) * scale,
return_dict["C2W"].unsqueeze(0), return_dict["fxfycxcy"].unsqueeze(0)).squeeze(0)
coords = coords * 0.5 + 0.5 # [0, 1]
coords = coords * return_dict["mask"] + (1. - return_dict["mask"]) # (V, 3, 512, 512), to white bg
return_dict["coord"] = coords
if not self.opt.load_depth:
return_dict.pop("depth")
if self.opt.load_depth:
depths = return_dict["depth"].unsqueeze(1) * return_dict["mask"] # (V, 1, 512, 512), to black bg
assert depths.min() == 0.
if self.opt.normalize_depth:
H, W = depths.shape[-2:]
depths = depths.reshape(V, -1)
depths_max = depths.max(dim=-1, keepdim=True).values
depths = depths / depths_max # [0, 1]
depths = depths.reshape(V, 1, H, W)
depths = -depths + 1. # 0->1, 1->0, i.e., white bg
return_dict["depth"] = depths.repeat(1, 3, 1, 1)
# Resize to the input resolution
for key in ["image", "mask", "albedo", "normal", "coord", "depth", "mr", "canny"]:
if key in return_dict.keys():
return_dict[key] = tF.interpolate(
return_dict[key], size=(self.opt.input_res, self.opt.input_res),
mode="bilinear", align_corners=False, antialias=True
)
# Handle anti-aliased normal, coord and depth (GObjaverse renders anti-aliased normal & depth)
if self.opt.load_normal:
return_dict["normal"] = return_dict["normal"] * return_dict["mask"] + (1. - return_dict["mask"])
if self.opt.load_coord:
return_dict["coord"] = return_dict["coord"] * return_dict["mask"] + (1. - return_dict["mask"])
if self.opt.load_depth:
return_dict["depth"] = return_dict["depth"] * return_dict["mask"] + (1. - return_dict["mask"])
# Load precomputed caption embeddings
if self.opt.prompt_embed_dir is not None:
uid = sample["uid"].decode("utf-8").split("/")[-1].split(".")[0]
return_dict["prompt_embed"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}.npy"))
if "xl" in self.opt.pretrained_model_name_or_path or "3" in self.opt.pretrained_model_name_or_path: # SDXL or SD3
return_dict["pooled_prompt_embed"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}_pooled.npy"))
if "PixArt" in self.opt.pretrained_model_name_or_path: # PixArt-alpha, PixArt-Sigma
return_dict["prompt_attention_mask"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}_attention_mask.npy"))
for key in return_dict.keys():
assert isinstance(return_dict[key], Tensor), f"Value of the key [{key}] is not a Tensor, but {type(return_dict[key])}."
return dict(return_dict)
def _load_png(self, png_bytes: bytes, uint16: bool = False) -> Tensor:
png = np.frombuffer(png_bytes, np.uint8)
png = cv2.imdecode(png, cv2.IMREAD_UNCHANGED) # (H, W, C) ndarray in [0, 255] or [0, 65553]
png = png.astype(np.float32) / (65535. if uint16 else 255.) # (H, W, C) in [0, 1]
png[:, :, :3] = png[:, :, :3][..., ::-1] # BGR -> RGB
png_tensor = torch.from_numpy(png).nan_to_num_(0.) # there are nan in GObjaverse gt normal
return png_tensor.permute(2, 0, 1) # (C, H, W) in [0, 1]
def _load_camera_from_json(self, json_bytes: bytes) -> ndarray:
json_dict = json.loads(json_bytes)
# In OpenCV convention
c2w = np.eye(4) # float64
c2w[:3, 0] = np.array(json_dict["x"])
c2w[:3, 1] = np.array(json_dict["y"])
c2w[:3, 2] = np.array(json_dict["z"])
c2w[:3, 3] = np.array(json_dict["origin"])
return c2w
def _pick_even_view_indices(self, num_views: int = 4) -> List[int]:
assert 12 % num_views == 0 # `12` for even-view sampling in GObjaverse
if np.random.rand() < 2/3:
index0 = np.random.choice(range(24)) # 0~23: 24 views in ele from [5, 30]; hard-coded for GObjaverse
return [(index0 + (24 // num_views)*i) % 24 for i in range(num_views)]
else:
index0 = np.random.choice(range(12)) # 27~38: 12 views in ele from [-5, 5]; hard-coded for GObjaverse
return [((index0 + (12 // num_views)*i) % 12 + 27) for i in range(num_views)]
def _pick_random_view_indices(self, num_views: int = 4) -> List[int]:
assert num_views <= 40 # `40` is hard-coded for GObjaverse
indices = (set(range(40)) - set([25, 26])) if self.opt.exclude_topdown_views else (set(range(40))) # `40` is hard-coded for GObjaverse
return np.random.choice(list(indices), num_views, replace=False).tolist()
def _check_views_exist(self, sample: Dict[str, Union[str, bytes]], vids: List[int]) -> bool:
for vid in vids:
if f"{vid:05d}.png" not in sample:
return False
try:
assert sample[f"{vid:05d}.png"] is not None and sample[f"{vid:05d}.json"] is not None
except: # TypeError: a bytes-like object is required, not 'NoneType'; KeyError: '00001.json'
return False
return True