Skip to content

Commit

Permalink
Merge pull request #1191 from jburnim/r0.12
Browse files Browse the repository at this point in the history
Prepare branch for the TFP 0.12.0rc4 release
  • Loading branch information
jburnim authored Dec 9, 2020
2 parents ed47dda + 3de3fe0 commit cc2c37e
Show file tree
Hide file tree
Showing 141 changed files with 8,243 additions and 3,054 deletions.
70 changes: 70 additions & 0 deletions .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2020 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
name: Tests
on: [push, pull_request]
env:
TEST_VENV_PATH: ~/test_virtualenv
jobs:
lints:
name: Lints
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7]
steps:
- name: Checkout
uses: actions/checkout@v1
with:
fetch-depth: 20
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Setup virtualenv
run: |
sudo apt install virtualenv
virtualenv -p python${{ matrix.python-version }} ${TEST_VENV_PATH}
- name: Lints
run: |
source ${TEST_VENV_PATH}/bin/activate
./testing/run_github_lints.sh
tests:
name: Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7]
shard: [0, 1, 2, 3, 4]
env:
TEST_VENV_PATH: ~/test_virtualenv
SHARD: ${{ matrix.shard }}
NUM_SHARDS: 5
steps:
- name: Checkout
uses: actions/checkout@v1
with:
fetch-depth: 1
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Setup virtualenv
run: |
sudo apt install virtualenv
virtualenv -p python${{ matrix.python-version }} ${TEST_VENV_PATH}
- name: Tests
run: |
source ${TEST_VENV_PATH}/bin/activate
./testing/run_github_tests.sh
56 changes: 0 additions & 56 deletions .travis.yml

This file was deleted.

19 changes: 9 additions & 10 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,19 @@ repository (with credit to the original author) and closes the pull request.

## Continuous Integration

We use [Travis CI](https://travis-ci.org/tensorflow/probability) to do automated
style checking and run unit-tests (discussed in more detail below). A build
will be triggered when you open a pull request, or update the pull request by
adding a commit, rebasing etc.
We use [GitHub Actions](https://github.com/tensorflow/probability/actions) to do
automated style checking and run unit-tests (discussed in more detail below). A
build will be triggered when you open a pull request, or update the pull request
by adding a commit, rebasing etc.

We test against TensorFlow nightly on Python 2.7 and 3.6. We shard our tests
We test against TensorFlow nightly on Python 3.7. We shard our tests
across several build jobs (identified by the `SHARD` environment variable).
Linting, in particular, is only done on the first shard, so look at that shard's
logs for lint errors if any.
Lints are also done in a separate job.

All pull-requests will need to pass the automated lint and unit-tests before
being merged. As Travis-CI tests can take a bit of time, see the following
sections on how to run the lint checks and unit-tests locally while you're
developing your change.
being merged. As the tests can take a bit of time, see the following sections
on how to run the lint checks and unit-tests locally while you're developing
your change.

## Style

Expand Down
2 changes: 1 addition & 1 deletion discussion/fun_mcmc/prefab.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def kernel(adaptive_hmc_state):
hmc_state.state,
axis=tuple(range(chain_ndims)) if chain_ndims else None,
window_size=int(np.prod(hmc_state.target_log_prob.shape)) *
variance_window_steps)
variance_window_steps) # pytype: disable=wrong-arg-types

if num_adaptation_steps is not None:
# Take care of adaptation for variance and step size.
Expand Down
5 changes: 1 addition & 4 deletions spinoffs/inference_gym/inference_gym/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# A package for target densities and benchmarking of inference algorithms
# against the same.

# [internal] load pytype.bzl (pytype_library, pytype_strict_library)
# [internal] load pytype.bzl (pytype_strict_library)
# [internal] load dummy dependency

package(
Expand All @@ -42,7 +42,6 @@ py_library(
],
)

# pytype
py_library(
name = "using_numpy",
srcs = ["using_numpy.py"],
Expand All @@ -56,7 +55,6 @@ py_library(
],
)

# pytype
py_library(
name = "using_jax",
srcs = ["using_jax.py"],
Expand All @@ -71,7 +69,6 @@ py_library(
],
)

