Skip to content

Commit

Permalink
Make spinoffs/autobnn a pip installable package.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619748042
  • Loading branch information
ursk authored and tensorflower-gardener committed Mar 28, 2024
1 parent a77f8dd commit 377d688
Show file tree
Hide file tree
Showing 24 changed files with 259 additions and 89 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ py_library(
":operators",
":training_util",
":util",
"//tensorflow_probability/python/internal:all_util",
# tensorflow_probability/python/internal:all_util dep,
],
)

Expand All @@ -49,7 +49,7 @@ py_library(
# flax:core dep,
# jax dep,
# jaxtyping dep,
"//tensorflow_probability/python/distributions:distribution.jax",
# tensorflow_probability/python/distributions:distribution.jax dep,
],
)

Expand All @@ -62,8 +62,8 @@ py_test(
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
"//tensorflow_probability:jax",
"//tensorflow_probability/python/distributions:lognormal.jax",
"//tensorflow_probability/python/distributions:normal.jax",
# tensorflow_probability/python/distributions:lognormal.jax dep,
# tensorflow_probability/python/distributions:normal.jax dep,
],
)

Expand Down Expand Up @@ -118,7 +118,7 @@ py_test(
":estimators",
":kernels",
":operators",
"//tensorflow_probability/python/internal:test_util",
# tensorflow_probability/python/internal:test_util dep,
],
)

Expand All @@ -130,10 +130,10 @@ py_library(
# flax dep,
# flax:core dep,
# jax dep,
"//tensorflow_probability/python/distributions:lognormal.jax",
"//tensorflow_probability/python/distributions:normal.jax",
"//tensorflow_probability/python/distributions:student_t.jax",
"//tensorflow_probability/python/distributions:uniform.jax",
# tensorflow_probability/python/distributions:lognormal.jax dep,
# tensorflow_probability/python/distributions:normal.jax dep,
# tensorflow_probability/python/distributions:student_t.jax dep,
# tensorflow_probability/python/distributions:uniform.jax dep,
],
)

Expand All @@ -147,7 +147,7 @@ py_test(
# absl/testing:parameterized dep,
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
"//tensorflow_probability/python/distributions:lognormal.jax",
# tensorflow_probability/python/distributions:lognormal.jax dep,
],
)

Expand All @@ -158,14 +158,14 @@ py_library(
# flax:core dep,
# jax dep,
# jaxtyping dep,
"//tensorflow_probability/python/bijectors:softplus.jax",
"//tensorflow_probability/python/distributions:distribution.jax",
"//tensorflow_probability/python/distributions:inflated.jax",
"//tensorflow_probability/python/distributions:logistic.jax",
"//tensorflow_probability/python/distributions:lognormal.jax",
"//tensorflow_probability/python/distributions:negative_binomial.jax",
"//tensorflow_probability/python/distributions:normal.jax",
"//tensorflow_probability/python/distributions:transformed_distribution.jax",
# tensorflow_probability/python/bijectors:softplus.jax dep,
# tensorflow_probability/python/distributions:distribution.jax dep,
# tensorflow_probability/python/distributions:inflated.jax dep,
# tensorflow_probability/python/distributions:logistic.jax dep,
# tensorflow_probability/python/distributions:lognormal.jax dep,
# tensorflow_probability/python/distributions:negative_binomial.jax dep,
# tensorflow_probability/python/distributions:normal.jax dep,
# tensorflow_probability/python/distributions:transformed_distribution.jax dep,
],
)

Expand Down Expand Up @@ -216,14 +216,14 @@ py_library(
":likelihoods",
# flax:core dep,
# jax dep,
"//tensorflow_probability/python/bijectors:chain.jax",
"//tensorflow_probability/python/bijectors:scale.jax",
"//tensorflow_probability/python/bijectors:shift.jax",
"//tensorflow_probability/python/distributions:beta.jax",
"//tensorflow_probability/python/distributions:dirichlet.jax",
"//tensorflow_probability/python/distributions:half_normal.jax",
"//tensorflow_probability/python/distributions:normal.jax",
"//tensorflow_probability/python/distributions:transformed_distribution.jax",
# tensorflow_probability/python/bijectors:chain.jax dep,
# tensorflow_probability/python/bijectors:scale.jax dep,
# tensorflow_probability/python/bijectors:shift.jax dep,
# tensorflow_probability/python/distributions:beta.jax dep,
# tensorflow_probability/python/distributions:dirichlet.jax dep,
# tensorflow_probability/python/distributions:half_normal.jax dep,
# tensorflow_probability/python/distributions:normal.jax dep,
# tensorflow_probability/python/distributions:transformed_distribution.jax dep,
],
)

Expand All @@ -242,7 +242,7 @@ py_test(
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
# numpy dep,
"//tensorflow_probability/python/distributions:distribution.jax",
# tensorflow_probability/python/distributions:distribution.jax dep,
],
)

