Skip to content

Commit

Permalink
Fix visualization of GridEncoding outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom94 committed Feb 11, 2022
1 parent 4929074 commit 09f3597
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions include/neural-graphics-primitives/nerf_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class NerfNetwork : public tcnn::Network<float, T> {
}
m_density_network.reset(tcnn::create_network<T>(local_density_network_config));

m_rgb_network_input_width = tcnn::next_multiple(m_dir_encoding->num_encoded_dims() + m_density_network->padded_output_width() - 1, rgb_alignment);
m_rgb_network_input_width = tcnn::next_multiple(m_dir_encoding->num_encoded_dims() + m_density_network->padded_output_width(), rgb_alignment);

json local_rgb_network_config = rgb_network;
local_rgb_network_config["n_input_dims"] = m_rgb_network_input_width;
Expand Down Expand Up @@ -509,11 +509,11 @@ class NerfNetwork : public tcnn::Network<float, T> {

uint32_t width(uint32_t layer) const override {
if (layer == 0) {
return m_forward.density_network_input.m();
return m_pos_encoding->num_encoded_dims();
} else if (layer < m_density_network->num_forward_activations() + 1) {
return m_density_network->width(layer - 1);
} else if (layer == m_density_network->num_forward_activations() + 1) {
return m_forward.rgb_network_input.m();
return m_rgb_network_input_width;
} else {
return m_rgb_network->width(layer - 2 - m_density_network->num_forward_activations());
}
Expand All @@ -523,17 +523,17 @@ class NerfNetwork : public tcnn::Network<float, T> {
return m_density_network->num_forward_activations() + m_rgb_network->num_forward_activations() + 2;
}

const T* forward_activations(uint32_t layer) const override {
std::pair<const T*, tcnn::MatrixLayout> forward_activations(uint32_t layer) const override {
if (!m_forward.density_network_input.data()) {
throw std::runtime_error{"Must call forward() before accessing activations."};
}

if (layer == 0) {
return m_forward.density_network_input.data();
return {m_forward.density_network_input.data(), m_pos_encoding->output_layout()};
} else if (layer < m_density_network->num_forward_activations() + 1) {
return m_density_network->forward_activations(layer - 1);
} else if (layer == m_density_network->num_forward_activations() + 1) {
return m_forward.rgb_network_input.data();
return {m_forward.rgb_network_input.data(), m_dir_encoding->output_layout()};
} else {
return m_rgb_network->forward_activations(layer - 2 - m_density_network->num_forward_activations());
}
Expand Down
2 changes: 1 addition & 1 deletion src/testbed_sdf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ void Testbed::render_sdf(
if (m_render_mode == ERenderMode::Slice) {
if (m_visualized_dimension == -1) {
distance_function(n_hit, rays_hit.pos, rays_hit.distance, stream);
extract_dimension_pos_neg_kernel<float><<<n_blocks_linear(n_hit*3), n_threads_linear, 0, stream>>>(n_hit*3, 0, 1, 3, rays_hit.distance.data(), (float*)rays_hit.normal.data());
extract_dimension_pos_neg_kernel<float><<<n_blocks_linear(n_hit*3), n_threads_linear, 0, stream>>>(n_hit*3, 0, 1, 3, rays_hit.distance.data(), CM, (float*)rays_hit.normal.data());
} else {
// Store colors in the normal buffer
uint32_t n_elements = next_multiple(n_hit, BATCH_SIZE_MULTIPLE);
Expand Down

0 comments on commit 09f3597

Please sign in to comment.