Skip to content

Commit

Permalink
Dueling
Browse files Browse the repository at this point in the history
  • Loading branch information
paul90317 committed Oct 14, 2023
1 parent 2131c1c commit 8faea9e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
4 changes: 3 additions & 1 deletion base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.tensorboard import SummaryWriter
from replay_buffer.replay_buffer import ReplayMemory
from abc import ABC, abstractmethod

import random

class DQNBaseAgent(ABC):
def __init__(self, config):
Expand Down Expand Up @@ -102,6 +102,8 @@ def evaluate(self):
all_rewards = []
for i in range(self.eval_episode):
observation, info = self.test_env.reset()
# observation, info = self.test_env.reset(seed=0)
# random.seed(0)
total_reward = 0
while True:
self.test_env.render()
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# my hyperparameters, you can change it as you like
config = {
"gpu": True,
"training_steps": 1e8,
"training_steps": 7000000,
"gamma": 0.99,
"batch_size": 32,
"eps_min": 0.1,
"warmup_steps": 20000,
"eps_decay": 1000000,
"eval_epsilon": 0.01,
"replay_buffer_capacity": 100000,
"logdir": 'log/DDQN/',
"logdir": 'log/Dueling/',
"update_freq": 4,
"update_target_freq": 10000,
"learning_rate": 0.0000625,
Expand Down
27 changes: 21 additions & 6 deletions models/atari_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@ def __init__(self, num_classes=4, init_weights=True):
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(True)
)
self.classifier = nn.Sequential(nn.Linear(7*7*64, 512),
nn.ReLU(True),
nn.Linear(512, num_classes)
)
# Dueling DQN architecture
self.value_stream = nn.Sequential(
nn.Linear(7*7*64, 512),
nn.ReLU(True),
nn.Linear(512, 1)
)

self.advantage_stream = nn.Sequential(
nn.Linear(7*7*64, 512),
nn.ReLU(True),
nn.Linear(512, num_classes)
)

if init_weights:
self._initialize_weights()
Expand All @@ -25,8 +33,15 @@ def forward(self, x):
x = x.float() / 255.
x = self.cnn(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x

# Compute value and advantage
value = self.value_stream(x)
advantage = self.advantage_stream(x)

# Combine value and advantage to get Q-values
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))

return q_values

def _initialize_weights(self):
for m in self.modules():
Expand Down

0 comments on commit 8faea9e

Please sign in to comment.