Expand All @@ -258,7 +258,6 @@ py_library(
# matplotlib dep,
# numpy dep,
# pandas dep,
"//tensorflow_probability/python/experimental/timeseries:metrics",
],
)

Expand All @@ -276,7 +275,7 @@ py_test(
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
# numpy dep,
"//tensorflow_probability/python/internal:test_util",
# tensorflow_probability/python/internal:test_util dep,
],
)

Expand All @@ -288,7 +287,7 @@ py_library(
# jax dep,
# numpy dep,
# scipy dep,
"//tensorflow_probability/python/distributions:distribution.jax",
# tensorflow_probability/python/distributions:distribution.jax dep,
],
)

Expand All @@ -301,6 +300,6 @@ py_test(
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
# numpy dep,
"//tensorflow_probability/python/internal:test_util",
# tensorflow_probability/python/internal:test_util dep,
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,17 @@
# ============================================================================
"""Package for training GP-like Bayesian Neural Nets w/ composite structure."""

from tensorflow_probability.python.internal import all_util
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import bnn_tree
from tensorflow_probability.spinoffs.autobnn import estimators
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import models
from tensorflow_probability.spinoffs.autobnn import operators
from tensorflow_probability.spinoffs.autobnn import training_util
from tensorflow_probability.spinoffs.autobnn import util
from autobnn import bnn
from autobnn import bnn_tree
from autobnn import estimators
from autobnn import kernels
from autobnn import likelihoods
from autobnn import models
from autobnn import operators
from autobnn import training_util
from autobnn import util

_allowed_symbols = [
__all__ = [
'bnn',
'bnn_tree',
'estimators',
Expand All @@ -36,5 +35,3 @@
'training_util',
'util',
]

all_util.remove_undocumented(__name__, _allowed_symbols)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PyTree # pylint: disable=g-importing-member,g-multiple-import
from tensorflow_probability.spinoffs.autobnn import likelihoods
from autobnn import likelihoods
from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flax import linen as nn
import jax
import jax.numpy as jnp
from tensorflow_probability.spinoffs.autobnn import bnn
from autobnn import bnn
from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib
from tensorflow_probability.substrates.jax.distributions import normal as normal_lib
from absl.testing import absltest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from flax import linen as nn
import jax
import jax.numpy as jnp
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import operators
from tensorflow_probability.spinoffs.autobnn import util
from autobnn import bnn
from autobnn import kernels
from autobnn import operators
from autobnn import util

Array = jnp.ndarray

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from flax import linen as nn
import jax
import jax.numpy as jnp
from tensorflow_probability.spinoffs.autobnn import bnn_tree
from tensorflow_probability.spinoffs.autobnn import kernels
from autobnn import bnn_tree
from autobnn import kernels
from absl.testing import absltest


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike, PyTree # pylint: disable=g-importing-member,g-multiple-import
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import models
from tensorflow_probability.spinoffs.autobnn import training_util
from autobnn import bnn
from autobnn import likelihoods
from autobnn import models
from autobnn import training_util


class _AutoBnnEstimator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import jax
import numpy as np
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.spinoffs.autobnn import estimators
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import operators
from tensorflow_probability.spinoffs.autobnn import util
from autobnn import estimators
from autobnn import kernels
from autobnn import operators
from autobnn import util


class AutoBNNTest(test_util.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flax.linen import initializers
import jax
import jax.numpy as jnp
from tensorflow_probability.spinoffs.autobnn import bnn
from autobnn import bnn
from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib
from tensorflow_probability.substrates.jax.distributions import normal as normal_lib
from tensorflow_probability.substrates.jax.distributions import student_t as student_t_lib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import util
from autobnn import kernels
from autobnn import util
from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib

from absl.testing import absltest
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from absl.testing import parameterized
import jax.numpy as jnp
from tensorflow_probability.spinoffs.autobnn import likelihoods
from autobnn import likelihoods
from absl.testing import absltest


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import functools
from typing import Sequence, Union
import jax.numpy as jnp
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import bnn_tree
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import operators
from autobnn import bnn
from autobnn import bnn_tree
from autobnn import kernels
from autobnn import likelihoods
from autobnn import operators


Array = jnp.ndarray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from tensorflow_probability.spinoffs.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import models
from autobnn import likelihoods
from autobnn import models
from absl.testing import absltest


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from flax import linen as nn
import jax
import jax.numpy as jnp
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import likelihoods
from autobnn import bnn
from autobnn import likelihoods
from tensorflow_probability.substrates.jax.bijectors import chain as chain_lib
from tensorflow_probability.substrates.jax.bijectors import scale as scale_lib
from tensorflow_probability.substrates.jax.bijectors import shift as shift_lib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import operators
from tensorflow_probability.spinoffs.autobnn import util
from autobnn import kernels
from autobnn import operators
from autobnn import util
from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib
from absl.testing import absltest

Expand Down
Loading

0 comments on commit 377d688

Please sign in to comment.