Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

no warning (gym & pytorch0.4 warnings) #54

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def _process_frame42(frame):
class AtariRescale42x42(gym.ObservationWrapper):
def __init__(self, env=None):
super(AtariRescale42x42, self).__init__(env)
self.observation_space = Box(0.0, 1.0, [1, 42, 42])
self.observation_space = Box(0.0, 1.0, [1, 42, 42], dtype=np.float32)

def _observation(self, observation):
def observation(self, observation):
return _process_frame42(observation)


Expand All @@ -43,7 +43,7 @@ def __init__(self, env=None):
self.alpha = 0.9999
self.num_steps = 0

def _observation(self, observation):
def observation(self, observation):
self.num_steps += 1
self.state_mean = self.state_mean * self.alpha + \
observation.mean() * (1 - self.alpha)
Expand Down
9 changes: 6 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from model import ActorCritic
from test import test
from train import train
import time

# Based on
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
Expand Down Expand Up @@ -43,15 +44,15 @@


if __name__ == '__main__':
mp.set_start_method('spawn')
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = ""

args = parser.parse_args()

torch.manual_seed(args.seed)
env = create_atari_env(args.env_name)
shared_model = ActorCritic(
env.observation_space.shape[0], env.action_space)
shared_model = ActorCritic(env.observation_space.shape[0], env.action_space)
shared_model.share_memory()

if args.no_shared:
Expand All @@ -70,8 +71,10 @@
processes.append(p)

for rank in range(0, args.num_processes):
p = mp.Process(target=train, args=(rank, args, shared_model, counter, lock, optimizer))
p = mp.Process(target=train, args=(rank, args, shared_model, counter, lock, optimizer, False))
p.start()
processes.append(p)
time.sleep(5)

for p in processes:
p.join()
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def __init__(self, num_inputs, action_space):

def forward(self, inputs):
inputs, (hx, cx) = inputs
x = F.elu(self.conv1(inputs))
x = F.elu(self.conv2(x))
x = F.elu(self.conv3(x))
x = F.elu(self.conv4(x))
x = F.relu(self.conv1(inputs), inplace=True)
x = F.relu(self.conv2(x), inplace=True)
x = F.relu(self.conv3(x), inplace=True)
x = F.relu(self.conv4(x), inplace=True)

x = x.view(-1, 32 * 3 * 3)
hx, cx = self.lstm(x, (hx, cx))
Expand Down
102 changes: 55 additions & 47 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
from collections import deque

Expand All @@ -7,63 +8,70 @@

from envs import create_atari_env
from model import ActorCritic
from tensorboardX import SummaryWriter


def test(rank, args, shared_model, counter):
torch.manual_seed(args.seed + rank)
with torch.no_grad():
my_writer = SummaryWriter(log_dir='log')
t0 = time.time()
torch.manual_seed(args.seed + rank)

env = create_atari_env(args.env_name)
env.seed(args.seed + rank)
env = create_atari_env(args.env_name)
env.seed(args.seed + rank)

model = ActorCritic(env.observation_space.shape[0], env.action_space)
model = ActorCritic(env.observation_space.shape[0], env.action_space)
model.eval()

model.eval()
state = env.reset()
state = torch.from_numpy(state)
reward_sum = 0
done = True

state = env.reset()
state = torch.from_numpy(state)
reward_sum = 0
done = True
start_time = time.time()

start_time = time.time()
# a quick hack to prevent the agent from stucking
actions = deque(maxlen=100)
episode_length = 0
while True:
episode_length += 1
# Sync with the shared model
if done:
model.load_state_dict(shared_model.state_dict())
cx = Variable(torch.zeros(1, 256))
hx = Variable(torch.zeros(1, 256))
else:
cx = Variable(cx.data)
hx = Variable(hx.data)

# a quick hack to prevent the agent from stucking
actions = deque(maxlen=100)
episode_length = 0
while True:
episode_length += 1
# Sync with the shared model
if done:
model.load_state_dict(shared_model.state_dict())
cx = Variable(torch.zeros(1, 256), volatile=True)
hx = Variable(torch.zeros(1, 256), volatile=True)
else:
cx = Variable(cx.data, volatile=True)
hx = Variable(hx.data, volatile=True)
value, logit, (hx, cx) = model((Variable(state.unsqueeze(0)), (hx, cx)))
prob = F.softmax(logit, dim=1)
action = prob.max(1, keepdim=True)[1].data.numpy()

value, logit, (hx, cx) = model((Variable(
state.unsqueeze(0), volatile=True), (hx, cx)))
prob = F.softmax(logit)
action = prob.max(1, keepdim=True)[1].data.numpy()
state, reward, done, _ = env.step(action[0, 0])
done = done or episode_length >= args.max_episode_length
reward_sum += reward

state, reward, done, _ = env.step(action[0, 0])
done = done or episode_length >= args.max_episode_length
reward_sum += reward
# a quick hack to prevent the agent from stucking
actions.append(action[0, 0])
if actions.count(actions[0]) == actions.maxlen:
done = True

# a quick hack to prevent the agent from stucking
actions.append(action[0, 0])
if actions.count(actions[0]) == actions.maxlen:
done = True
if done:
my_writer.add_scalar('episode_reward', reward_sum, counter.value)
my_writer.add_scalar('episode_length', episode_length, counter.value)
my_writer.add_scalar('FPS', counter.value / (time.time() - start_time), counter.value)
print("Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}".format(
time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)),
counter.value, counter.value / (time.time() - start_time),
reward_sum, episode_length))
reward_sum = 0
episode_length = 0
actions.clear()
state = env.reset()
time.sleep(4 * 60)
if (time.time() - t0) > (15 * 60):
torch.save(model.state_dict(), os.path.join('models', 'epoch_{}.pth').format(str(int(time.time()))))
t0 = time.time()

