Skip to content

Commit

Permalink
DQN
Browse files Browse the repository at this point in the history
  • Loading branch information
paul90317 committed Oct 11, 2023
0 parents commit 70c157e
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
log
__pycache__
Binary file added Lab2-DQN.pdf
Binary file not shown.
Binary file added Lab2-Guide.pdf
Binary file not shown.
139 changes: 139 additions & 0 deletions base_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import torch
import torch.nn as nn
import numpy as np
import os
import time
from collections import deque
from torch.utils.tensorboard import SummaryWriter
from replay_buffer.replay_buffer import ReplayMemory
from abc import ABC, abstractmethod


class DQNBaseAgent(ABC):
def __init__(self, config):
self.gpu = config["gpu"]
self.device = torch.device("cuda" if self.gpu and torch.cuda.is_available() else "cpu")
self.total_time_step = 0
self.training_steps = int(config["training_steps"])
self.batch_size = int(config["batch_size"])
self.epsilon = 1.0
self.eps_min = config["eps_min"]
self.eps_decay = config["eps_decay"]
self.eval_epsilon = config["eval_epsilon"]
self.warmup_steps = config["warmup_steps"]
self.eval_interval = config["eval_interval"]
self.eval_episode = config["eval_episode"]
self.gamma = config["gamma"]
self.update_freq = config["update_freq"]
self.update_target_freq = config["update_target_freq"]

self.replay_buffer = ReplayMemory(int(config["replay_buffer_capacity"]))
self.writer = SummaryWriter(config["logdir"])

@abstractmethod
def decide_agent_actions(self, observation, epsilon=0.0, action_space=None):
### TODO ###
# get action from behavior net, with epsilon-greedy selection

return NotImplementedError

def update(self):
if self.total_time_step % self.update_freq == 0:
self.update_behavior_network()
if self.total_time_step % self.update_target_freq == 0:
self.update_target_network()

@abstractmethod
def update_behavior_network(self):
# sample a minibatch of transitions
state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size, self.device)
### TODO ###
# calculate the loss and update the behavior network


def update_target_network(self):
self.target_net.load_state_dict(self.behavior_net.state_dict())

def epsilon_decay(self):
self.epsilon -= (1 - self.eps_min) / self.eps_decay
self.epsilon = max(self.epsilon, self.eps_min)

def train(self):
episode_idx = 0
while self.total_time_step <= self.training_steps:
observation, info = self.env.reset()
episode_reward = 0
episode_len = 0
episode_idx += 1
while True:
if self.total_time_step < self.warmup_steps:
action = self.decide_agent_actions(observation, 1.0, self.env.action_space)
else:
action = self.decide_agent_actions(observation, self.epsilon, self.env.action_space)
self.epsilon_decay()

next_observation, reward, terminate, truncate, info = self.env.step(action)
self.replay_buffer.append(observation, [action], [reward], next_observation, [int(terminate)])

if self.total_time_step >= self.warmup_steps:
self.update()

episode_reward += reward
episode_len += 1

if terminate or truncate:
self.writer.add_scalar('Train/Episode Reward', episode_reward, self.total_time_step)
self.writer.add_scalar('Train/Episode Len', episode_len, self.total_time_step)
print(f"[{self.total_time_step}/{self.training_steps}] episode: {episode_idx} episode reward: {episode_reward} episode len: {episode_len} epsilon: {self.epsilon}")
break

observation = next_observation
self.total_time_step += 1

if episode_idx % self.eval_interval == 0:
# save model checkpoint
avg_score = self.evaluate()
self.save(os.path.join(self.writer.log_dir, f"model_{self.total_time_step}_{int(avg_score)}.pth"))
self.writer.add_scalar('Evaluate/Episode Reward', avg_score, self.total_time_step)

def evaluate(self):
print("==============================================")
print("Evaluating...")
all_rewards = []
for i in range(self.eval_episode):
observation, info = self.test_env.reset()
total_reward = 0
while True:
self.test_env.render()
action = self.decide_agent_actions(observation, self.eval_epsilon, self.test_env.action_space)
next_observation, reward, terminate, truncate, info = self.test_env.step(action)
total_reward += reward
if terminate or truncate:
print(f"episode {i+1} reward: {total_reward}")
all_rewards.append(total_reward)
break

observation = next_observation


avg = sum(all_rewards) / self.eval_episode
print(f"average score: {avg}")
print("==============================================")
return avg

# save model
def save(self, save_path):
torch.save(self.behavior_net.state_dict(), save_path)

# load model
def load(self, load_path):
self.behavior_net.load_state_dict(torch.load(load_path))

# load model weights and evaluate
def load_and_evaluate(self, load_path):
self.load(load_path)
self.evaluate()




122 changes: 122 additions & 0 deletions dqn_agent_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import torch
import torch.nn as nn
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from base_agent import DQNBaseAgent
from models.atari_model import AtariNetDQN
import gym
import random
from gym.wrappers import FrameStack
import cv2
from replay_buffer.replay_buffer import ReplayMemory
import sys

def transform(frames):
new_frames=[]
for img in frames:
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img = img[:172,:]
img = cv2.resize(img,(84,84))
new_frames.append(img)
return np.array(new_frames)