# pytype
py_library(
name = "using_tensorflow",
srcs = ["using_tensorflow.py"],
Expand Down
15 changes: 10 additions & 5 deletions spinoffs/oryx/oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,22 +333,27 @@ def process_higher_order_primitive(self, primitive, f, tracers, params,
if is_map:
# TODO(sharadmv): figure out if invars are mapped or unmapped
params = params.copy()
out_axes_thunk = params['out_axes_thunk']
@jax_util.as_hashable_function(closure=('harvest', out_axes_thunk))
def new_out_axes_thunk():
out_axes = out_axes_thunk()
assert all(out_axis == 0 for out_axis in out_axes)
return (0,) * out_tree().num_leaves
new_params = dict(
params,
in_axes=(0,) * len(tree_util.tree_leaves(plants)) +
params['in_axes'])
in_axes=(0,) * len(tree_util.tree_leaves(plants)) + params['in_axes'],
out_axes_thunk=new_out_axes_thunk)
else:
new_params = dict(params)
all_args, all_tree = tree_util.tree_flatten((plants, vals))
num_plants = len(all_args) - len(vals)
if 'donated_invars' in params:
new_params['donated_invars'] = ((False,) * num_plants
+ params['donated_invars'])
f, aux = harvest_eval(f, self, context.settings, all_tree)
f, out_tree = harvest_eval(f, self, context.settings, all_tree)
out_flat = primitive.bind(
f, *all_args, **new_params, name=jax_util.wrap_name(name, 'harvest'))
out_tree = aux()
out, reaps = tree_util.tree_unflatten(out_tree, out_flat)
out, reaps = tree_util.tree_unflatten(out_tree(), out_flat)
out_tracers = safe_map(self.pure, out)
reap_tracers = tree_util.tree_map(self.pure, reaps)
if primitive is nest_p and reap_tracers:
Expand Down
10 changes: 8 additions & 2 deletions spinoffs/oryx/oryx/core/interpreters/inverse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def wrapped(*args, **kwargs):
flat_incells = [InverseAndILDJ.unknown(aval) for aval in flat_forward_avals]
flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
flat_constcells, flat_incells, flat_outcells)
flat_constcells, flat_incells, flat_outcells) # pytype: disable=wrong-arg-types
flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
if any(not flat_incell.top() for flat_incell in flat_incells):
raise ValueError('Cannot invert function.')
Expand Down Expand Up @@ -332,7 +332,7 @@ def hop_inverse_rule(prim):
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_):
const_cells, incells = jax_util.split_list(incells, [num_consts])
env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, const_cells,
incells, outcells)
incells, outcells) # pytype: disable=wrong-arg-types
new_incells = [env.read(invar) for invar in jaxpr.invars]
new_outcells = [env.read(outvar) for outvar in jaxpr.outvars]
return const_cells + new_incells, new_outcells, None
Expand Down Expand Up @@ -377,6 +377,12 @@ def remove_slice(cell):
new_params = dict(params, in_axes=new_in_axes)
if 'donated_invars' in params:
new_params['donated_invars'] = (False,) * len(flat_vals)
if 'out_axes' in params:
assert all(out_axis == 0 for out_axis in params['out_axes'])
new_params['out_axes_thunk'] = jax_util.HashableFunction(
lambda: (0,) * aux().num_leaves,
closure=('ildj', params['out_axes']))
del new_params['out_axes']
subenv_vals = prim.bind(f, *flat_vals, **new_params)
subenv_tree = aux()
subenv = tree_util.tree_unflatten(subenv_tree, subenv_vals)
Expand Down
8 changes: 8 additions & 0 deletions spinoffs/oryx/oryx/core/interpreters/inverse/inverse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ def f(x, y):
onp.testing.assert_allclose(y, np.ones(2))
onp.testing.assert_allclose(ildj_, 0., atol=1e-6, rtol=1e-6)

def test_inverse_of_reshape(self):
def f(x):
return np.reshape(x, (4,))
f_inv = core.inverse_and_ildj(f, np.ones((2, 2)))
x, ildj_ = f_inv(np.ones(4))
onp.testing.assert_allclose(x, np.ones((2, 2)))
onp.testing.assert_allclose(ildj_, 0.)

