diff --git a/discussion/robust_inverse_graphics/diffusion.py b/discussion/robust_inverse_graphics/diffusion.py new file mode 100644 index 0000000000..786b241c11 --- /dev/null +++ b/discussion/robust_inverse_graphics/diffusion.py @@ -0,0 +1,355 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Diffusion/score matching utilities. + +In this file, we are generally dealing with the following joint model: + +Q(x) Q(z_0 | x) \prod_{t=1}^T Q(z_t | z_{t - 1}) + +T could be infinite depending on the model used. The goal is to learn a model +for Q(x) which we only have access to via samples from it. We do this by +learning P(z_{t - 1} | z_t; t), which we parameterize using a trainable +`denoise_fn(z_t, f(t))` with some function `f` (often log signal-to-noise +ratio). +""" + +import enum +from typing import Any, Callable + +from flax import struct +import jax +import jax.numpy as jnp +from discussion.robust_inverse_graphics import saving +from fun_mc import using_jax as fun_mc + + +__all__ = [ + "linear_log_snr", + "variance_preserving_forward_process", + "vdm_diffusion_loss", + "vdm_sample", + "VDMDiffusionLossExtra", + "VDMSampleExtra", + "DenoiseOutputType", +] + + +Extra = Any +LogSnrFn = Callable[[jnp.ndarray], jnp.ndarray] +DenoiseFn = Callable[[jnp.ndarray, jnp.ndarray], tuple[jnp.ndarray, Extra]] + + +class DenoiseOutputType(enum.Enum): + """How to interpret the output of `denoise_fn`. + + NOISE: The output is the predicted noise. + ANGULAR_VELOCITY: The output is the angular velocity, defined as: + `alpha_t * noise - sigma_t * x`. + DIRECT: The output is the denoised value. + """ + + NOISE = "noise" + ANGULAR_VELOCITY = "angular_velocity" + DIRECT = "direct" + + +def variance_preserving_forward_process( + z_0: jnp.ndarray, noise: jnp.ndarray, log_snr_t: jnp.ndarray +) -> jnp.ndarray: + """Variance preserving forward process. + + This produces a sample from Q(z_t | z_0) given the desired level of noise and + randomness. + + Args: + z_0: Un-noised inputs. + noise: Noise. + log_snr_t: Log signal-to-noise ratio at time t. + + Returns: + Value of z_t. + """ + var_t = jax.nn.sigmoid(-log_snr_t) + alpha_t = jnp.sqrt(jax.nn.sigmoid(log_snr_t)) # sqrt(1 - var_t) + return alpha_t * z_0 + jnp.sqrt(var_t) * noise + + +@saving.register +@struct.dataclass +class VDMDiffusionLossExtra: + """Extra outputs from `vdm_diffusion_loss`. + + Attributes: + noise: The added noise. + recon_noise: The reconstructed noise (only set if `denoise_output` is NOISE. + target: Target value for the loss to reconstruct. + recon: Output of `denoise_fn`. + extra: Extra outputs from `denoise_fn`. + """ + + noise: jnp.ndarray + recon_noise: jnp.ndarray | None + target: jnp.ndarray + recon: jnp.ndarray + extra: Extra + + +def vdm_diffusion_loss( + t: jnp.ndarray, + num_steps: int | None, + x: jnp.ndarray, + log_snr_fn: LogSnrFn, + denoise_fn: DenoiseFn, + seed: jax.Array, + denoise_output: DenoiseOutputType = DenoiseOutputType.NOISE, +) -> tuple[jnp.ndarray, VDMDiffusionLossExtra]: + r"""The diffusion loss of the variational diffusion model (VDM). + + This uses the parameterization from [1]. The typical procedure minimizes the + expectation of this function, averaging across examples (`z_0`) and times + (sampled uniformly from [0, 1]). + + When `denoise_output` is NOISE, and when the loss is minimized, + `denoise_fn(z_t, log_snr_t) \propto -grad log Q(z_t; log_snr_t)` where `z_t` + is sampled from the forward process (`variance_preserving_forward_process`) + and `Q(.; log_snr_t)` is the marginal density of `z_t`. + + Args: + t: Time in [0, 1] + num_steps: If None, use continuous time parameterization. Otherwise, + discretize `t` to this many bins. + x: Un-noised inputs. + log_snr_fn: Takes in time in [0, 1] and returns the log signal-to-noise + ratio. + denoise_fn: Function that denoises `z_t` given the `log_snr_t`. Its output + is interpreted based on the value of `denoise_output`. + seed: Random seed. + denoise_output: How to interpret the output of `denoise_fn`. + + Returns: + A tuple of the loss and `VDMDiffusionLossExtra` extra outputs. + + #### References + + [1] Kingma, D. P., Salimans, T., Poole, B., & Ho, J. (2021). Variational + Diffusion Models. In arXiv [cs.LG]. arXiv. http://arxiv.org/abs/2107.00630 + """ + + if num_steps is not None: + t = jnp.ceil(t * num_steps) / num_steps + + log_snr_t = log_snr_fn(t) + noise = jax.random.normal(seed, x.shape) + z_t = variance_preserving_forward_process(x, noise, log_snr_t) + + recon, extra = denoise_fn(z_t, log_snr_t) + + match denoise_output: + case DenoiseOutputType.NOISE: + target = noise + recon_noise = recon + sq_error = 0.5 * jnp.square(target - recon).sum() + if num_steps is None: + log_snr_t_grad = jax.grad(log_snr_fn)(t) + loss = -log_snr_t_grad * sq_error + else: + s = t - (1.0 / num_steps) + log_snr_s = log_snr_fn(s) + loss = num_steps * jnp.expm1(log_snr_s - log_snr_t) * sq_error + case DenoiseOutputType.ANGULAR_VELOCITY: + # Plug in x_hat = alpha_t * z_t - sigma_t * v into equation (13) or (15) + # and simplify to get the loss being (SNR(s) - SNR(t)) sigma_t**2 * MSE + # for discrete case and SNR'(t) * sigma_t**2 * MSE for the continuous + # case. + recon_noise = None + var_t = jax.nn.sigmoid(-log_snr_t) + sigma_t = jnp.sqrt(var_t) + alpha_t_2 = jax.nn.sigmoid(log_snr_t) + alpha_t = jnp.sqrt(alpha_t_2) + v = alpha_t * noise - sigma_t * x + + target = v + sq_error = 0.5 * jnp.square(target - recon).sum() + if num_steps is None: + log_snr_t_grad = jax.grad(log_snr_fn)(t) + loss = -alpha_t_2 * log_snr_t_grad * sq_error + else: + s = t - (1.0 / num_steps) + log_snr_s = log_snr_fn(s) + loss = ( + num_steps * jnp.expm1(log_snr_s - log_snr_t) * alpha_t_2 * sq_error + ) + case DenoiseOutputType.DIRECT: + recon_noise = None + target = x + sq_error = 0.5 * jnp.square(target - recon).sum() + if num_steps is None: + snr_t_grad = jax.grad(lambda t: jnp.exp(log_snr_fn(t)))(t) + loss = -snr_t_grad * sq_error + else: + s = t - (1.0 / num_steps) + snr_t = jnp.exp(log_snr_t) + # TODO(siege): Not sure this is more stable than doing snr_s - snr_t + # directly. + log_snr_s = log_snr_fn(s) + loss = num_steps * snr_t * jnp.expm1(log_snr_s - log_snr_t) * sq_error + case _: + raise ValueError(f"Unknown denoise_output: {denoise_output}") + + return loss, VDMDiffusionLossExtra( + noise=noise, + recon_noise=recon_noise, + target=target, + recon=recon, + extra=extra, + ) + + +def _vdm_sample_step( + z_t: jnp.ndarray, + step: jnp.ndarray, + num_steps: int, + log_snr_fn: LogSnrFn, + denoise_fn: DenoiseFn, + seed: jax.Array, + denoise_output: DenoiseOutputType, + t_start: jnp.ndarray, +) -> tuple[jnp.ndarray, Extra]: + """One step of the sampling process.""" + t = t_start * (step / num_steps) + s = t_start * ((step - 1) / num_steps) + + log_snr_t = log_snr_fn(t) + log_snr_s = log_snr_fn(s) + recon, extra = denoise_fn(z_t, log_snr_t) + + zeta = jax.random.normal(seed, z_t.shape) + + alpha_s_2 = jax.nn.sigmoid(log_snr_s) + alpha_s = jnp.sqrt(alpha_s_2) + alpha_t_2 = jax.nn.sigmoid(log_snr_t) + alpha_t = jnp.sqrt(alpha_t_2) + var_t_s_div_var_t = -jnp.expm1(log_snr_t - log_snr_s) + var_s = jax.nn.sigmoid(-log_snr_s) + var_t = jax.nn.sigmoid(-log_snr_t) + sigma_t = jnp.sqrt(var_t) + + match denoise_output: + case DenoiseOutputType.NOISE: + recon_noise = recon + mu = jnp.sqrt(alpha_s_2 / alpha_t_2) * ( + z_t - sigma_t * var_t_s_div_var_t * recon_noise + ) + case DenoiseOutputType.ANGULAR_VELOCITY: + # We use the expression for q(z_s | z_t, x) directly with x_hat + # substituted for x. + # TODO(siege): Try simplifying this further for better numerics. + x_hat = alpha_t * z_t - sigma_t * recon + alpha_t_s = alpha_t / alpha_s + + mu = alpha_t_s * var_s / var_t * z_t + alpha_s * var_t_s_div_var_t * x_hat + case DenoiseOutputType.DIRECT: + x_hat = recon + alpha_t_s = alpha_t / alpha_s + + mu = alpha_t_s * var_s / var_t * z_t + alpha_s * var_t_s_div_var_t * x_hat + case _: + raise ValueError(f"Unknown denoise_output: {denoise_output}") + sigma = jnp.sqrt(var_t_s_div_var_t * var_s) + z_s = mu + sigma * zeta + return z_s, extra + + +@saving.register +@struct.dataclass +class VDMSampleExtra: + """Extra outputs from `vdm_sample`. + + Attributes: + z_s: A trace of samples. + """ + + z_s: jnp.ndarray | None + + +def vdm_sample( + z_t: jnp.ndarray, + num_steps: int, + log_snr_fn: LogSnrFn, + denoise_fn: DenoiseFn, + seed: jax.Array, + trace_z_s: bool = False, + denoise_output: DenoiseOutputType = DenoiseOutputType.NOISE, + t_start: jnp.ndarray | float = 1.0, +) -> tuple[jnp.ndarray, VDMSampleExtra]: + """Generates a sample from the variational diffusion model (VDM). + + This uses the sampler from [1]. See `vdm_diffusion_loss` for the requirements + on `denoise_fn`. + + Args: + z_t: The initial noised sample. Should have the same distribution as + `variance_preserving_forward_process(x, noise, log_snr_fn(t_start))`. + num_steps: Number of steps to take. The more steps taken, then more accurate + the sample. 1000 is a common value. + log_snr_fn: Takes in time in [0, 1] and returns the log signal-to-noise + ratio. + denoise_fn: Function that denoises `z_t` given the `log_snr_t`. Its output + is interpreted based on the value of `denoise_output`. + seed: Random seed. + trace_z_s: Whether to trace intermediate samples. + denoise_output: How to interpret the output of `denoise_fn`. + t_start: The value of t in z_t. Typically this is 1, signifying that z_t is + a sample from a standard normal. + + Returns: + A tuple of the sample and `VDMSampleExtra` extra outputs. + + + #### References + + [1] Kingma, D. P., Salimans, T., Poole, B., & Ho, J. (2021). Variational + Diffusion Models. In arXiv [cs.LG]. arXiv. http://arxiv.org/abs/2107.00630 + """ + + def body(z_t, step, seed): + sample_seed, seed = jax.random.split(seed) + z_s, _ = _vdm_sample_step( + z_t=z_t, + step=step, + num_steps=num_steps, + log_snr_fn=log_snr_fn, + denoise_fn=denoise_fn, + seed=sample_seed, + denoise_output=denoise_output, + t_start=t_start, + ) + if trace_z_s: + trace = {"z_s": z_s} + else: + trace = {} + return (z_s, step - 1, seed), trace + + (z_0, _, _), trace = fun_mc.trace((z_t, num_steps, seed), body, num_steps) + + return z_0, VDMSampleExtra(z_s=trace.get("z_s")) + + +def linear_log_snr( + t: jnp.ndarray, + log_snr_start: jax.typing.ArrayLike = 6.0, + log_snr_end: jax.typing.ArrayLike = -6.0, +) -> jnp.ndarray: + """Linear log signal-to-noise ratio function.""" + return log_snr_start + (log_snr_end - log_snr_start) * t # pytype: disable=bad-return-type # numpy-scalars diff --git a/discussion/robust_inverse_graphics/diffusion_test.py b/discussion/robust_inverse_graphics/diffusion_test.py new file mode 100644 index 0000000000..37b68bd925 --- /dev/null +++ b/discussion/robust_inverse_graphics/diffusion_test.py @@ -0,0 +1,158 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools + +from absl.testing import parameterized +import jax +import jax.numpy as jnp + +from discussion.robust_inverse_graphics import diffusion +from discussion.robust_inverse_graphics.util import test_util + + +class DiffusionTest(test_util.TestCase): + + def test_variance_preserving_forward_process(self): + seeds = jax.random.split(self.test_seed()) + + x = 1.0 + 2 * jax.random.normal(seeds[0], [10000]) + noise = jax.random.normal(seeds[1], [10000]) + + z_small = diffusion.variance_preserving_forward_process( + x, + noise, + log_snr_t=jnp.array(-20.0), + ) + z_large = diffusion.variance_preserving_forward_process( + x, + noise, + log_snr_t=jnp.array(20.0), + ) + + # At the SNR extremes, the outputs should either be all noise or all inputs. + self.assertAllClose(z_small.mean(), 0.0, atol=1e-1) + self.assertAllClose(z_small.std(), 1.0, rtol=1e-1) + self.assertAllClose(z_large.mean(), 1.0, atol=1e-1) + self.assertAllClose(z_large.std(), 2.0, rtol=1e-1) + + @parameterized.named_parameters( + ("_continuous_noise", None, diffusion.DenoiseOutputType.NOISE), + ("_discretized_noise", 1000, diffusion.DenoiseOutputType.NOISE), + ("_continuous_v", None, diffusion.DenoiseOutputType.ANGULAR_VELOCITY), + ("_discretized_v", 1000, diffusion.DenoiseOutputType.ANGULAR_VELOCITY), + ("_continuous_direct", None, diffusion.DenoiseOutputType.DIRECT), + ("_discretized_direct", 1000, diffusion.DenoiseOutputType.DIRECT), + ) + def test_vdm_diffusion_loss(self, num_steps, denoise_output): + # P(x) is a gaussian with these parameters. + loc_0 = 1.0 + scale_0 = 3.0 + + def denoise_fn(w, z_t, log_snr_t): + # Since P(x) is a gaussian, we know the solution in closed form. See e.g. + # Appendix L from Kingma et al. 2021 and also the form of the forward + # process. + var_t = jax.nn.sigmoid(-log_snr_t) + sigma_t = jnp.sqrt(var_t) + alpha_t = jnp.sqrt(jax.nn.sigmoid(log_snr_t)) # sqrt(1 - var_t) + + loc = alpha_t * loc_0 + var = (alpha_t * scale_0) ** 2 + var_t + score = (loc - z_t) / var + recon_noise = -score * jnp.sqrt(var_t) + w + match denoise_output: + case diffusion.DenoiseOutputType.NOISE: + return recon_noise, () + case diffusion.DenoiseOutputType.ANGULAR_VELOCITY: + predicted_x = (z_t - sigma_t * recon_noise) / alpha_t + return (alpha_t * z_t - predicted_x) / sigma_t, () + case diffusion.DenoiseOutputType.DIRECT: + predicted_x = (z_t - sigma_t * recon_noise) / alpha_t + return predicted_x, () + + def get_loss(w, x, t, seed): + + return diffusion.vdm_diffusion_loss( + t=t, + num_steps=num_steps, + x=x, + log_snr_fn=diffusion.linear_log_snr, + denoise_fn=functools.partial(denoise_fn, w), + seed=seed, + denoise_output=denoise_output, + )[0] + + x_seed, seed = jax.random.split(self.test_seed(), 2) + + x = jax.random.normal(x_seed, [100000]) * scale_0 + loc_0 + t = jnp.linspace(0., 1., 100000) + seeds = jax.random.split(seed, 100000) + + grad = jax.grad( + lambda w: jax.vmap(lambda x, t, seed: get_loss(w, x, t, seed))( # pylint: disable=g-long-lambda + x, t, seeds + ).mean() + )(jnp.zeros([])) + + self.assertAllClose(grad, 0., atol=2e-1) + + @parameterized.named_parameters( + ("_noise", diffusion.DenoiseOutputType.NOISE), + ("_v", diffusion.DenoiseOutputType.ANGULAR_VELOCITY), + ("_direct", diffusion.DenoiseOutputType.DIRECT), + ) + def test_vdm_sample(self, denoise_output): + # P(x) is a gaussian with these parameters. + loc_0 = 1.0 + scale_0 = 3.0 + + def denoise_fn(z_t, log_snr_t): + # Since P(x) is a gaussian, we know the solution in closed form. See e.g. + # Appendix L from Kingma et al. 2021 and also the form of the forward + # process. + var_t = jax.nn.sigmoid(-log_snr_t) + sigma_t = jnp.sqrt(var_t) + alpha_t = jnp.sqrt(jax.nn.sigmoid(log_snr_t)) # sqrt(1 - var_t) + + loc = alpha_t * loc_0 + var = (alpha_t * scale_0) ** 2 + var_t + score = (loc - z_t) / var + recon_noise = -score * jnp.sqrt(var_t) + match denoise_output: + case diffusion.DenoiseOutputType.NOISE: + return recon_noise, () + case diffusion.DenoiseOutputType.ANGULAR_VELOCITY: + predicted_x = (z_t - sigma_t * recon_noise) / alpha_t + return (alpha_t * z_t - predicted_x) / sigma_t, () + case diffusion.DenoiseOutputType.DIRECT: + predicted_x = (z_t - sigma_t * recon_noise) / alpha_t + return predicted_x, () + + init_seed, sample_seed = jax.random.split(self.test_seed()) + z_t = jax.random.normal(init_seed, [10000]) + + sample, _ = diffusion.vdm_sample( + z_t=z_t, + num_steps=1000, + log_snr_fn=diffusion.linear_log_snr, + denoise_fn=denoise_fn, + seed=sample_seed, + denoise_output=denoise_output, + ) + self.assertAllClose(sample.mean(), loc_0, atol=2e-1) + self.assertAllClose(sample.std(), scale_0, rtol=2e-1) + + +if __name__ == "__main__": + test_util.main() diff --git a/discussion/robust_inverse_graphics/models.py b/discussion/robust_inverse_graphics/models.py new file mode 100644 index 0000000000..0b48b3c639 --- /dev/null +++ b/discussion/robust_inverse_graphics/models.py @@ -0,0 +1,1035 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Probabilistic models.""" + +import abc +from collections.abc import Callable +import dataclasses +import functools +import operator +from typing import Any + +from flax import linen as nn +from flax import struct +import jax +import jax.numpy as jnp +from discussion.robust_inverse_graphics import diffusion +from discussion.robust_inverse_graphics import saving +from discussion.robust_inverse_graphics.nerf import rendering +from discussion.robust_inverse_graphics.util import tree_util +import tensorflow_probability.substrates.jax as tfp + +tfd = tfp.distributions +pbroadcast = tfp.internal.distribute_lib.rwb_pbroadcast +reduce_mean = tfp.internal.distribute_lib.reduce_mean + +__all__ = [ + 'chunked_render_latents', + 'Dataset', + 'EvaluationGuideParams', + 'Example', + 'ExamplesIterator', + 'Extra', + 'Guide', + 'GuideParams', + 'GuideSampleExtra', + 'Latents', + 'LikelihoodInputs', + 'make_mean_field_guide', + 'make_nerf_model', + 'make_probnerf_guide', + 'make_probnerf_model', + 'MeanFieldGuideParams', + 'Model', + 'ModelParams', + 'ParameterizedGuide', + 'ParameterizedModel', + 'rgb_log_likelihood', + 'sinh_arcsinh_rgb_log_likelihood', + 'RGBExtra', + 'UnprocessedExample', +] + +ModelParams = Any +GuideParams = Any +Latents = Any +Extra = Any +UnprocessedExample = Any + + +@saving.register +@struct.dataclass +class GuideSampleExtra: + log_prob: jax.Array | None = None + log_prob_stop_grad_params: jax.Array | None = None + dist_params: Any | None = None + + +# TODO(siege): Annotate these fields in the class itself. +# The selectors need to live outside, so that different DataclassViews compare +# equal. +def _fields_with_view_axis_selector(f: str) -> bool: + return f not in ['scene_id', 'scene_embedding'] + + +def _fields_with_ray_axis_selector(f: str) -> bool: + return f in ['rays', 'rgb', 'depth', 'segmentation'] + + +@struct.dataclass +class Example: + """Inference data/example. + + Attributes: + rays: The rays used to generate the image(s). + rgb: The colors of the pixels of the image(s). Shape: [3] + time: Optional time value, for video data. Shape: [] + depth: Optional depth value, increasing with forward distance from the + camera. Shape: [] + segmentation: Optional integer-valued segmentation mask, with distinct + values identifying distinct objects in the scene. Shape: [] + camera_world_matrix: Optional camera world matrix. Shape: [4, 4] + camera_intrinsics: Optional pinhole camera intrinsic parameters. Shape: [3, + 3] + scene_id: Optional scene index. Shape: [] + scene_embedding: Optional scene embedding. Shape [...] + """ + + rays: rendering.Ray | None = None + rgb: Any | None = None + time: Any | None = None + depth: Any | None = None + segmentation: Any | None = None + camera_world_matrix: Any | None = None + camera_intrinsics: Any | None = None + scene_id: Any | None = None + scene_embedding: Any | None = None + + def fields_with_view_axis(self) -> tree_util.DataclassView['Example']: + """Returns a view with only fields that typically have the view axis.""" + return tree_util.DataclassView(self, _fields_with_view_axis_selector) + + def fields_with_ray_axis(self) -> tree_util.DataclassView['Example']: + """Returns a view with only fields that typically have the ray axis.""" + return tree_util.DataclassView(self, _fields_with_ray_axis_selector) + + @classmethod + def test_example( + cls, num_views: int = 5, im_height: int = 128, im_width: int = 128 + ) -> 'Example': + shape = num_views, im_height, im_width, 3 + ones = jnp.ones(shape) + zeros = jnp.zeros(shape) + return cls(rays=rendering.Ray(zeros, ones, ones), rgb=zeros) + + @classmethod + def view_axis(cls) -> 'Example': + """Returns the example with fields replaced with the view axis location.""" + axes = {} + for f in dataclasses.fields(cls): + if f.name in ['scene_id', 'scene_embedding']: + axis = None + else: + axis = 0 + axes[f.name] = axis + return cls(**axes) + + +class ExamplesIterator(metaclass=abc.ABCMeta): + """A sized iterator for examples.""" + + @abc.abstractmethod + def __next__(self) -> UnprocessedExample: + pass + + def __iter__(self) -> 'ExamplesIterator': + return self + + @abc.abstractmethod + def _size(self) -> int: + """Epoch size.""" + + @property + def size(self) -> int: + return self._size() + + @abc.abstractmethod + def save(self, checkpoint_dir: str): + """Save the checkpoint iteration state.""" + + @abc.abstractmethod + def load(self, checkpoint_dir: str): + """Load the checkpoint iteration state.""" + + +@dataclasses.dataclass +class Dataset: + train_examples_fn: Callable[[], ExamplesIterator] + test_examples_fn: Callable[[], ExamplesIterator] + process_example_fn: Callable[[UnprocessedExample], Example] = lambda x: x + + +@struct.dataclass +class RGBExtra: + """Extras from `rgb_log_likelihood`. + + S below refers to the shape of RGB, which is at least a vector. + + Attributes: + per_channel_ll: Per-channel likelihood. Shape: [S] + rgb_loss: L2 reconstruction loss. Shape: [] + """ + + per_channel_ll: jax.Array + rgb_loss: jax.Array + + +def rgb_log_likelihood( + recon_rgb: rendering.RGB, + rgb: rendering.RGB, + obs_scale: jax.Array, +) -> tuple[jax.Array, RGBExtra]: + """RGB per-channel likelihood with Gaussian noise. + + S below refers to the shape of RGB, which is at least a vector. + + Args: + recon_rgb: Reconstructed RGB. Shape: [S, 3] + rgb: Target RGB. Shape: [S, 3] + obs_scale: Observation scale. Shape: [] + + Returns: + A tuple of: + The likelihood value. + RGBExtra, for extra outputs. + """ + rgb_loss = ((recon_rgb - rgb) ** 2).mean() + per_channel_ll = tfd.Normal(recon_rgb, obs_scale).log_prob(rgb) + ll = per_channel_ll.sum() + + extra = RGBExtra( + per_channel_ll=per_channel_ll, + rgb_loss=rgb_loss, + ) + + return ll, extra + + +def sinh_arcsinh_rgb_log_likelihood( + recon_rgb: rendering.RGB, + rgb: rendering.RGB, + obs_scale: jax.Array, + reversion_to_mean: jax.Array | float = 0.8, + obs_scale_factor: jax.Array | float = 2.1, + tailweight: jax.Array | float = 1.1, +) -> tuple[jax.Array, RGBExtra]: + """RGB per-channel likelihood with a skewed, SinhArcsinh distributed noise. + + S below refers to the shape of RGB, which is at least a vector. + + Args: + recon_rgb: Reconstructed RGB. Shape: [S, 3] + rgb: Target RGB. Shape: [S, 3] + obs_scale: Observation scale. Shape: [] + reversion_to_mean: Reversion to mean factor. Shape: [] + obs_scale_factor: Observation scale factor. Shape: [] + tailweight: Tailweight. Shape: [] + + Returns: + A tuple of: + The likelihood value. + RGBExtra, for extra outputs. + """ + rgb_loss = ((recon_rgb - rgb) ** 2).mean() + per_channel_ll = tfd.SinhArcsinh( + 0.5 + reversion_to_mean * (recon_rgb - 0.5), + obs_scale * obs_scale_factor, + (0.5 - recon_rgb), + tailweight, + ).log_prob(rgb) + ll = per_channel_ll.sum() + + extra = RGBExtra( + per_channel_ll=per_channel_ll, + rgb_loss=rgb_loss, + ) + + return ll, extra + + +@functools.partial(jax.jit, static_argnames=('render_latents_fn', 'num_chunks')) +def chunked_render_latents( + render_latents_fn: Callable[ + [Latents, Example, jax.Array], + tuple[rendering.RGB, Extra], + ], + latents: Latents, + example: 'Example', + seed: jax.Array, + num_chunks: int = 40, +) -> tuple[rendering.RGB, Extra]: + """Renders latents in a chunked way, to save on memory.""" + example_view = jax.tree.map( + lambda x: x.reshape((num_chunks, x.shape[0] // num_chunks) + x.shape[1:]), + example.fields_with_ray_axis(), + ) + + rgb, extra = jax.lax.map( + lambda example_seed: render_latents_fn( # pylint: disable=g-long-lambda + latents, example_seed[0].value, example_seed[1] + ), + (example_view, jax.random.split(seed, num_chunks)), + ) + + return jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:]), (rgb, extra)) + + +@struct.dataclass +class LikelihoodInputs: + """Additional inputs to the likelihood function. + + Attributes: + example: The example. + seed: The random seed. + step: The inference step. + """ + + example: Example | None = None + seed: jax.Array | None = None + step: jax.Array | None = None + + +@saving.register +@struct.dataclass +class LikelihoodExtra: + psnr: jax.Array | None = None + render_extra: Any | None = None + rgb_mse: jax.Array | None = None + + +@struct.dataclass +class Model: + """A model.""" + + init_latents_fn: Callable[[jax.Array], tuple[Latents, Extra]] + render_latents_fn: Callable[ + [Latents, Example, jax.Array], + tuple[rendering.RGB, Extra], + ] + log_likelihood_fn: Callable[ + [Latents, LikelihoodInputs], tuple[jax.Array, Extra] + ] + prior_log_prob_fn: Callable[[Latents], tuple[jax.Array, Extra]] + prior_sample_fn: Callable[[jax.Array], tuple[Latents, Extra]] + reduce_extra_fn: ( + Callable[[Extra, Extra, str | None], tuple[Extra, Extra]] | None + ) = None + + +@struct.dataclass +class SSDNeRFModel(Model): + """SSDNeRF model.""" + + denoise_fn: Callable[[jax.Array, jax.Array], tuple[jax.Array, Any]] | None = ( + None + ) + denoise_output: diffusion.DenoiseOutputType | None = None + log_snr_fn: Callable[[jax.Array], jax.Array] | None = None + num_rays: int | None = None + near: float | None = None + far: float | None = None + num_samples: tuple[int, ...] | None = None + obs_scales: tuple[float, ...] | None = None + prior_map_bwd_fn: Callable[[jax.Array], jax.Array] = lambda x: x + prior_map_fwd_fn: Callable[[jax.Array], jax.Array] = lambda x: x + + +@struct.dataclass +class GenSRTModel(Model): + """GenSRT model.""" + + denoise_fn: Callable[[jax.Array, jax.Array], tuple[jax.Array, Any]] | None = ( + None + ) + denoise_output: diffusion.DenoiseOutputType | None = None + log_snr_fn: Callable[[jax.Array], jax.Array] | None = None + num_rays: int | None = None + + +def make_nerf_model( + nerf: nn.Module, + near: float, + far: float, + num_samples: tuple[int, ...], + obs_scales: tuple[float, ...], + anneal_nerf: bool = False, + num_rays: int = 0, + rgb_log_likelihood_fn: Callable[ + [rendering.RGB, rendering.RGB, jax.Array], + tuple[jax.Array, RGBExtra], + ] = rgb_log_likelihood, + ray_warp_fn: str | Callable[[jax.Array], jax.Array] | None = None, +) -> Model: + """Creates a model from a NeRF. + + Args: + nerf: A NeRF model. The call method should have the signature of + `(RaySample) -> ((Density, RGB), Extra)` + near: The near plane. + far: The far plane. + num_samples: Number of samples for Mip-NeRF rendering. + obs_scales: Scales for the likelihoods for each Mip-NeRF level. + anneal_nerf: If True, also pass in a `step` argument to the NeRF model, + typically for annealing purposes. + num_rays: If non-zero, subsample rays to this amount. + rgb_log_likelihood_fn: Pixel-level likelihood function to use. + ray_warp_fn: Ray warp function. See Equation 11 in + https://arxiv.org/abs/2111.12077. + + Returns: + A Model. + """ + + def init_latents_fn(seed): + if anneal_nerf: + kwargs = {'step': 0} + else: + kwargs = {} + return ( + nerf.init( + seed, + rendering.RaySample( + position=jnp.zeros(3), + covariance=jnp.ones((3, 3)), + viewdir=jnp.ones(3), + ), + **kwargs, + ), + (), + ) + + def render_latents_fn(latents, example, seed, step=None): + if anneal_nerf: + kwargs = {'step': step} + else: + kwargs = {} + return rendering.render_rf( + rf_fn=functools.partial(nerf.apply, latents, **kwargs), + rays=example.rays, + near=near, + far=far, + num_samples=num_samples, + seed=seed, + ray_warp_fn=ray_warp_fn, + ) + + def log_likelihood_fn(latents, inputs, return_render_extra=False): + seed = inputs.seed + + if num_rays > 0: + seed, subsample_seed = jax.random.split(seed) + rays, rgb = subsample_rays(inputs.example, num_rays, subsample_seed) + total_num_rays = functools.reduce( + operator.mul, inputs.example.rgb.shape[:-1] + ) + ll_factor = total_num_rays / num_rays + else: + rays = inputs.example.rays + rgb = inputs.example.rgb + ll_factor = 1 + + _, extra = render_latents_fn(latents, Example(rays=rays), seed, inputs.step) + ll = 0.0 + for extra_l, obs_scale in zip(extra.levels, obs_scales): + one_ll, ll_extra = rgb_log_likelihood_fn( + extra_l.rgb, + rgb, + obs_scale, + ) + ll += one_ll + # Grab the ll_extra from the last level, as that corresponds to the final + # reconstruction. + rgb_mse = ll_extra.rgb_loss + if not return_render_extra: + extra = None + return ll * ll_factor, LikelihoodExtra( + psnr=None, rgb_mse=rgb_mse, render_extra=extra + ) + + def prior_log_prob_fn(latents): + del latents # Unused. + return 0.0, () + + def prior_sample_fn(seed): + return init_latents_fn(seed) + + def reduce_extra_fn(prior_extra, likelihood_extra, example_axis_name=None): + rgb_mse = reduce_mean( + likelihood_extra.rgb_mse, named_axis=example_axis_name + ) + return prior_extra, likelihood_extra.replace( + rgb_mse=rgb_mse, psnr=-10 * jnp.log10(rgb_mse) + ) + + return Model( + init_latents_fn=init_latents_fn, + render_latents_fn=render_latents_fn, + log_likelihood_fn=log_likelihood_fn, + prior_log_prob_fn=prior_log_prob_fn, + prior_sample_fn=prior_sample_fn, + reduce_extra_fn=reduce_extra_fn, + ) + + +@struct.dataclass +class ParameterizedModel: + """A parameterized model.""" + + init_params_fn: Callable[[jax.Array], tuple[ModelParams, Extra]] + prior_sample_fn: Callable[[ModelParams, jax.Array], tuple[Latents, Extra]] + prior_log_prob_fn: Callable[[ModelParams, Latents], tuple[jax.Array, Extra]] + render_latents_fn: Callable[ + [ModelParams, Latents, Example, jax.Array], + tuple[rendering.RGB, Extra], + ] + log_likelihood_fn: Callable[ + [ModelParams, Latents, LikelihoodInputs], + tuple[jax.Array, Extra], + ] + reduce_extra_fn: ( + Callable[ + [Extra, Extra, str | None], + tuple[Extra, Extra], + ] + | None + ) = None + + +@struct.dataclass +class ParameterizedSSDNeRFModel(ParameterizedModel): + """A parameterized SSDNeRF model.""" + + denoise_fn: ( + Callable[[jax.Array, jax.Array, jax.Array], tuple[jax.Array, Any]] | None + ) = None + log_snr_fn: Callable[[jax.Array], jax.Array] | None = None + num_rays: int | None = None + near: float | None = None + far: float | None = None + num_samples: tuple[int, ...] | None = None + obs_scales: tuple[float, ...] | None = None + denoise_output: diffusion.DenoiseOutputType | None = None + prior_map_fwd_fn: Callable[[jax.Array], jax.Array] = lambda x: x + prior_map_bwd_fn: Callable[[jax.Array], jax.Array] = lambda x: x + + +@struct.dataclass +class ParameterizedGenSRTModel(ParameterizedModel): + """A parameterized GenSRT model.""" + + denoise_fn: ( + Callable[[jax.Array, jax.Array, jax.Array], tuple[jax.Array, Any]] | None + ) = None + log_snr_fn: Callable[[jax.Array], jax.Array] | None = None + num_rays: int | None = None + denoise_output: diffusion.DenoiseOutputType | None = None + + +@struct.dataclass +class Guide: + """A guide.""" + + guide_sample_fn: Callable[ + [Example, jax.Array], + tuple[Latents, GuideSampleExtra], + ] + guide_log_prob_fn: ( + Callable[ + [Example, Latents], + tuple[jax.Array, Extra], + ] + | None + ) = None + reduce_extra_fn: ( + Callable[ + [GuideSampleExtra, Extra, str | None], + tuple[GuideSampleExtra, Extra], + ] + | None + ) = None + + +@struct.dataclass +class ParameterizedGuide: + """A parameterized guide.""" + + init_params_fn: Callable[[jax.Array], tuple[GuideParams, Extra]] + guide_sample_fn: Callable[ + [GuideParams, Example, jax.Array], + tuple[Latents, GuideSampleExtra], + ] + guide_log_prob_fn: ( + Callable[ + [GuideParams, Example, Latents], + tuple[jax.Array, Extra], + ] + | None + ) = None + reduce_extra_fn: ( + Callable[ + [GuideSampleExtra, Extra, str | None], + tuple[GuideSampleExtra, Extra], + ] + | None + ) = None + + +@saving.register +@struct.dataclass +class ProbNeRFModelParams: + hypernet_params: Any + realnvp_params: Any + corruption_params: Any | None = None + + +def make_probnerf_model( + nerf: nn.Module, + hypernet: nn.Module, + realnvp: nn.Module, + corruption_nerf: nn.Module | None = None, + num_rays: int = 1024, + near: float = 0.2, + far: float = 1.5, + num_samples: tuple[int, ...] = (48, 48), + obs_scales: tuple[float, ...] = (1.0, 0.1), + rgb_log_likelihood_fn: Callable[ + [rendering.RGB, rendering.RGB, jax.Array], + tuple[jax.Array, RGBExtra], + ] = rgb_log_likelihood, +) -> ParameterizedModel: + """Creates a ProbNeRF model [1]. + + Args: + nerf: Base NeRF. + hypernet: Hypernetwork mapping from z to nerf weights. + realnvp: RealNVP prior. + corruption_nerf: Optional corruption nerf. + num_rays: Number of rays to subsample after view subsampling. + near: The near plane. + far: The far plane. + num_samples: Number of samples for Mip-NeRF rendering. + obs_scales: Scales for the likelihoods for each Mip-NeRF level. + rgb_log_likelihood_fn: Pixel-level likelihood function to use. + + Returns: + A ProbNeRF model. + + #### References + + [1] Hoffman, M. D., Le, T. A., Sountsov, P., Suter, C., Lee, B., Mansinghka, + V. K., & Saurous, R. A. (2023). ProbNeRF: Uncertainty-Aware Inference of 3D + Shapes from 2D Images. International Conference on Artificial Intelligence + and Statistics. https://arxiv.org/abs/2210.17415 + """ + num_latent = realnvp.ndims + + def init_params_fn(seed): + hypernet_seed, corruption_nerf_seed, realnvp_seed = jax.random.split( + seed, 3 + ) + + if corruption_nerf is None: + corruption_params = None + else: + corruption_params = corruption_nerf.init( + corruption_nerf_seed, + rendering.RaySample( + position=jnp.zeros(3), + covariance=jnp.ones((3, 3)), + viewdir=jnp.ones(3), + ), + ) + + return ( + ProbNeRFModelParams( + hypernet_params=hypernet.init( + hypernet_seed, jnp.zeros((num_latent,)) + ), + realnvp_params=realnvp.init(realnvp_seed, jnp.zeros((num_latent,))), + corruption_params=corruption_params, + ), + (), + ) + + def prior_sample_fn(params, seed): + pulled_back_latents = tfd.MultivariateNormalDiag( + 0.0, jnp.ones(num_latent) + ).sample(seed=seed) + latents, _ = realnvp.apply( + params.realnvp_params, pulled_back_latents, forward=True + ) + return latents, () + + def prior_log_prob_fn(params, latents): + pulled_back_latents, ildj = realnvp.apply( + params.realnvp_params, latents, forward=False + ) + return ( + tfd.MultivariateNormalDiag(0.0, jnp.ones(latents.shape[-1])).log_prob( + pulled_back_latents + ) + + ildj, + (), + ) + + @jax.jit + def mipnerf_render_rays_from_weights( + nerf_params, corruption_params, rays, seed + ): + if corruption_params is None: + rf = lambda ray_sample: nerf.apply(nerf_params, ray_sample) + else: + if corruption_nerf is None: + raise ValueError( + 'corruption_nerf is None, but corruption_params are not?' + ) + if nerf_params is None: + rf = lambda ray_sample: corruption_nerf.apply( + corruption_params, ray_sample + ) + else: + rf = lambda ray_sample: jax.tree.map( # pylint: disable=g-long-lambda + lambda *x: jnp.stack(x, 0), + nerf.apply(nerf_params, ray_sample), + corruption_nerf.apply(corruption_params, ray_sample), + ) + return rendering.render_rf( + rf_fn=rf, + rays=rays, + near=near, + far=far, + num_samples=num_samples, + seed=seed, + ) + + def render_latents_fn( + params, + latents, + example, + seed, + render_corruption=True, + render_parts=('scene', 'corruption'), + ): + """Renders latents from views specified by example. + + Args: + params: ProbNeRF parameters. + latents: ProbNeRF latent. + example: Example. + seed: Random seed. + render_corruption: Whether to render the corruption model. + render_parts: Which parts to render (can be 'scene' or 'corruption'). + + Returns: + rendered_rgb: Rendered RGB. + extra: Extra. + """ + render_parts = set(render_parts) + if not render_corruption: + render_parts.remove('corruption') + nerf_weights = ( + hypernet.apply(params.hypernet_params, latents) + if 'scene' in render_parts + else None + ) + corruption_params = ( + params.corruption_params if 'corruption' in render_parts else None + ) + rendered_rgb, extra = mipnerf_render_rays_from_weights( + nerf_weights, corruption_params, example.rays, seed + ) + return rendered_rgb, extra + + def log_likelihood_fn(params, latents, inputs): + """Computes the likelihood. + + Assumes view subsampling. Subsamples rays. + + Args: + params: ProbNeRF model params. + latents: Latents. + inputs: Likelihood inputs. Shape (num_views, ...). + + Returns: + ll: Evidence lower bound, averaged over the batch. Shape (). + extra: Extra, averaged over the batch. Shape (). + """ + subsample_seed, render_seed = jax.random.split(inputs.seed) + + subsampled_rays, subsampled_rgb = subsample_rays( + inputs.example, num_rays, subsample_seed + ) + + _, extra = render_latents_fn( + params, latents, Example(rays=subsampled_rays), render_seed + ) + ll = 0.0 + for extra_l, mipnerf_obs_scale in zip(extra.levels, obs_scales): + one_ll, ll_extra = rgb_log_likelihood_fn( + extra_l.rgb, + subsampled_rgb, + mipnerf_obs_scale, + ) + ll += one_ll + # Grab the ll_extra from the last level, as that corresponds to the + # final reconstruction. + + total_num_rays = functools.reduce( + operator.mul, inputs.example.rgb.shape[:-1] + ) + rgb_mse = ll_extra.rgb_loss + return ll * total_num_rays / num_rays, LikelihoodExtra( + psnr=None, rgb_mse=rgb_mse + ) + + def reduce_extra_fn(prior_extra, likelihood_extra, example_axis_name=None): + rgb_mse = reduce_mean( + likelihood_extra.rgb_mse, named_axis=example_axis_name + ) + return prior_extra, likelihood_extra.replace( + rgb_mse=rgb_mse, psnr=-10 * jnp.log10(rgb_mse) + ) + + return ParameterizedModel( + init_params_fn=init_params_fn, + prior_sample_fn=prior_sample_fn, + prior_log_prob_fn=prior_log_prob_fn, + render_latents_fn=render_latents_fn, + log_likelihood_fn=log_likelihood_fn, + reduce_extra_fn=reduce_extra_fn, + ) + + +def make_probnerf_guide( + guide: nn.Module, + im_height: int = 128, + im_width: int = 128, +) -> ParameterizedGuide: + """Creates a ProbNeRF guide [1]. + + Assumes view subsampling. + + Args: + guide: Variational approximation to the posterior. + im_height: Image height. + im_width: Image width. + + Returns: + A ProbNeRF guide. + + #### References + + [1] Hoffman, M. D., Le, T. A., Sountsov, P., Suter, C., Lee, B., Mansinghka, + V. K., & Saurous, R. A. (2023). ProbNeRF: Uncertainty-Aware Inference of 3D + Shapes from 2D Images. International Conference on Artificial Intelligence + and Statistics. https://arxiv.org/abs/2210.17415 + """ + + def init_params_fn(seed): + num_views = 10 + dummy_rgb = jnp.zeros((num_views, im_height, im_width, 3)) + dummy_camera_world_matrix = jnp.zeros((num_views, 4, 4)) + return ( + guide.init( + seed, + dummy_rgb, + dummy_camera_world_matrix, + jax.random.PRNGKey(0), + ), + (), + ) + + def guide_sample_fn(params, example, seed): + latents, (_, _), log_prob_stop_grad_params = guide.apply( + params, example.rgb, example.camera_world_matrix, seed + ) + return latents, GuideSampleExtra( + log_prob_stop_grad_params=log_prob_stop_grad_params + ) + + def reduce_extra_fn(sample_extra, log_prob_extra, example_axis_name=None): + del example_axis_name + return sample_extra, log_prob_extra + + return ParameterizedGuide( + init_params_fn=init_params_fn, + guide_sample_fn=guide_sample_fn, + reduce_extra_fn=reduce_extra_fn, + ) + + +@saving.register +@struct.dataclass +class EvaluationGuideParams: + loc: jax.Array + log_scale: jax.Array + + +def make_probnerf_evaluation_guide( + init_loc_fn: Callable[[jax.Array], tuple[jax.Array, Extra]], +) -> ParameterizedGuide: + """Creates a guide for evaluating a ProbNeRF on a single example. + + Used for maximizing an ELBO which is a lower bound to p(image | model params) + with respect to model params, marginalizing over ProbNeRF latents. + + Model params don't necessarily correspond to ProbNeRF model params. For + example, we can fix ProbNeRF model params while defining optimizable model + params to be parameters of another NeRF representing the corruption. + + Args: + init_loc_fn: Callable to sample the initial locations. + + Returns: + A guide q(latents; guide_params) over ProbNeRF scene latents which ignores + the example. Parameterized as a diagonal multivariate Normal. + """ + + def init_params_fn(seed): + loc_seed, log_scale_seed = jax.random.split(seed) + init_loc, _ = init_loc_fn(loc_seed) + num_latent = init_loc.shape[0] + + # Sample from N(-5, 1) since sampling from N(0, 1) makes RealNVP's + # log prob nan out. + init_log_scale = jax.random.normal(log_scale_seed, (num_latent,)) - 5.0 + + return EvaluationGuideParams(init_loc, init_log_scale), () + + def guide_sample_fn(params, example, seed): + del example + dist = tfd.MultivariateNormalDiag( + loc=params.loc, scale_diag=jnp.exp(params.log_scale) + ) + dist_stop_grad_params = tfd.MultivariateNormalDiag( + loc=jax.lax.stop_gradient(params.loc), + scale_diag=jax.lax.stop_gradient(jnp.exp(params.log_scale)), + ) + latents = dist.sample(seed=seed) + log_prob_stop_grad_params = dist_stop_grad_params.log_prob(latents) + return latents, GuideSampleExtra( + log_prob_stop_grad_params=log_prob_stop_grad_params + ) + + return ParameterizedGuide( + init_params_fn=init_params_fn, + guide_sample_fn=guide_sample_fn, + ) + + +@saving.register +@struct.dataclass +class MeanFieldGuideParams: + loc: Any + isp_scale: Any + + +def make_mean_field_guide( + init_loc_fn: Callable[[jax.Array], tuple[Latents, Any]], + init_scale: float | jax.Array = 1e-2, +) -> ParameterizedGuide: + """Creates a mean field guide for a general model. + + Args: + init_loc_fn: Callable to sample the initial locations. + init_scale: Initial scale multiplier. + + Returns: + A guide q(latents; guide_params). Parameterized as a diagonal multivariate + Normal. + """ + + def init_params_fn(seed): + init_loc, _ = init_loc_fn(seed) + init_isp_scale = tfp.math.softplus_inverse(init_scale) + isp_scale = jax.tree.map( + lambda l: jnp.full(l.shape, init_isp_scale, dtype=l.dtype), + init_loc, + ) + + return MeanFieldGuideParams(loc=init_loc, isp_scale=isp_scale), () + + def guide_sample_fn(params, example, seed): + del example + + leaves, treedef = jax.tree_util.tree_flatten(params.loc) + num_seeds = len(leaves) + seeds = jax.tree_util.tree_unflatten( + treedef, jax.random.split(seed, num_seeds) + ) + + def sample_part(loc, isp_scale, seed): + return tfd.Normal(loc, jax.nn.softplus(isp_scale)).sample(seed=seed) + + def log_prob_part(latent, loc, isp_scale): + return ( + tfd.Normal( + jax.lax.stop_gradient(loc), + jax.lax.stop_gradient(jax.nn.softplus(isp_scale)), + ) + .log_prob(latent) + .sum() + ) + + latents = jax.tree.map(sample_part, params.loc, params.isp_scale, seeds) + log_prob_stop_grad_params = jax.tree.map( + log_prob_part, latents, params.loc, params.isp_scale + ) + log_prob_stop_grad_params = functools.reduce( + lambda x, y: x + y, jax.tree_util.tree_leaves(log_prob_stop_grad_params) + ) + + return latents, GuideSampleExtra( + log_prob_stop_grad_params=log_prob_stop_grad_params + ) + + return ParameterizedGuide( + init_params_fn=init_params_fn, + guide_sample_fn=guide_sample_fn, + ) + + +def subsample_rays( + example: Example, + num_rays: int, + seed: jax.Array, +) -> tuple[rendering.Ray, rendering.RGB]: + """Subsample rays from an example.""" + # Flatten inputs + assert example.rgb is not None + rays_shape = example.rgb.shape[:-1] + rays_ndims = len(rays_shape) + total_num_rays = functools.reduce(operator.mul, rays_shape) + flat_rays = jax.tree.map( + lambda x: x.reshape((-1,) + x.shape[rays_ndims:]), + example.rays, + ) + flat_rgb = example.rgb.reshape((-1, 3)) + + # Subsample rays + indices = jax.random.choice(seed, total_num_rays, (num_rays,), False) + take_fn = lambda x: jax.tree.map(lambda y: y[indices], x) + subsampled_rays = take_fn(flat_rays) + subsampled_rgb = take_fn(flat_rgb) + return subsampled_rays, subsampled_rgb diff --git a/discussion/robust_inverse_graphics/models_test.py b/discussion/robust_inverse_graphics/models_test.py new file mode 100644 index 0000000000..ce13ac0e2e --- /dev/null +++ b/discussion/robust_inverse_graphics/models_test.py @@ -0,0 +1,245 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +from typing import Any + +from flax import linen as nn +import jax +import jax.numpy as jnp + +from discussion.robust_inverse_graphics import models +from discussion.robust_inverse_graphics import probnerf +from discussion.robust_inverse_graphics.nerf import rendering +from discussion.robust_inverse_graphics.util import test_util + + +class TestNerf(nn.Module): + + @nn.compact + def __call__( + self, ray_sample: rendering.RaySample + ) -> tuple[tuple[rendering.Density, rendering.RGB], Any]: + init_fn = functools.partial(jax.random.uniform, shape=[3]) + rgb = jax.nn.sigmoid(self.param('rgb', init_fn)) + density = jnp.ones([]) + + return (density, rgb), () + + +class ModelsTest(test_util.TestCase): + + def test_fields_with_view_axis(self): + rays_shape = 2, 3 + ones = jnp.ones(rays_shape) + zeros = jnp.zeros(rays_shape) + example_1 = models.Example( + rays=rendering.Ray(zeros, ones, ones), rgb=zeros, + scene_id=jnp.zeros([]), + ) + example_2 = jax.tree.map(lambda x: x, example_1) + example = jax.tree.map( + lambda x_1, x_2: None, + example_1.fields_with_view_axis(), + example_2.fields_with_view_axis(), + ).value + self.assertIsNone(example.rays.origin) + self.assertIsNone(example.rgb) + self.assertEqual(example.scene_id, 0) + + def test_rgb_log_likelihood(self): + ll, extra = models.rgb_log_likelihood( + recon_rgb=jnp.zeros(3), + rgb=jnp.zeros(3), + obs_scale=jnp.ones([]), + ) + + self.assertEqual(ll.shape, []) + self.assertEqual(extra.rgb_loss.shape, []) + self.assertEqual(extra.per_channel_ll.shape, [3]) + + def test_sinh_arcsinh_rgb_log_likelihood(self): + ll, extra = models.sinh_arcsinh_rgb_log_likelihood( + recon_rgb=jnp.zeros(3), + rgb=jnp.zeros(3), + obs_scale=jnp.ones([]), + ) + + self.assertEqual(ll.shape, []) + self.assertEqual(extra.rgb_loss.shape, []) + self.assertEqual(extra.per_channel_ll.shape, [3]) + + def test_chunked_render_latents_fn(self): + rgb = jnp.linspace(0.0, 1.0, 40 * 3).reshape([-1, 3]) + example = models.Example(rgb=rgb, scene_id=jnp.zeros([])) + latents = jnp.array([1.0]) + + def render_latents_fn(latents, example, _): + self.assertEqual(example.rgb.shape, (2, 3)) + return example.rgb + latents, () + + rendered_rgb, _ = models.chunked_render_latents( + render_latents_fn=render_latents_fn, + latents=latents, + example=example, + seed=jax.random.PRNGKey(0), + num_chunks=20, + ) + self.assertAllClose(rgb + latents, rendered_rgb) + + def test_make_nerf_model(self): + model_fn = functools.partial( + models.make_nerf_model, + TestNerf(), + near=1.0, + far=2.0, + num_samples=(4, 4), + obs_scales=(1.0, 1.0), + ) + model = model_fn() + + init_latents, _ = model.init_latents_fn(jax.random.PRNGKey(0)) + example = models.Example( + rgb=jnp.zeros([16, 3]), + rays=rendering.Ray( + origin=jnp.zeros([16, 3]), + direction=jnp.ones([16, 3]), + viewdir=jnp.ones([16, 3]), + radius=jnp.ones([16]), + ), + ) + ll, _ = model.log_likelihood_fn( + init_latents, + models.LikelihoodInputs(example=example, seed=jax.random.PRNGKey(1)), + ) + self.assertEqual(ll.shape, []) + + def test_make_probnerf_model(self): + num_latent = 2 + grid_size = 3 + obs_scale = 0.1 + total_num_views = 10 + im_height, im_width = 4, 4 + num_rays = total_num_views * im_width * im_height + + nerf = probnerf.TwoPartNerf(grid_size=grid_size) + nerf_variables = nerf.init( + jax.random.PRNGKey(0), + rendering.RaySample.test_sample(), + ) + hypernet = probnerf.DecoderHypernet(nerf_variables) + realnvp = probnerf.RealNVPStack(num_latent) + + model = models.make_probnerf_model( + nerf, + hypernet, + realnvp, + num_rays=num_rays, + obs_scales=(obs_scale,), + ) + + init_params, _ = model.init_params_fn(jax.random.PRNGKey(0)) + + ( + rgb_seed, + origin_seed, + direction_seed, + camera_world_matrix_seed, + latents_seed, + ) = jax.random.split(jax.random.PRNGKey(0), 5) + rgb = jax.random.normal(rgb_seed, [total_num_views, im_height, im_width, 3]) + origin = jax.random.normal( + origin_seed, [total_num_views, im_height, im_width, 3] + ) + direction = jax.random.normal( + direction_seed, [total_num_views, im_height, im_width, 3] + ) + viewdir = direction / jnp.linalg.norm(direction, axis=-1, keepdims=True) + camera_world_matrix = jax.random.normal( + camera_world_matrix_seed, [total_num_views, 4, 4] + ) + latents = jax.random.normal(latents_seed, [num_latent]) + example = models.Example( + rgb=rgb, + rays=rendering.Ray( + origin=origin, + direction=direction, + viewdir=viewdir, + radius=jnp.ones([total_num_views, im_height, im_width]), + ), + camera_world_matrix=camera_world_matrix, + ) + + ll = model.log_likelihood_fn( + init_params, + latents, + models.LikelihoodInputs(example=example, seed=jax.random.PRNGKey(1)), + )[0] + self.assertEqual(ll.shape, []) + + def test_make_probnerf_guide(self): + num_latent = 2 + grid_size = 3 + total_num_views = 10 + im_height, im_width = 4, 4 + + nerf = probnerf.TwoPartNerf(grid_size=grid_size) + nerf_variables = nerf.init( + jax.random.PRNGKey(0), + rendering.RaySample.test_sample(), + ) + guide = probnerf.Guide(num_latent, nerf_variables) + probnerf_guide = models.make_probnerf_guide(guide, im_height, im_width) + + init_params, _ = probnerf_guide.init_params_fn(jax.random.PRNGKey(0)) + + ( + rgb_seed, + origin_seed, + direction_seed, + camera_world_matrix_seed, + ) = jax.random.split(jax.random.PRNGKey(0), 4) + rgb = jax.random.normal(rgb_seed, [total_num_views, im_height, im_width, 3]) + origin = jax.random.normal( + origin_seed, [total_num_views, im_height, im_width, 3] + ) + direction = jax.random.normal( + direction_seed, [total_num_views, im_height, im_width, 3] + ) + viewdir = direction / jnp.linalg.norm(direction, axis=-1, keepdims=True) + camera_world_matrix = jax.random.normal( + camera_world_matrix_seed, [total_num_views, 4, 4] + ) + example = models.Example( + rgb=rgb, + rays=rendering.Ray( + origin=origin, + direction=direction, + viewdir=viewdir, + ), + camera_world_matrix=camera_world_matrix, + ) + + latents, extra = probnerf_guide.guide_sample_fn( + init_params, + example, + jax.random.PRNGKey(1), + ) + + self.assertEqual(latents.shape, (num_latent,)) + assert extra.log_prob_stop_grad_params + self.assertEqual(extra.log_prob_stop_grad_params.shape, ()) + + +if __name__ == '__main__': + test_util.main() diff --git a/discussion/robust_inverse_graphics/probnerf.py b/discussion/robust_inverse_graphics/probnerf.py new file mode 100644 index 0000000000..a4d7f71aab --- /dev/null +++ b/discussion/robust_inverse_graphics/probnerf.py @@ -0,0 +1,381 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ProbNeRF.""" + +from typing import Any + +import flax.linen as nn +import jax +from jax import lax +import jax.numpy as jnp +import jax.scipy as jsp +import numpy as np +import tensorflow_probability.substrates.jax as tfp + + +tfd = tfp.distributions + + +__all__ = [ + 'TwoPartNerf', + 'DecoderHypernet', + 'RealNVPStack', + 'Guide', +] + + +# Taken from +# +def posenc(x, min_deg, max_deg, legacy_posenc_order=False): + """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1]. + + Instead of computing [sin(x), cos(x)], we use the trig identity + cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]). + + Args: + x: jnp.ndarray, variables to be encoded. Note that x should be in [-pi, pi]. + min_deg: int, the minimum (inclusive) degree of the encoding. + max_deg: int, the maximum (exclusive) degree of the encoding. + legacy_posenc_order: bool, keep the same ordering as the original tf code. + + Returns: + encoded: jnp.ndarray, encoded variables. + """ + if min_deg == max_deg: + return x + scales = jnp.array([2**i for i in range(min_deg, max_deg)]) + if legacy_posenc_order: + xb = x[..., None, :] * scales[:, None] + four_feat = jnp.reshape( + jnp.sin(jnp.stack([xb, xb + 0.5 * jnp.pi], -2)), + list(x.shape[:-1]) + [-1], + ) + else: + xb = jnp.reshape( + (x[..., None, :] * scales[:, None]), list(x.shape[:-1]) + [-1] + ) + four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1)) + return jnp.concatenate([x] + [four_feat], axis=-1) + + +class DensityNerf(nn.Module): + """Density NeRF.""" + + scale: float = 1.0 + + min_degree: int = 0 + max_degree: int = 10 + + @nn.compact + def __call__(self, position): + pos_encoding = posenc( + position / self.scale * jnp.pi, self.min_degree, self.max_degree + ) + x = pos_encoding + x = nn.Dense(64)(x) + x = nn.relu(x) + x = nn.Dense(64)(x) + x = nn.relu(x) + density = nn.softplus(nn.Dense(1)(x))[0] + return density, x + + +class AppearanceNerf(nn.Module): + """Appearance NeRF.""" + + scale: float = 1.0 + + min_degree: int = 0 + max_degree: int = 10 + + @nn.compact + def __call__(self, position, viewdir, density): + pos_encoding = posenc( + position / self.scale * jnp.pi, self.min_degree, self.max_degree + ) + view_encoding = posenc(viewdir * jnp.pi, self.min_degree, self.max_degree) + x = jnp.concatenate([pos_encoding, view_encoding, density], -1) + x = nn.Dense(64)(x) + x = nn.relu(x) + x = nn.Dense(64)(x) + x = nn.relu(x) + rgb = nn.sigmoid(nn.Dense(3)(x)) + return rgb + + +class TwoPartNerf(nn.Module): + """Combines the density and appearance NeRF.""" + + grid_size: int = 128 + scale: float = 1.0 + + min_degree: int = 0 + max_degree: int = 10 + appearance_min_degree: int = 0 + appearance_max_degree: int = 10 + + @nn.compact + def __call__(self, ray_sample): + density, _ = DensityNerf(self.scale, self.min_degree, self.max_degree)( + ray_sample.position + ) + rgb = AppearanceNerf( + self.scale, self.appearance_min_degree, self.appearance_max_degree + )(ray_sample.position, ray_sample.viewdir, density[jnp.newaxis]) + + # For foam rendering, convert density to alpha assuming constant density. + return (-jnp.expm1(-density / self.grid_size), rgb), () + + +def _map_to_params(flat_params, template): + leaves, treedef = jax.tree_util.tree_flatten(template) + new_leaves = [] + param_index = 0 + for p in leaves: + num_params = np.prod(p.shape) + new_leaves.append( + flat_params[param_index : param_index + num_params].reshape(p.shape) + ) + param_index += num_params + return jax.tree_util.tree_unflatten(treedef, new_leaves) + + +def _map_hidden_to_params(h, template): + leaves = jax.tree_util.tree_leaves(template) + total_num_params = sum([np.prod(p.shape) for p in leaves]) + flat_params = nn.Dense(total_num_params)(h) + return _map_to_params(flat_params, template) + + +class Hypernet(nn.Module): + """Hypernet mapping from latent to NeRF weights.""" + + # The parameter shapes for the network the hypernetwork makes. + output_template: Any + width: int = 512 + depth: int = 2 + num_outputs: int = 1 + + @nn.compact + def __call__(self, latent): + """Maps from latent vector to parameters to a neural net.""" + + for _ in range(self.depth): + latent = nn.relu(nn.Dense(self.width)(latent)) + outputs = tuple([ + _map_hidden_to_params(latent, self.output_template) + for _ in range(self.num_outputs) + ]) + if self.num_outputs == 1: + return outputs[0] + else: + return outputs + + +class DecoderHypernet(nn.Module): + """Decoder hypernet.""" + + # The parameter shapes for the decoder. + decoder_template: Any + + density_width: int = 512 + density_depth: int = 2 + + appearance_width: int = 512 + appearance_depth: int = 2 + + @nn.compact + def __call__(self, latent): + """Maps from latent vector to parameters to `TwoPartNerf`.""" + + # Rescale latent elementwise to make it easier to match a N(0, I) prior. + latent = latent * jnp.exp( + self.param('latent_scale', nn.initializers.zeros, latent.shape[-1]) + ) + density_latent, appearance_latent = latent.reshape([2, -1]) + + density_params, _ = Hypernet( + self.decoder_template['params']['DensityNerf_0'], + self.density_width, + self.density_depth, + 2, + )(density_latent) + + appearance_latent = jnp.concatenate([density_latent, appearance_latent], -1) + appearance_params, _ = Hypernet( + self.decoder_template['params']['AppearanceNerf_0'], + self.appearance_width, + self.appearance_depth, + 2, + )(appearance_latent) + + params = self.decoder_template.unfreeze() + params['params']['DensityNerf_0'] = density_params + params['params']['AppearanceNerf_0'] = appearance_params + + return nn.FrozenDict(params) + + +def _split(x): + flat_split_x = jnp.transpose(x.reshape([-1, 2, x.shape[-1] // 2]), [1, 0, 2]) + split_x = flat_split_x.reshape([2, *x.shape[:-1], x.shape[-1] // 2]) + x1, x2 = split_x + return x1, x2 + + +class RealNVPLayer(nn.Module): + """One RealNVP layer.""" + + hidden_width: int + ndims: int + + def setup(self): + kernel_init = nn.initializers.normal(0.01) + self.hid1 = nn.Dense(self.hidden_width, kernel_init=kernel_init) + self.hid2 = nn.Dense(self.hidden_width, kernel_init=kernel_init) + self.shift_and_scale1 = nn.Dense(self.ndims, kernel_init=kernel_init) + self.shift_and_scale2 = nn.Dense(self.ndims, kernel_init=kernel_init) + + @nn.compact + def __call__(self, x, forward=True): + permutation = jax.random.permutation(jax.random.PRNGKey(0), x.shape[-1]) + if not forward: + permutation = jnp.argsort(permutation) + x = x[..., permutation] + + x1, x2 = _split(x) + ldj = 0.0 + if forward: + h = nn.relu(self.hid1(x1)) + shift, log_scale = _split(self.shift_and_scale1(h)) + x2 = x2 * jnp.exp(log_scale) + x2 = x2 + shift + ldj += log_scale.sum(-1) + + h = nn.relu(self.hid2(x2)) + shift, log_scale = _split(self.shift_and_scale2(h)) + x1 = x1 * jnp.exp(log_scale) + x1 = x1 + shift + ldj += log_scale.sum(-1) + + x = jnp.concatenate([x1, x2], -1) + x = x[..., permutation] + else: + h = nn.relu(self.hid2(x2)) + shift, log_scale = _split(self.shift_and_scale2(h)) + x1 = x1 - shift + x1 = x1 * jnp.exp(-log_scale) + ldj -= log_scale.sum(-1) + + h = nn.relu(self.hid1(x1)) + shift, log_scale = _split(self.shift_and_scale1(h)) + x2 = x2 - shift + x2 = x2 * jnp.exp(-log_scale) + ldj -= log_scale.sum(-1) + + x = jnp.concatenate([x1, x2], -1) + return x, ldj + + +class RealNVPStack(nn.Module): + """Stack of RealNVPs.""" + + ndims: int + hidden_width: int = 512 + depth: int = 2 + + def setup(self): + self.layers = [ + RealNVPLayer(self.hidden_width, self.ndims) for _ in range(self.depth) + ] + + @nn.compact + def __call__(self, x, forward=True): + ldj = 0.0 + layers = self.layers if forward else self.layers[::-1] + for layer in layers: + x, new_ldj = layer(x, forward=forward) + ldj += new_ldj + return x, ldj + + +class Guide(nn.Module): + """Guide / recognition model.""" + + latent_dim: int + decoder_template: Any + + @nn.compact + def __call__(self, images, cameras, seed): + # Encode images + h = images + h = nn.Conv(16, (3, 3), (2, 2))(h) + h = nn.relu(h) + h = nn.Conv(32, (3, 3), (2, 2))(h) + h = nn.relu(h) + h = nn.Conv(64, (3, 3), (2, 2))(h) + h = nn.relu(h) + h = nn.Conv(128, (3, 3), (2, 2))(h) + h = nn.relu(h) + h = nn.Conv(256, (3, 3), (2, 2))(h) + h = nn.avg_pool(h, (2, 2), (2, 2)) + image_h = h.reshape([h.shape[0], -1]) + + h = cameras.reshape([cameras.shape[0], -1]) + h = nn.Dense(512)(h) + h = nn.relu(h) + camera_h = nn.Dense(512)(h) + + h = jnp.concatenate([image_h, camera_h], -1) + + # Probabilistic aggregation. + def aggregate(locs, log_precisions): + log_precisions = jnp.minimum(10.0, log_precisions) + precisions = jnp.exp(log_precisions) + loc = (locs * precisions).sum(0) / precisions.sum(0) + log_scale = -0.5 * jsp.special.logsumexp(log_precisions, 0) + return loc, log_scale + + z_latent_params = nn.Dense(2 * self.latent_dim)(h).reshape( + [-1, 2, self.latent_dim] + ) + z_locs, z_log_precisions = ( + z_latent_params[..., 0, :], + z_latent_params[..., 1, :], + ) + + # Add potential for the realnvp prior. + prior_loc = self.param('prior_loc', nn.initializers.zeros, self.latent_dim) + prior_log_precision = self.param( + 'prior_log_precision', nn.initializers.zeros, self.latent_dim + ) + z_locs = jnp.concatenate([z_locs, 0 * z_locs[:1] + prior_loc], 0) + z_log_precisions = jnp.concatenate( + [z_log_precisions, 0 * z_locs[:1] + prior_log_precision], 0 + ) + + z_locs, z_log_scales = aggregate(z_locs, z_log_precisions) + + z = tfd.Normal(z_locs, jnp.exp(z_log_scales)).sample((), seed) + z_log_prob = ( + tfd.Normal( + lax.stop_gradient(z_locs), lax.stop_gradient(jnp.exp(z_log_scales)) + ) + .log_prob(z) + .sum() + ) + + h = jnp.concatenate([h, z + jnp.zeros([h.shape[0], 1])], -1) + + return (z, (z_locs, z_log_scales), z_log_prob) diff --git a/discussion/robust_inverse_graphics/saving.py b/discussion/robust_inverse_graphics/saving.py new file mode 100644 index 0000000000..f37f5a1ff2 --- /dev/null +++ b/discussion/robust_inverse_graphics/saving.py @@ -0,0 +1,72 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Saving/loading code.""" + +from typing import Any, BinaryIO, Mapping, TypeVar + +import immutabledict +from discussion.robust_inverse_graphics.util import tree2 +from fun_mc import using_jax as fun_mc + + +try: + # This module doesn't exist at the time static analysis is done. + # pylint: disable=g-import-not-at-top + from fun_mc.dynamic.backend_jax import fun_mc_lib # pytype: disable=import-error +except ImportError: + pass + +__all__ = [ + 'enable_interactive_mode', + 'load', + 'register', + 'save', +] + +_registry = tree2.Registry(allow_unknown_types=True) +_registry.auto_register_type('_TraceMaskHolder')(fun_mc_lib._TraceMaskHolder) # pylint: disable=protected-access +_registry.auto_register_type('AdamState')(fun_mc.AdamState) +_registry.auto_register_type('InterruptibleTraceState')( + fun_mc.InterruptibleTraceState +) + + +T = TypeVar('T') + + +def enable_interactive_mode(): + """Enables interactive mode (for notebook use).""" + _registry.interactive_mode = True + + +def register(tree_type: type[T]) -> type[T]: + """Registers a RobustVision type.""" + return _registry.auto_register_type(f'rig.{tree_type.__name__}')(tree_type) + + +def save( + tree: Any, + path: str | BinaryIO, + options: Mapping[str, Any] = immutabledict.immutabledict({}), +): + """Saves a tree.""" + _registry.save_tree(tree, path, options) + + +def load( + path: str | BinaryIO, + options: Mapping[str, Any] = immutabledict.immutabledict({}), +) -> Any: + """Loads a tree.""" + return _registry.load_tree(path, options) diff --git a/discussion/robust_inverse_graphics/saving_test.py b/discussion/robust_inverse_graphics/saving_test.py new file mode 100644 index 0000000000..8cf1d79237 --- /dev/null +++ b/discussion/robust_inverse_graphics/saving_test.py @@ -0,0 +1,50 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import os + +from discussion.robust_inverse_graphics import saving +from discussion.robust_inverse_graphics.util import test_util +from fun_mc import using_jax as fun_mc + + +class SavingTest(test_util.TestCase): + + def test_interruptible_trace_state_saving(self): + def fun(x, y): + x = x + 1.0 + y = y + 2.0 + return (x, y), (x, y) + + state, _ = fun_mc.trace( + state=fun_mc.interruptible_trace_init((0.0, 0.0), fn=fun, num_steps=5), + fn=functools.partial(fun_mc.interruptible_trace_step, fn=fun), + num_steps=4, + ) + + out_dir = self.create_tempdir() + path = os.path.join(out_dir, 'test.tree2') + + saving.save(state, path) + state = saving.load(path) + + # This line would only work if the right type got loaded. + x_trace, y_trace = state.trace() + + self.assertAllEqual([1.0, 2.0, 3.0, 4.0], x_trace) + self.assertAllEqual([2.0, 4.0, 6.0, 8.0], y_trace) + + +if __name__ == '__main__': + test_util.main()