class MyReplayMemory(ReplayMemory):
def __init__(self, capacity, action_space_n):
super().__init__(capacity)
self.action_space_n = action_space_n
def append(self, *transition):
state, action, reward, next_state, done = transition

state=transform(state)
next_state=transform(next_state)

# cv2.imwrite("1.png",state[-1])

self.buffer.append((state, action, reward, next_state, done))

def sample(self, batch_size, device):
transitions = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = zip(*transitions)
return (
torch.tensor(np.array(state),dtype=torch.float,device=device),
torch.tensor(action,dtype=torch.int64,device=device),
torch.tensor(reward,dtype=torch.float,device=device),
torch.tensor(np.array(next_state),dtype=torch.float,device=device),
1 - torch.tensor(done,dtype=torch.float,device=device)
)

class AtariDQNAgent(DQNBaseAgent):
def __init__(self, config):
super(AtariDQNAgent, self).__init__(config)

### TODO ###
# initialize env
# self.env = ???

self.test_env = FrameStack(gym.make(config['env_id'],render_mode='human'),4)
self.env = FrameStack(gym.make(config['env_id']),4)

self.replay_buffer = MyReplayMemory(int(config["replay_buffer_capacity"]),self.env.action_space.n)

# initialize behavior network and target network
self.behavior_net = AtariNetDQN(self.env.action_space.n)
self.behavior_net.to(self.device)
self.target_net = AtariNetDQN(self.env.action_space.n)
self.target_net.to(self.device)

if len(sys.argv) > 1:
self.load(sys.argv[1])
self.target_net.load_state_dict(self.behavior_net.state_dict())

# initialize optimizer
self.lr = config["learning_rate"]
self.optim = torch.optim.Adam(self.behavior_net.parameters(), lr=self.lr, eps=1.5e-4)

def decide_agent_actions(self, observation, epsilon=0.0, action_space : gym.Space=None):
### TODO ###
# get action from behavior net, with epsilon-greedy selection

if random.random() < epsilon:
return random.randint(0, action_space.n-1)

with torch.no_grad():
x=torch.tensor(np.array([transform(observation)]),dtype=torch.float, device=self.device)
y=self.behavior_net(x)
return int(torch.argmax(y))

def update_behavior_network(self):
# sample a minibatch of transitions
state, action, reward, next_state, yet = self.replay_buffer.sample(self.batch_size, self.device)
self.behavior_net.train()

### TODO ###
# calculate the loss and update the behavior network
# 1. get max_a Q(s',a) from target net
# 2. calculate Q_target = r + gamma * max_a Q(s',a)
# 3. get Q(s,a) from behavior net
# 4. calculate loss between Q(s,a) and Q_target
# 5. update behavior net

with torch.no_grad():
q_next = self.target_net(next_state)
q_next : torch.Tensor = torch.max(q_next, dim = 1)[0]
q_next = q_next.reshape(self.batch_size, 1)

# if episode terminates at next_state, then q_target = reward
q_target = self.gamma * q_next * yet + reward

q_value : torch.Tensor = self.behavior_net(state)
q_value = q_value.gather(1,action)

criterion = torch.nn.MSELoss()
loss = criterion(q_value, q_target)

self.writer.add_scalar('DQN/Loss', loss.item(), self.total_time_step)

self.optim.zero_grad()
loss.backward()
self.optim.step()

self.behavior_net.eval()


24 changes: 24 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dqn_agent_atari import AtariDQNAgent

if __name__ == '__main__':
# my hyperparameters, you can change it as you like
config = {
"gpu": True,
"training_steps": 1e8,
"gamma": 0.99,
"batch_size": 32,
"eps_min": 0.1,
"warmup_steps": 20000,
"eps_decay": 1000000,
"eval_epsilon": 0.01,
"replay_buffer_capacity": 100000,
"logdir": 'log/DQN/',
"update_freq": 4,
"update_target_freq": 10000,
"learning_rate": 0.0000625,
"eval_interval": 100,
"eval_episode": 5,
"env_id": 'ALE/MsPacman-v5',
}
agent = AtariDQNAgent(config)
agent.train()
43 changes: 43 additions & 0 deletions models/atari_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class AtariNetDQN(nn.Module):
def __init__(self, num_classes=4, init_weights=True):
super(AtariNetDQN, self).__init__()
self.cnn = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4),
nn.ReLU(True),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(True),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(True)
)
self.classifier = nn.Sequential(nn.Linear(7*7*64, 512),
nn.ReLU(True),
nn.Linear(512, num_classes)
)

if init_weights:
self._initialize_weights()

def forward(self, x):
x = x.float() / 255.
x = self.cnn(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(m.bias, 0.0)

21 changes: 21 additions & 0 deletions replay_buffer/replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
import torch
from collections import deque
import random

class ReplayMemory(object):
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)

def __len__(self):
return len(self.buffer)

def append(self, *transition):
"""Saves a transition"""
self.buffer.append(tuple(map(tuple, transition)))

def sample(self, batch_size, device):
"""Sample a batch of transitions"""
transitions = random.sample(self.buffer, batch_size)
return (torch.tensor(np.asarray(x), dtype=torch.float, device=device) for x in zip(*transitions))

0 comments on commit 70c157e

Please sign in to comment.