diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e33fcb3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +log +__pycache__ \ No newline at end of file diff --git a/Lab2-DQN.pdf b/Lab2-DQN.pdf new file mode 100644 index 0000000..6f1d199 Binary files /dev/null and b/Lab2-DQN.pdf differ diff --git a/Lab2-Guide.pdf b/Lab2-Guide.pdf new file mode 100644 index 0000000..2f84afd Binary files /dev/null and b/Lab2-Guide.pdf differ diff --git a/base_agent.py b/base_agent.py new file mode 100644 index 0000000..f3bb5ed --- /dev/null +++ b/base_agent.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import numpy as np +import os +import time +from collections import deque +from torch.utils.tensorboard import SummaryWriter +from replay_buffer.replay_buffer import ReplayMemory +from abc import ABC, abstractmethod + + +class DQNBaseAgent(ABC): + def __init__(self, config): + self.gpu = config["gpu"] + self.device = torch.device("cuda" if self.gpu and torch.cuda.is_available() else "cpu") + self.total_time_step = 0 + self.training_steps = int(config["training_steps"]) + self.batch_size = int(config["batch_size"]) + self.epsilon = 1.0 + self.eps_min = config["eps_min"] + self.eps_decay = config["eps_decay"] + self.eval_epsilon = config["eval_epsilon"] + self.warmup_steps = config["warmup_steps"] + self.eval_interval = config["eval_interval"] + self.eval_episode = config["eval_episode"] + self.gamma = config["gamma"] + self.update_freq = config["update_freq"] + self.update_target_freq = config["update_target_freq"] + + self.replay_buffer = ReplayMemory(int(config["replay_buffer_capacity"])) + self.writer = SummaryWriter(config["logdir"]) + + @abstractmethod + def decide_agent_actions(self, observation, epsilon=0.0, action_space=None): + ### TODO ### + # get action from behavior net, with epsilon-greedy selection + + return NotImplementedError + + def update(self): + if self.total_time_step % self.update_freq == 0: + self.update_behavior_network() + if self.total_time_step % self.update_target_freq == 0: + self.update_target_network() + + @abstractmethod + def update_behavior_network(self): + # sample a minibatch of transitions + state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size, self.device) + ### TODO ### + # calculate the loss and update the behavior network + + + def update_target_network(self): + self.target_net.load_state_dict(self.behavior_net.state_dict()) + + def epsilon_decay(self): + self.epsilon -= (1 - self.eps_min) / self.eps_decay + self.epsilon = max(self.epsilon, self.eps_min) + + def train(self): + episode_idx = 0 + while self.total_time_step <= self.training_steps: + observation, info = self.env.reset() + episode_reward = 0 + episode_len = 0 + episode_idx += 1 + while True: + if self.total_time_step < self.warmup_steps: + action = self.decide_agent_actions(observation, 1.0, self.env.action_space) + else: + action = self.decide_agent_actions(observation, self.epsilon, self.env.action_space) + self.epsilon_decay() + + next_observation, reward, terminate, truncate, info = self.env.step(action) + self.replay_buffer.append(observation, [action], [reward], next_observation, [int(terminate)]) + + if self.total_time_step >= self.warmup_steps: + self.update() + + episode_reward += reward + episode_len += 1 + + if terminate or truncate: + self.writer.add_scalar('Train/Episode Reward', episode_reward, self.total_time_step) + self.writer.add_scalar('Train/Episode Len', episode_len, self.total_time_step) + print(f"[{self.total_time_step}/{self.training_steps}] episode: {episode_idx} episode reward: {episode_reward} episode len: {episode_len} epsilon: {self.epsilon}") + break + + observation = next_observation + self.total_time_step += 1 + + if episode_idx % self.eval_interval == 0: + # save model checkpoint + avg_score = self.evaluate() + self.save(os.path.join(self.writer.log_dir, f"model_{self.total_time_step}_{int(avg_score)}.pth")) + self.writer.add_scalar('Evaluate/Episode Reward', avg_score, self.total_time_step) + + def evaluate(self): + print("==============================================") + print("Evaluating...") + all_rewards = [] + for i in range(self.eval_episode): + observation, info = self.test_env.reset() + total_reward = 0 + while True: + self.test_env.render() + action = self.decide_agent_actions(observation, self.eval_epsilon, self.test_env.action_space) + next_observation, reward, terminate, truncate, info = self.test_env.step(action) + total_reward += reward + if terminate or truncate: + print(f"episode {i+1} reward: {total_reward}") + all_rewards.append(total_reward) + break + + observation = next_observation + + + avg = sum(all_rewards) / self.eval_episode + print(f"average score: {avg}") + print("==============================================") + return avg + + # save model + def save(self, save_path): + torch.save(self.behavior_net.state_dict(), save_path) + + # load model + def load(self, load_path): + self.behavior_net.load_state_dict(torch.load(load_path)) + + # load model weights and evaluate + def load_and_evaluate(self, load_path): + self.load(load_path) + self.evaluate() + + + + diff --git a/dqn_agent_atari.py b/dqn_agent_atari.py new file mode 100644 index 0000000..9195b4d --- /dev/null +++ b/dqn_agent_atari.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from base_agent import DQNBaseAgent +from models.atari_model import AtariNetDQN +import gym +import random +from gym.wrappers import FrameStack +import cv2 +from replay_buffer.replay_buffer import ReplayMemory +import sys + +def transform(frames): + new_frames=[] + for img in frames: + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + img = img[:172,:] + img = cv2.resize(img,(84,84)) + new_frames.append(img) + return np.array(new_frames) + +class MyReplayMemory(ReplayMemory): + def __init__(self, capacity, action_space_n): + super().__init__(capacity) + self.action_space_n = action_space_n + def append(self, *transition): + state, action, reward, next_state, done = transition + + state=transform(state) + next_state=transform(next_state) + + # cv2.imwrite("1.png",state[-1]) + + self.buffer.append((state, action, reward, next_state, done)) + + def sample(self, batch_size, device): + transitions = random.sample(self.buffer, batch_size) + state, action, reward, next_state, done = zip(*transitions) + return ( + torch.tensor(np.array(state),dtype=torch.float,device=device), + torch.tensor(action,dtype=torch.int64,device=device), + torch.tensor(reward,dtype=torch.float,device=device), + torch.tensor(np.array(next_state),dtype=torch.float,device=device), + 1 - torch.tensor(done,dtype=torch.float,device=device) + ) + +class AtariDQNAgent(DQNBaseAgent): + def __init__(self, config): + super(AtariDQNAgent, self).__init__(config) + + ### TODO ### + # initialize env + # self.env = ??? + + self.test_env = FrameStack(gym.make(config['env_id'],render_mode='human'),4) + self.env = FrameStack(gym.make(config['env_id']),4) + + self.replay_buffer = MyReplayMemory(int(config["replay_buffer_capacity"]),self.env.action_space.n) + + # initialize behavior network and target network + self.behavior_net = AtariNetDQN(self.env.action_space.n) + self.behavior_net.to(self.device) + self.target_net = AtariNetDQN(self.env.action_space.n) + self.target_net.to(self.device) + + if len(sys.argv) > 1: + self.load(sys.argv[1]) + self.target_net.load_state_dict(self.behavior_net.state_dict()) + + # initialize optimizer + self.lr = config["learning_rate"] + self.optim = torch.optim.Adam(self.behavior_net.parameters(), lr=self.lr, eps=1.5e-4) + + def decide_agent_actions(self, observation, epsilon=0.0, action_space : gym.Space=None): + ### TODO ### + # get action from behavior net, with epsilon-greedy selection + + if random.random() < epsilon: + return random.randint(0, action_space.n-1) + + with torch.no_grad(): + x=torch.tensor(np.array([transform(observation)]),dtype=torch.float, device=self.device) + y=self.behavior_net(x) + return int(torch.argmax(y)) + + def update_behavior_network(self): + # sample a minibatch of transitions + state, action, reward, next_state, yet = self.replay_buffer.sample(self.batch_size, self.device) + self.behavior_net.train() + + ### TODO ### + # calculate the loss and update the behavior network + # 1. get max_a Q(s',a) from target net + # 2. calculate Q_target = r + gamma * max_a Q(s',a) + # 3. get Q(s,a) from behavior net + # 4. calculate loss between Q(s,a) and Q_target + # 5. update behavior net + + with torch.no_grad(): + q_next = self.target_net(next_state) + q_next : torch.Tensor = torch.max(q_next, dim = 1)[0] + q_next = q_next.reshape(self.batch_size, 1) + + # if episode terminates at next_state, then q_target = reward + q_target = self.gamma * q_next * yet + reward + + q_value : torch.Tensor = self.behavior_net(state) + q_value = q_value.gather(1,action) + + criterion = torch.nn.MSELoss() + loss = criterion(q_value, q_target) + + self.writer.add_scalar('DQN/Loss', loss.item(), self.total_time_step) + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + self.behavior_net.eval() + + \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..0e193f4 --- /dev/null +++ b/main.py @@ -0,0 +1,24 @@ +from dqn_agent_atari import AtariDQNAgent + +if __name__ == '__main__': + # my hyperparameters, you can change it as you like + config = { + "gpu": True, + "training_steps": 1e8, + "gamma": 0.99, + "batch_size": 32, + "eps_min": 0.1, + "warmup_steps": 20000, + "eps_decay": 1000000, + "eval_epsilon": 0.01, + "replay_buffer_capacity": 100000, + "logdir": 'log/DQN/', + "update_freq": 4, + "update_target_freq": 10000, + "learning_rate": 0.0000625, + "eval_interval": 100, + "eval_episode": 5, + "env_id": 'ALE/MsPacman-v5', + } + agent = AtariDQNAgent(config) + agent.train() \ No newline at end of file diff --git a/models/atari_model.py b/models/atari_model.py new file mode 100644 index 0000000..6f0fda0 --- /dev/null +++ b/models/atari_model.py @@ -0,0 +1,43 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class AtariNetDQN(nn.Module): + def __init__(self, num_classes=4, init_weights=True): + super(AtariNetDQN, self).__init__() + self.cnn = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), + nn.ReLU(True), + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.ReLU(True), + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.ReLU(True) + ) + self.classifier = nn.Sequential(nn.Linear(7*7*64, 512), + nn.ReLU(True), + nn.Linear(512, num_classes) + ) + + if init_weights: + self._initialize_weights() + + def forward(self, x): + x = x.float() / 255. + x = self.cnn(x) + x = torch.flatten(x, start_dim=1) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(m.bias, 0.0) + diff --git a/replay_buffer/replay_buffer.py b/replay_buffer/replay_buffer.py new file mode 100644 index 0000000..62f55f6 --- /dev/null +++ b/replay_buffer/replay_buffer.py @@ -0,0 +1,21 @@ +import numpy as np +import torch +from collections import deque +import random + +class ReplayMemory(object): + def __init__(self, capacity): + self.buffer = deque(maxlen=capacity) + + def __len__(self): + return len(self.buffer) + + def append(self, *transition): + """Saves a transition""" + self.buffer.append(tuple(map(tuple, transition))) + + def sample(self, batch_size, device): + """Sample a batch of transitions""" + transitions = random.sample(self.buffer, batch_size) + return (torch.tensor(np.asarray(x), dtype=torch.float, device=device) for x in zip(*transitions)) + \ No newline at end of file