if done:
print("Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}".format(
time.strftime("%Hh %Mm %Ss",
time.gmtime(time.time() - start_time)),
counter.value, counter.value / (time.time() - start_time),
reward_sum, episode_length))
reward_sum = 0
episode_length = 0
actions.clear()
state = env.reset()
time.sleep(60)

state = torch.from_numpy(state)
state = torch.from_numpy(state)
42 changes: 35 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,32 @@ def ensure_shared_grads(model, shared_model):
shared_param._grad = param.grad


def train(rank, args, shared_model, counter, lock, optimizer=None):
def train(rank, args, shared_model, counter, lock, optimizer=None, DEBUG=False):
if DEBUG:
print('rank: {}'.format(rank))
torch.manual_seed(args.seed + rank)

env = create_atari_env(args.env_name)
env.seed(args.seed + rank)

model = ActorCritic(env.observation_space.shape[0], env.action_space)
model.train()
if DEBUG:
print('agent{:03d}: model created'.format(rank))

if optimizer is None:
optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

model.train()
if DEBUG:
print('agent{:03d}: optimizer created'.format(rank))

state = env.reset()
state = torch.from_numpy(state)
done = True

episode_length = 0
while True:
if DEBUG:
print('agent{:03d}: while loop'.format(rank))
# Sync with the shared model
model.load_state_dict(shared_model.state_dict())
if done:
Expand All @@ -49,24 +56,45 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):
entropies = []

for step in range(args.num_steps):
if DEBUG:
print('agent{:03d}: for loop p1'.format(rank))

episode_length += 1
value, logit, (hx, cx) = model((Variable(state.unsqueeze(0)),
(hx, cx)))
prob = F.softmax(logit)
log_prob = F.log_softmax(logit)
if DEBUG:
print('agent{:03d}: for loop p1.1'.format(rank))
print(state.unsqueeze(0).size())
with lock:
value, logit, (hx, cx) = model((Variable(state.unsqueeze(0)), (hx, cx)))
if DEBUG:
print('agent{:03d}: for loop p2'.format(rank))
# prob = F.softmax(logit)
prob = F.softmax(logit, dim=1)
log_prob = F.log_softmax(logit, dim=1)
if DEBUG:
print('agent{:03d}: for loop p3'.format(rank))
entropy = -(log_prob * prob).sum(1, keepdim=True)
entropies.append(entropy)
if DEBUG:
print('agent{:03d}: for loop p4'.format(rank))

action = prob.multinomial(num_samples=1).data
log_prob = log_prob.gather(1, Variable(action))
if DEBUG:
print('agent{:03d}: for loop p5'.format(rank))

state, reward, done, _ = env.step(action.numpy())
done = done or episode_length >= args.max_episode_length
reward = max(min(reward, 1), -1)
if DEBUG:
print('agent{:03d}: for loop p6'.format(rank))

with lock:
counter.value += 1
if DEBUG:
print('agent{:03d}: counter plus {:09d}'.format(rank, counter.value))

if DEBUG:
print('agent{:03d}: for loop p7'.format(rank))
if done:
episode_length = 0
state = env.reset()
Expand Down
80 changes: 80 additions & 0 deletions visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
from collections import deque

import torch
import torch.nn.functional as F
from torch.autograd import Variable

import gym
import imageio

from envs import create_atari_env
from model import ActorCritic


def visualize(env_name, model_path, render=False):
with torch.no_grad():
torch.manual_seed(0)

env = create_atari_env(env_name)
env_orig = gym.make(env_name)
env.seed(0)
env_orig.seed(0)

model = ActorCritic(env.observation_space.shape[0], env.action_space)
model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
model.eval()

states = []
state, state_orig = env.reset(), env_orig.reset()
states.append(state_orig)
state = torch.from_numpy(state)
reward_sum = 0

# a quick hack to prevent the agent from stucking
actions = deque(maxlen=1000)
episode_length = 0

cx = Variable(torch.zeros(1, 256))
hx = Variable(torch.zeros(1, 256))
while True:
episode_length += 1

value, logit, (hx, cx) = model((Variable(state.unsqueeze(0)), (hx, cx)))
prob = F.softmax(logit, dim=1)
action = prob.max(1, keepdim=True)[1].data.numpy()

state, reward, done, _ = env.step(action[0, 0])
state_orig, _, _, _ = env_orig.step(action[0, 0])
states.append(state_orig)
state = torch.from_numpy(state)
if render:
env.render()
done = done or episode_length >= 10000
reward_sum += reward

# a quick hack to prevent the agent from stucking
actions.append(action[0, 0])
if actions.count(actions[0]) == actions.maxlen:
print('stuck in infinite loop')
done = True

if done:
print('episode_length: {}'.format(episode_length))
print('reward_sum: {}'.format(reward_sum))
break

cx = Variable(cx.data)
hx = Variable(hx.data)

env.close(), env_orig.close()
return states

if __name__ == '__main__':
MODEL_DIR = 'models'
VIDEO_DIR = 'videos'
for model_name in sorted(os.listdir(MODEL_DIR)):
print('play with model {}'.format(model_name))
result = visualize('BreakoutNoFrameskip-v4', os.path.join(MODEL_DIR, model_name))
imageio.mimwrite(os.path.join(VIDEO_DIR, model_name.replace('.pth', '.mp4')), result , fps=60)
print('-' * 40)