-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch ensemble ddpg #1633
base: pytorch
Are you sure you want to change the base?
Batch ensemble ddpg #1633
Conversation
8930ae0
to
ad5f40e
Compare
@@ -139,6 +148,17 @@ def __init__(self, | |||
gradient dqda element-wise between ``[-dqda_clipping, dqda_clipping]``. | |||
Does not perform clipping if ``dqda_clipping == 0``. | |||
action_l2 (float): weight of squared action l2-norm on actor loss. | |||
use_batch_ensemble (bool): whether to use BatchEnsemble FC and Conv2D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we might should make these batch ensemble related parameters transparent to the ddpg_algorithm? Basically, the ddpg_algorithm should not use batch_ensemble related parameters in the ideal case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. Currently ddpg needs the use_batch_ensemble
to do some post processing when forwarding critic networks during training. Let me think it over if there might be some alternative methods to work around.
@@ -281,14 +318,39 @@ def _update_random_action(spec, noisy_action): | |||
if self._rollout_random_action > 0: | |||
nest.map_structure(_update_random_action, self._action_spec, | |||
pred_step.output) | |||
return pred_step | |||
|
|||
if self.need_full_rollout_state(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want the algorithm use the same ensemble_id during an entire episode. This means that it should store ensembled_id in state and use the same ensemble_id to call actor_network
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, good point, I think that is the reason why I had to tweak the ddpg_algorithm_test
to pass the toy unittest. Updated.
Update
actor_network
,critic_network
, andddpg_algorithm
to work with batch_ensemble layers.