Skip to content

Commit

Permalink
Add vectorized interface stubs to pyro.markov (#2172)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored and fritzo committed Nov 16, 2019
1 parent 92bca2b commit 363c8f8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
13 changes: 9 additions & 4 deletions pyro/poutine/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def _fn(*args, **kwargs):
return wrapper(fn) if fn is not None else wrapper


def markov(fn=None, history=1, keep=False):
def markov(fn=None, history=1, keep=False, dim=None, name=None):
"""
Markov dependency declaration.
Expand All @@ -500,15 +500,20 @@ def markov(fn=None, history=1, keep=False):
when branching: if ``keep=True``, neighboring branches at the same
level can depend on each other; if ``keep=False``, neighboring branches
are independent (conditioned on their share"
:param int dim: An optional dimension to use for this independence index.
Interface stub, behavior not yet implemented.
:param str name: An optional unique name to help inference algorithms match
:func:`pyro.markov` sites between models and guides.
Interface stub, behavior not yet implemented.
"""
if fn is None:
# Used as a decorator with bound args
return MarkovMessenger(history=history, keep=keep)
return MarkovMessenger(history=history, keep=keep, dim=dim, name=name)
if not callable(fn):
# Used as a generator
return MarkovMessenger(history=history, keep=keep).generator(iterable=fn)
return MarkovMessenger(history=history, keep=keep, dim=dim, name=name).generator(iterable=fn)
# Used as a decorator with bound args
return MarkovMessenger(history=history, keep=keep)(fn)
return MarkovMessenger(history=history, keep=keep, dim=dim, name=name)(fn)


class _SeedMessenger(Messenger):
Expand Down
15 changes: 14 additions & 1 deletion pyro/poutine/markov_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,24 @@ class MarkovMessenger(ReentrantMessenger):
when branching: if ``keep=True``, neighboring branches at the same
level can depend on each other; if ``keep=False``, neighboring branches
are independent (conditioned on their shared ancestors).
:param int dim: An optional dimension to use for this independence index.
Interface stub, behavior not yet implemented.
:param str name: An optional unique name to help inference algorithms match
:func:`pyro.markov` sites between models and guides.
Interface stub, behavior not yet implemented.
"""
def __init__(self, history=1, keep=False):
def __init__(self, history=1, keep=False, dim=None, name=None):
assert history >= 0
self.history = history
self.keep = keep
self.dim = dim
self.name = name
if dim is not None:
raise NotImplementedError(
"vectorized markov not yet implemented, try setting dim to None")
if name is not None:
raise NotImplementedError(
"vectorized markov not yet implemented, try setting name to None")
self._iterable = None
self._pos = -1
self._stack = []
Expand Down

0 comments on commit 363c8f8

Please sign in to comment.