-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Log each entropy for composite distributions in PPO #2707
base: main
Are you sure you want to change the base?
[Feature] Log each entropy for composite distributions in PPO #2707
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2707
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchrl/objectives/ppo.py
Outdated
for head_key, head_entropy in entropy.items( | ||
include_nested=True, leaves_only=True | ||
): | ||
td_out.set("-".join(head_key), head_entropy.detach().mean()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Choosing under which key to log the individual factor's entropy was a bit of a headache. The way I personally use CompositeDistribution
yields tensordict that look like:
action: {
head_1: {
action: ...
entropy: ...
}
head_2 {
action: ...
entropy: ...
}
}
which means that using head_key[-1]
to log each entropy is not really a viable solution (all the factor entropies will be logged under the same name, entropy
). I'm not sure how to get a one-size-fits-all here, and happy for suggestions. The current solution ensures that there is no collision, at the price of having very verbose keys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't the most generic solution just be to log the entropy TD as it comes?
Why do we need to rename it?
BTW it seems to me that what you're doing here amends to
tensordict.flatten_keys("-").detact().mean()
Nit: this isn't collision-safe I think (but flatten_keys will tell you if there are any collision):
eg ("key-one", "entropy")
and ("key", "one", "entropy")
will collide
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also - are we 100% sure all keys are nested? I think so (that's how CompositeDist
works) but maybe we could just put a safeguard check here to make sure an error is raised if that assumption is violated (eg, users have their own dist class that returns {"entropy", ("nested", "entropy")}
keys).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed your recommendation: I added a composite_entropy
key to the loss td. Two remarks:
- The composite entropy is not logged under
entropy
to avoid BC (users currently expect a Tensor), - I did not
detach()
the composite entropy; this would allow the user to compute a custom entropy bonus when using a composite entropy (e.g. not the same penalty per head).
Wdyt ?
71bd4bd
to
06c3d94
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a couple of comments to address before we merge it
LMK what you think would be the best way to log the entropies in the output data structure, I think flattening may be a bit surprising
torchrl/objectives/ppo.py
Outdated
for head_key, head_entropy in entropy.items( | ||
include_nested=True, leaves_only=True | ||
): | ||
td_out.set("-".join(head_key), head_entropy.detach().mean()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't the most generic solution just be to log the entropy TD as it comes?
Why do we need to rename it?
BTW it seems to me that what you're doing here amends to
tensordict.flatten_keys("-").detact().mean()
Nit: this isn't collision-safe I think (but flatten_keys will tell you if there are any collision):
eg ("key-one", "entropy")
and ("key", "one", "entropy")
will collide
torchrl/objectives/ppo.py
Outdated
for head_key, head_entropy in entropy.items( | ||
include_nested=True, leaves_only=True | ||
): | ||
td_out.set("-".join(head_key), head_entropy.detach().mean()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, we could not flatten and let users do that, no?
torchrl/objectives/ppo.py
Outdated
for head_key, head_entropy in entropy.items( | ||
include_nested=True, leaves_only=True | ||
): | ||
td_out.set("-".join(head_key), head_entropy.detach().mean()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also - are we 100% sure all keys are nested? I think so (that's how CompositeDist
works) but maybe we could just put a safeguard check here to make sure an error is raised if that assumption is violated (eg, users have their own dist class that returns {"entropy", ("nested", "entropy")}
keys).
Description
This PR enables PPO to log the entropy of each individual head of a composite policy separately.
Concretely, for a composite distribution with, say, a nested discrete and continuous head, the
td_out
is augmented with some detached values.Motivation and Context
This is an extremely useful debugging tool when training composite policies.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!