Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Log each entropy for composite distributions in PPO #2707

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

louisfaury
Copy link
Contributor

@louisfaury louisfaury commented Jan 20, 2025

Description

This PR enables PPO to log the entropy of each individual head of a composite policy separately.

Concretely, for a composite distribution with, say, a nested discrete and continuous head, the td_out is augmented with some detached values.

>>> loss = PPO(td)
>>> loss.keys()
[..., entropy, discrete-head-entropy, continuous-head-entropy, ...]

Motivation and Context

This is an extremely useful debugging tool when training composite policies.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Jan 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2707

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 20, 2025
torchrl/objectives/ppo.py Show resolved Hide resolved
torchrl/objectives/ppo.py Show resolved Hide resolved
torchrl/objectives/ppo.py Show resolved Hide resolved
for head_key, head_entropy in entropy.items(
include_nested=True, leaves_only=True
):
td_out.set("-".join(head_key), head_entropy.detach().mean())
Copy link
Contributor Author

@louisfaury louisfaury Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choosing under which key to log the individual factor's entropy was a bit of a headache. The way I personally use CompositeDistribution yields tensordict that look like:

action: {
   head_1: {
        action: ...
        entropy: ...
   }
   head_2 {
        action: ...
        entropy: ...
   }
}

which means that using head_key[-1] to log each entropy is not really a viable solution (all the factor entropies will be logged under the same name, entropy). I'm not sure how to get a one-size-fits-all here, and happy for suggestions. The current solution ensures that there is no collision, at the price of having very verbose keys.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't the most generic solution just be to log the entropy TD as it comes?
Why do we need to rename it?
BTW it seems to me that what you're doing here amends to

tensordict.flatten_keys("-").detact().mean()

Nit: this isn't collision-safe I think (but flatten_keys will tell you if there are any collision):
eg ("key-one", "entropy") and ("key", "one", "entropy") will collide

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also - are we 100% sure all keys are nested? I think so (that's how CompositeDist works) but maybe we could just put a safeguard check here to make sure an error is raised if that assumption is violated (eg, users have their own dist class that returns {"entropy", ("nested", "entropy")} keys).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed your recommendation: I added a composite_entropy key to the loss td. Two remarks:

  1. The composite entropy is not logged under entropy to avoid BC (users currently expect a Tensor),
  2. I did not detach() the composite entropy; this would allow the user to compute a custom entropy bonus when using a composite entropy (e.g. not the same penalty per head).

Wdyt ?

@vmoens vmoens force-pushed the lf/ppo-log-composite-entropies branch from 71bd4bd to 06c3d94 Compare January 21, 2025 11:49
@vmoens vmoens added the enhancement New feature or request label Jan 21, 2025
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a couple of comments to address before we merge it

LMK what you think would be the best way to log the entropies in the output data structure, I think flattening may be a bit surprising

torchrl/objectives/ppo.py Show resolved Hide resolved
for head_key, head_entropy in entropy.items(
include_nested=True, leaves_only=True
):
td_out.set("-".join(head_key), head_entropy.detach().mean())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't the most generic solution just be to log the entropy TD as it comes?
Why do we need to rename it?
BTW it seems to me that what you're doing here amends to

tensordict.flatten_keys("-").detact().mean()

Nit: this isn't collision-safe I think (but flatten_keys will tell you if there are any collision):
eg ("key-one", "entropy") and ("key", "one", "entropy") will collide

for head_key, head_entropy in entropy.items(
include_nested=True, leaves_only=True
):
td_out.set("-".join(head_key), head_entropy.detach().mean())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, we could not flatten and let users do that, no?

for head_key, head_entropy in entropy.items(
include_nested=True, leaves_only=True
):
td_out.set("-".join(head_key), head_entropy.detach().mean())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also - are we 100% sure all keys are nested? I think so (that's how CompositeDist works) but maybe we could just put a safeguard check here to make sure an error is raised if that assumption is violated (eg, users have their own dist class that returns {"entropy", ("nested", "entropy")} keys).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants