Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsafe variable naming in JointDistributionCoroutine #1827

Closed
chrism0dwk opened this issue Jul 30, 2024 · 5 comments
Closed

Unsafe variable naming in JointDistributionCoroutine #1827

chrism0dwk opened this issue Jul 30, 2024 · 5 comments

Comments

@chrism0dwk
Copy link
Contributor

The JointDistributionCoroutine class is a fantastic way of constructing a joint probability distribution, particularly with random variables that depend on each other. The class allows you to name your variable, and outputs a joint sample as a StructTuple (which AFAICS is just a namedtuple under the hood).

However, you have to be careful how you name your variables. Consider the trivial example with a variable inadvertently named count:

import tensorflow_probability as tfp
tfd = tfp.distributions

@tfd.JointDistributionCoroutine
def model():
    yield tfd.Uniform(low=0., high=1., name="count")

For this model, calling model.sample() apparently returns a StructTuple with a count field as expected:

>>> sample = model.sample()
>>> sample
StructTuple(
  count=<tf.Tensor: shape=(), dtype=float32, numpy=0.22241592>
)

Great! But what happens if I try to access count?

>>> sample.count
<function structtuple.<locals>.StructTuple.count(value, /)>

Oh dear. It looks like we get the bound tuple.count method back instead. For the same reason, this wrecks any tfd.JointDistributionCoroutine methods that depends on accessing StructTuple attributes. e.g.

>>> model.log_prob(sample)
TypeError: Cannot convert the argument `type_value`: <built-in method count of StructTuple object at 0x78954f1e0140> to a TensorFlow DType.

A fix might just be to document the reserved names in the Name resolution paragraph in the doc for this class. But I wonder if there's something deeper we could do?

Chris

@chrism0dwk chrism0dwk changed the title JointDistributionCoroutine variable naming issue Unsafe variable naming in JointDistributionCoroutine Jul 30, 2024
@ColCarroll
Copy link
Collaborator

Thanks, Chris -- looks like the StructTuple is a custom class that subclasses the built-in tuple, whose only two (public) attributes are count and index. I guess it is unfortunate that I could imagine using both of these in a graphical model!

I think we should just add those two to this check on validating names. If someone wants a variable named __gt__, they deserve whatever happens down the line 😀

What do you think?

@chrism0dwk
Copy link
Contributor Author

@ColCarroll sounds a pragmatic way forward. I've suggested a solution in #1828. See what you think...

@SiegeLordEx
Copy link
Member

The goal of structtuple was to mimic namedtuple's behavior, which neither forbids these names, nor balks at using them. It sounds like we didn't quite mimic its behavior correctly.

@SiegeLordEx
Copy link
Member

I'm guessing we need to use __getattribute__ rather than __getattr__ so that the field names correctly shadow the parent class members.

@ColCarroll
Copy link
Collaborator

#1828

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants