From 2386281b883b213c20c39d8c0271e360da11c64a Mon Sep 17 00:00:00 2001 From: sergioajini Date: Thu, 2 Aug 2018 16:11:26 +0430 Subject: [PATCH 1/2] no warning (gym & pytorch0.4 warnings) --- envs.py | 6 ++-- main.py | 9 +++-- model.py | 8 ++--- test.py | 102 ++++++++++++++++++++++++++++++------------------------- train.py | 42 +++++++++++++++++++---- 5 files changed, 103 insertions(+), 64 deletions(-) diff --git a/envs.py b/envs.py index 0724432..4e921fe 100644 --- a/envs.py +++ b/envs.py @@ -29,9 +29,9 @@ def _process_frame42(frame): class AtariRescale42x42(gym.ObservationWrapper): def __init__(self, env=None): super(AtariRescale42x42, self).__init__(env) - self.observation_space = Box(0.0, 1.0, [1, 42, 42]) + self.observation_space = Box(0.0, 1.0, [1, 42, 42], dtype=np.float32) - def _observation(self, observation): + def observation(self, observation): return _process_frame42(observation) @@ -43,7 +43,7 @@ def __init__(self, env=None): self.alpha = 0.9999 self.num_steps = 0 - def _observation(self, observation): + def observation(self, observation): self.num_steps += 1 self.state_mean = self.state_mean * self.alpha + \ observation.mean() * (1 - self.alpha) diff --git a/main.py b/main.py index 369fdf9..80719fd 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,7 @@ from model import ActorCritic from test import test from train import train +import time # Based on # https://github.com/pytorch/examples/tree/master/mnist_hogwild @@ -43,6 +44,7 @@ if __name__ == '__main__': + mp.set_start_method('spawn') os.environ['OMP_NUM_THREADS'] = '1' os.environ['CUDA_VISIBLE_DEVICES'] = "" @@ -50,8 +52,7 @@ torch.manual_seed(args.seed) env = create_atari_env(args.env_name) - shared_model = ActorCritic( - env.observation_space.shape[0], env.action_space) + shared_model = ActorCritic(env.observation_space.shape[0], env.action_space) shared_model.share_memory() if args.no_shared: @@ -70,8 +71,10 @@ processes.append(p) for rank in range(0, args.num_processes): - p = mp.Process(target=train, args=(rank, args, shared_model, counter, lock, optimizer)) + p = mp.Process(target=train, args=(rank, args, shared_model, counter, lock, optimizer, False)) p.start() processes.append(p) + time.sleep(5) + for p in processes: p.join() diff --git a/model.py b/model.py index 2cdff4d..32ec683 100644 --- a/model.py +++ b/model.py @@ -57,10 +57,10 @@ def __init__(self, num_inputs, action_space): def forward(self, inputs): inputs, (hx, cx) = inputs - x = F.elu(self.conv1(inputs)) - x = F.elu(self.conv2(x)) - x = F.elu(self.conv3(x)) - x = F.elu(self.conv4(x)) + x = F.relu(self.conv1(inputs), inplace=True) + x = F.relu(self.conv2(x), inplace=True) + x = F.relu(self.conv3(x), inplace=True) + x = F.relu(self.conv4(x), inplace=True) x = x.view(-1, 32 * 3 * 3) hx, cx = self.lstm(x, (hx, cx)) diff --git a/test.py b/test.py index 054649e..30eb58e 100644 --- a/test.py +++ b/test.py @@ -1,3 +1,4 @@ +import os import time from collections import deque @@ -7,63 +8,70 @@ from envs import create_atari_env from model import ActorCritic +from tensorboardX import SummaryWriter def test(rank, args, shared_model, counter): - torch.manual_seed(args.seed + rank) + with torch.no_grad(): + my_writer = SummaryWriter(log_dir='log') + t0 = time.time() + torch.manual_seed(args.seed + rank) - env = create_atari_env(args.env_name) - env.seed(args.seed + rank) + env = create_atari_env(args.env_name) + env.seed(args.seed + rank) - model = ActorCritic(env.observation_space.shape[0], env.action_space) + model = ActorCritic(env.observation_space.shape[0], env.action_space) + model.eval() - model.eval() + state = env.reset() + state = torch.from_numpy(state) + reward_sum = 0 + done = True - state = env.reset() - state = torch.from_numpy(state) - reward_sum = 0 - done = True + start_time = time.time() - start_time = time.time() + # a quick hack to prevent the agent from stucking + actions = deque(maxlen=100) + episode_length = 0 + while True: + episode_length += 1 + # Sync with the shared model + if done: + model.load_state_dict(shared_model.state_dict()) + cx = Variable(torch.zeros(1, 256)) + hx = Variable(torch.zeros(1, 256)) + else: + cx = Variable(cx.data) + hx = Variable(hx.data) - # a quick hack to prevent the agent from stucking - actions = deque(maxlen=100) - episode_length = 0 - while True: - episode_length += 1 - # Sync with the shared model - if done: - model.load_state_dict(shared_model.state_dict()) - cx = Variable(torch.zeros(1, 256), volatile=True) - hx = Variable(torch.zeros(1, 256), volatile=True) - else: - cx = Variable(cx.data, volatile=True) - hx = Variable(hx.data, volatile=True) + value, logit, (hx, cx) = model((Variable(state.unsqueeze(0)), (hx, cx))) + prob = F.softmax(logit, dim=1) + action = prob.max(1, keepdim=True)[1].data.numpy() - value, logit, (hx, cx) = model((Variable( - state.unsqueeze(0), volatile=True), (hx, cx))) - prob = F.softmax(logit) - action = prob.max(1, keepdim=True)[1].data.numpy() + state, reward, done, _ = env.step(action[0, 0]) + done = done or episode_length >= args.max_episode_length + reward_sum += reward - state, reward, done, _ = env.step(action[0, 0]) - done = done or episode_length >= args.max_episode_length - reward_sum += reward + # a quick hack to prevent the agent from stucking + actions.append(action[0, 0]) + if actions.count(actions[0]) == actions.maxlen: + done = True - # a quick hack to prevent the agent from stucking - actions.append(action[0, 0]) - if actions.count(actions[0]) == actions.maxlen: - done = True + if done: + my_writer.add_scalar('episode_reward', reward_sum, counter.value) + my_writer.add_scalar('episode_length', episode_length, counter.value) + my_writer.add_scalar('FPS', counter.value / (time.time() - start_time), counter.value) + print("Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}".format( + time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)), + counter.value, counter.value / (time.time() - start_time), + reward_sum, episode_length)) + reward_sum = 0 + episode_length = 0 + actions.clear() + state = env.reset() + time.sleep(4 * 60) + if (time.time() - t0) > (15 * 60): + torch.save(model.state_dict(), os.path.join('models', 'epoch_{}.pth').format(str(int(time.time())))) + t0 = time.time() - if done: - print("Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}".format( - time.strftime("%Hh %Mm %Ss", - time.gmtime(time.time() - start_time)), - counter.value, counter.value / (time.time() - start_time), - reward_sum, episode_length)) - reward_sum = 0 - episode_length = 0 - actions.clear() - state = env.reset() - time.sleep(60) - - state = torch.from_numpy(state) + state = torch.from_numpy(state) diff --git a/train.py b/train.py index e3f0143..83c6dde 100644 --- a/train.py +++ b/train.py @@ -15,18 +15,23 @@ def ensure_shared_grads(model, shared_model): shared_param._grad = param.grad -def train(rank, args, shared_model, counter, lock, optimizer=None): +def train(rank, args, shared_model, counter, lock, optimizer=None, DEBUG=False): + if DEBUG: + print('rank: {}'.format(rank)) torch.manual_seed(args.seed + rank) env = create_atari_env(args.env_name) env.seed(args.seed + rank) model = ActorCritic(env.observation_space.shape[0], env.action_space) + model.train() + if DEBUG: + print('agent{:03d}: model created'.format(rank)) if optimizer is None: optimizer = optim.Adam(shared_model.parameters(), lr=args.lr) - - model.train() + if DEBUG: + print('agent{:03d}: optimizer created'.format(rank)) state = env.reset() state = torch.from_numpy(state) @@ -34,6 +39,8 @@ def train(rank, args, shared_model, counter, lock, optimizer=None): episode_length = 0 while True: + if DEBUG: + print('agent{:03d}: while loop'.format(rank)) # Sync with the shared model model.load_state_dict(shared_model.state_dict()) if done: @@ -49,24 +56,45 @@ def train(rank, args, shared_model, counter, lock, optimizer=None): entropies = [] for step in range(args.num_steps): + if DEBUG: + print('agent{:03d}: for loop p1'.format(rank)) + episode_length += 1 - value, logit, (hx, cx) = model((Variable(state.unsqueeze(0)), - (hx, cx))) - prob = F.softmax(logit) - log_prob = F.log_softmax(logit) + if DEBUG: + print('agent{:03d}: for loop p1.1'.format(rank)) + print(state.unsqueeze(0).size()) + with lock: + value, logit, (hx, cx) = model((Variable(state.unsqueeze(0)), (hx, cx))) + if DEBUG: + print('agent{:03d}: for loop p2'.format(rank)) + # prob = F.softmax(logit) + prob = F.softmax(logit, dim=1) + log_prob = F.log_softmax(logit, dim=1) + if DEBUG: + print('agent{:03d}: for loop p3'.format(rank)) entropy = -(log_prob * prob).sum(1, keepdim=True) entropies.append(entropy) + if DEBUG: + print('agent{:03d}: for loop p4'.format(rank)) action = prob.multinomial(num_samples=1).data log_prob = log_prob.gather(1, Variable(action)) + if DEBUG: + print('agent{:03d}: for loop p5'.format(rank)) state, reward, done, _ = env.step(action.numpy()) done = done or episode_length >= args.max_episode_length reward = max(min(reward, 1), -1) + if DEBUG: + print('agent{:03d}: for loop p6'.format(rank)) with lock: counter.value += 1 + if DEBUG: + print('agent{:03d}: counter plus {:09d}'.format(rank, counter.value)) + if DEBUG: + print('agent{:03d}: for loop p7'.format(rank)) if done: episode_length = 0 state = env.reset() From e2699f63d04eed0841188e80a0bbbefef808602a Mon Sep 17 00:00:00 2001 From: sergioajini Date: Thu, 2 Aug 2018 16:20:38 +0430 Subject: [PATCH 2/2] generate video --- visualize.py | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 visualize.py diff --git a/visualize.py b/visualize.py new file mode 100644 index 0000000..7d8a4ff --- /dev/null +++ b/visualize.py @@ -0,0 +1,80 @@ +import os +from collections import deque + +import torch +import torch.nn.functional as F +from torch.autograd import Variable + +import gym +import imageio + +from envs import create_atari_env +from model import ActorCritic + + +def visualize(env_name, model_path, render=False): + with torch.no_grad(): + torch.manual_seed(0) + + env = create_atari_env(env_name) + env_orig = gym.make(env_name) + env.seed(0) + env_orig.seed(0) + + model = ActorCritic(env.observation_space.shape[0], env.action_space) + model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) + model.eval() + + states = [] + state, state_orig = env.reset(), env_orig.reset() + states.append(state_orig) + state = torch.from_numpy(state) + reward_sum = 0 + + # a quick hack to prevent the agent from stucking + actions = deque(maxlen=1000) + episode_length = 0 + + cx = Variable(torch.zeros(1, 256)) + hx = Variable(torch.zeros(1, 256)) + while True: + episode_length += 1 + + value, logit, (hx, cx) = model((Variable(state.unsqueeze(0)), (hx, cx))) + prob = F.softmax(logit, dim=1) + action = prob.max(1, keepdim=True)[1].data.numpy() + + state, reward, done, _ = env.step(action[0, 0]) + state_orig, _, _, _ = env_orig.step(action[0, 0]) + states.append(state_orig) + state = torch.from_numpy(state) + if render: + env.render() + done = done or episode_length >= 10000 + reward_sum += reward + + # a quick hack to prevent the agent from stucking + actions.append(action[0, 0]) + if actions.count(actions[0]) == actions.maxlen: + print('stuck in infinite loop') + done = True + + if done: + print('episode_length: {}'.format(episode_length)) + print('reward_sum: {}'.format(reward_sum)) + break + + cx = Variable(cx.data) + hx = Variable(hx.data) + + env.close(), env_orig.close() + return states + +if __name__ == '__main__': + MODEL_DIR = 'models' + VIDEO_DIR = 'videos' + for model_name in sorted(os.listdir(MODEL_DIR)): + print('play with model {}'.format(model_name)) + result = visualize('BreakoutNoFrameskip-v4', os.path.join(MODEL_DIR, model_name)) + imageio.mimwrite(os.path.join(VIDEO_DIR, model_name.replace('.pth', '.mp4')), result , fps=60) + print('-' * 40)