Skip to content

Commit

Permalink
update ply save method
Browse files Browse the repository at this point in the history
  • Loading branch information
botaoye committed Dec 7, 2024
1 parent 5af8e16 commit a427bd7
Showing 1 changed file with 18 additions and 83 deletions.
101 changes: 18 additions & 83 deletions src/model/ply_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,116 +23,51 @@ def construct_list_of_attributes(num_rest: int) -> list[str]:
return attributes


# def export_ply(
# extrinsics: Float[Tensor, "4 4"],
# means: Float[Tensor, "gaussian 3"],
# scales: Float[Tensor, "gaussian 3"],
# rotations: Float[Tensor, "gaussian 4"],
# harmonics: Float[Tensor, "gaussian 3 d_sh"],
# opacities: Float[Tensor, " gaussian"],
# path: Path,
# ):
# # Shift the scene so that the median Gaussian is at the origin.
# means = means - means.median(dim=0).values
#
# # Rescale the scene so that most Gaussians are within range [-1, 1].
# scale_factor = means.abs().quantile(0.95, dim=0).max()
# means = means / scale_factor
# scales = scales / scale_factor
#
# # Define a rotation that makes +Z be the world up vector.
# rotation = [
# [0, 0, 1],
# [-1, 0, 0],
# [0, -1, 0],
# ]
# rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device)
#
# # The Polycam viewer seems to start at a 45 degree angle. Since we want to be
# # looking directly at the object, we compose a 45 degree rotation onto the above
# # rotation.
# adjustment = torch.tensor(
# R.from_rotvec([0, 0, -45], True).as_matrix(),
# dtype=torch.float32,
# device=means.device,
# )
# rotation = adjustment @ rotation
#
# # We also want to see the scene in camera space (as the default view). We therefore
# # compose the w2c rotation onto the above rotation.
# rotation = rotation @ extrinsics[:3, :3].inverse()
#
# # Apply the rotation to the means (Gaussian positions).
# means = einsum(rotation, means, "i j, ... j -> ... i")
#
# # Apply the rotation to the Gaussian rotations.
# rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
# rotations = rotation.detach().cpu().numpy() @ rotations
# rotations = R.from_matrix(rotations).as_quat()
# x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
# rotations = np.stack((w, x, y, z), axis=-1)
#
# # Since our axes are swizzled for the spherical harmonics, we only export the DC
# # band.
# harmonics_view_invariant = harmonics[..., 0]
#
# dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)]
# elements = np.empty(means.shape[0], dtype=dtype_full)
# attributes = (
# means.detach().cpu().numpy(),
# torch.zeros_like(means).detach().cpu().numpy(),
# harmonics_view_invariant.detach().cpu().contiguous().numpy(),
# opacities[..., None].detach().cpu().numpy(),
# scales.log().detach().cpu().numpy(),
# rotations,
# )
# attributes = np.concatenate(attributes, axis=1)
# elements[:] = list(map(tuple, attributes))
# path.parent.mkdir(exist_ok=True, parents=True)
# PlyData([PlyElement.describe(elements, "vertex")]).write(path)


def export_ply(
extrinsics: Float[Tensor, "4 4"],
means: Float[Tensor, "gaussian 3"],
scales: Float[Tensor, "gaussian 3"],
rotations: Float[Tensor, "gaussian 4"],
harmonics: Float[Tensor, "gaussian 3 d_sh"],
opacities: Float[Tensor, " gaussian"],
path: Path,
shift_and_scale: bool = False,
save_sh_dc_only: bool = True,
):
# Shift the scene so that the median Gaussian is at the origin.
means = means - means.median(dim=0).values
if shift_and_scale:
# Shift the scene so that the median Gaussian is at the origin.
means = means - means.median(dim=0).values

# Rescale the scene so that most Gaussians are within range [-1, 1].
scale_factor = means.abs().quantile(0.95, dim=0).max()
means = means / scale_factor
scales = scales / scale_factor
# Rescale the scene so that most Gaussians are within range [-1, 1].
scale_factor = means.abs().quantile(0.95, dim=0).max()
means = means / scale_factor
scales = scales / scale_factor

# Apply the rotation to the Gaussian rotations.
rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
rotations = R.from_matrix(rotations).as_quat()
x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
rotations = np.stack((w, x, y, z), axis=-1)

# Since our axes are swizzled for the spherical harmonics, we only export the DC
# band.
# harmonics_view_invariant = harmonics[..., 0]
# print(harmonics_view_invariant.shape)
# Since current model use SH_degree = 4,
# which require large memory to store, we can only save the DC band to save memory.
f_dc = harmonics[..., 0]
f_rest = harmonics[..., 1:].flatten(start_dim=1)

dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(f_rest.shape[1])]
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])]
elements = np.empty(means.shape[0], dtype=dtype_full)
attributes = (
attributes = [
means.detach().cpu().numpy(),
torch.zeros_like(means).detach().cpu().numpy(),
f_dc.detach().cpu().contiguous().numpy(),
f_rest.detach().cpu().contiguous().numpy(),
opacities[..., None].detach().cpu().numpy(),
scales.log().detach().cpu().numpy(),
rotations,
)
]
if save_sh_dc_only:
# remove f_rest from attributes
attributes.pop(3)

attributes = np.concatenate(attributes, axis=1)
elements[:] = list(map(tuple, attributes))
path.parent.mkdir(exist_ok=True, parents=True)
Expand Down

0 comments on commit a427bd7

Please sign in to comment.