def test_sigmoid_ildj(self):
def naive_sigmoid(x):
# This is the default JAX implementation of sigmoid.
Expand Down
3 changes: 1 addition & 2 deletions spinoffs/oryx/oryx/core/interpreters/inverse/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,8 @@ def reshape_ildj(incells, outcells, **params):
))], None
elif outcell.top() and not incell.top():
val = outcell.val
ndslice = NDSlice.new(np.reshape(val, incell.aval.shape))
new_incells = [
InverseAndILDJ(incell.aval, [ndslice])
InverseAndILDJ.new(np.reshape(val, incell.aval.shape))
]
return new_incells, outcells, None
return incells, outcells, None
Expand Down
46 changes: 32 additions & 14 deletions spinoffs/oryx/oryx/core/interpreters/unzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,29 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
in_pvals = [pval if pval.is_known() or in_axis is None else
unknown(mapped_aval(params['axis_size'], in_axis, pval[0]))
for pval, in_axis in zip(in_pvals, params['in_axes'])]
out_axes_thunk = params['out_axes_thunk']
@jax_util.as_hashable_function(closure=('unzip', out_axes_thunk))
def new_out_axes_thunk():
out_axes = out_axes_thunk()
assert all(out_axis == 0 for out_axis in out_axes)
_, num_outputs, _ = aux()
return (0,) * num_outputs
new_params = dict(params, out_axes_thunk=new_out_axes_thunk)
else:
new_params = params
pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
keys = tuple(t.is_key() for t in tracers)
new_settings = UnzipSettings(settings.tag, call_primitive in block_registry)
fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
out_flat = call_primitive.bind(fun, *in_consts, **params)
success, results = aux()
out_flat = call_primitive.bind(fun, *in_consts, **new_params)
success, _, results = aux()
if not success:
out_pvs, out_keys, jaxpr, env = results
out_pv_consts, consts = jax_util.split_list(out_flat, [len(out_pvs)])
out_tracers = self._bound_output_tracers(call_primitive, params, jaxpr,
consts, env, tracers, out_pvs,
out_pv_consts, out_keys, name,
is_map)
out_tracers = self._bound_output_tracers(call_primitive, new_params,
jaxpr, consts, env, tracers,
out_pvs, out_pv_consts,
out_keys, name, is_map)
return out_tracers
init_name = jax_util.wrap_name(name, 'init')
apply_name = jax_util.wrap_name(name, 'apply')
Expand All @@ -319,15 +329,16 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
[len(apply_pvs)])

variable_tracers = self._bound_output_tracers(
call_primitive, params, init_jaxpr, init_consts, init_env, key_tracers,
init_pvs, init_pv_consts, [True] * len(init_pvs), init_name, is_map)
call_primitive, new_params, init_jaxpr, init_consts, init_env,
key_tracers, init_pvs, init_pv_consts, [True] * len(init_pvs),
init_name, is_map)

unflat_variables = tree_util.tree_unflatten(variable_tree, variable_tracers)
if call_primitive is harvest.nest_p:
variable_dict = harvest.sow(
dict(safe_zip(variable_names, unflat_variables)),
tag=settings.tag,
name=params['scope'],
name=new_params['scope'],
mode='strict')
unflat_variables = tuple(variable_dict[name] for name in variable_names)
else:
Expand All @@ -342,7 +353,7 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
variable_tracers = tree_util.tree_leaves(unflat_variables)

out_tracers = self._bound_output_tracers(
call_primitive, params, apply_jaxpr, apply_consts, apply_env,
call_primitive, new_params, apply_jaxpr, apply_consts, apply_env,
variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts,
apply_keys, apply_name, is_map)
return out_tracers
Expand All @@ -365,6 +376,11 @@ def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
tuple(v for v, t in zip(params['donated_invars'], in_tracers)
if not t.pval.is_known()))
new_params['donated_invars'] = new_donated_invars
if is_map:
out_axes = params['out_axes_thunk']()
assert all(out_axis == 0 for out_axis in out_axes)
new_params['out_axes'] = (0,) * len(out_tracers)
del new_params['out_axes_thunk']
eqn = pe.new_eqn_recipe(
tuple(const_tracers + env_tracers + in_tracers), out_tracers, primitive,
new_params, source_info_util.current()) # pytype: disable=wrong-arg-types
Expand Down Expand Up @@ -442,14 +458,16 @@ def unzip_eval_wrapper(pvs, *consts):
out = (
tuple(init_pv_consts) + tuple(init_consts) + tuple(apply_pv_consts) +
tuple(apply_consts))
yield out, (success, ((init_pvs, len(init_consts), apply_pvs),
(init_jaxpr, apply_jaxpr), (init_env,
apply_env), metadata))
yield out, (success, len(out),
((init_pvs, len(init_consts), apply_pvs),
(init_jaxpr, apply_jaxpr),
(init_env, apply_env),
metadata))
else:
jaxpr, (out_pvals, out_keys, consts, env) = result
out_pvs, out_consts = jax_util.unzip2(out_pvals)
out = tuple(out_consts) + tuple(consts)
yield out, (success, (out_pvs, out_keys, jaxpr, env))
yield out, (success, len(out), (out_pvs, out_keys, jaxpr, env))


@lu.transformation
Expand Down
2 changes: 1 addition & 1 deletion spinoffs/oryx/oryx/core/state/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,4 @@ def variables(self) -> Dict[str, Any]:

@ppl.log_prob.register(Module)
def module_log_prob(module, *args, **kwargs):
return log_prob.log_prob(module, *args, **kwargs)
return log_prob.log_prob(module, *args, **kwargs) # pytype: disable=wrong-arg-count
Loading

0 comments on commit cc2c37e

Please sign in to comment.