-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathDQN-keras.py
142 lines (119 loc) · 4.32 KB
/
DQN-keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-
"""
Created on Sat Oct 14 14:16:10 2017
@author: Ali Darwish
"""
import numpy as np
import gym
import gym.spaces
from gym import envs
#import gym_pull
#gym_pull.pull('')
#print(envs.registry.all())
import pickle
import gym_trackairsim.envs
import gym_trackairsim
import argparse
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Conv2D, Permute, Concatenate, Dropout
from keras.optimizers import Adam
import keras.backend as K
from collections import deque
from PIL import Image
from keras.callbacks import History, TensorBoard
import gym
import numpy as np
import random
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import Adam
from collections import deque
class DQN:
def __init__(self, env):
self.env = env
self.memory = deque(maxlen=2000)
self.gamma = 0.85
self.epsilon = 1.0
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.learning_rate = 0.005
self.tau = .125
self.model = self.create_model()
self.target_model = self.create_model()
def create_model(self):
model = Sequential()
state_shape = self.env.observation_space.shape
model.add(Dense(24, input_dim=state_shape[0], activation="relu"))
model.add(Dense(48, activation="relu"))
model.add(Dense(24, activation="relu"))
model.add(Dense(self.env.action_space.n))
model.compile(loss="mean_squared_error",
optimizer=Adam(lr=self.learning_rate))
return model
def act(self, state):
self.epsilon *= self.epsilon_decay
self.epsilon = max(self.epsilon_min, self.epsilon)
if np.random.random() < self.epsilon:
return self.env.action_space.sample()
return np.argmax(self.model.predict(state)[0])
def remember(self, state, action, reward, new_state, done):
self.memory.append([state, action, reward, new_state, done])
def replay(self):
batch_size = 32
if len(self.memory) < batch_size:
return
samples = random.sample(self.memory, batch_size)
for sample in samples:
state, action, reward, new_state, done = sample
target = self.target_model.predict(state)
if done:
target[0][action] = reward
else:
Q_future = max(self.target_model.predict(new_state)[0])
target[0][action] = reward + Q_future * self.gamma
self.model.fit(state, target, epochs=1, verbose=0)
def target_train(self):
weights = self.model.get_weights()
target_weights = self.target_model.get_weights()
for i in range(len(target_weights)):
target_weights[i] = weights[i] * self.tau + target_weights[i] * (1 - self.tau)
self.target_model.set_weights(target_weights)
def save_model(self, fn):
self.model.save(fn)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'test'], default='train')
parser.add_argument('--env-name', type=str, default='TrackSimEnv-v1')
parser.add_argument('--weights', type=str, default=None)
args = parser.parse_args()
# Get the environment and extract the number of actions.
env = gym.make(args.env_name)
gamma = 0.9
epsilon = .95
trials = 1000
trial_len = 500
dqn_agent = DQN(env=env)
steps = []
for trials in range(tials):
cur_state = env.reset().reshape(1,2)
for step in range(trial_len):
action = dqn_agent.act(cur_state)
new_state, reward,done,_=env.step(action)
new_state = new_state.reshape(1,2)
dqn_agent.remeber(cur_state, nb_actions, reward, new_state,done)
dqn_agent.replay()
dqn_agent.target_train()
cur_state = new_state
if done:
break
if step >= 199:
print("Failed to complete in trial {}".format(trial))
if step % 10 == 0:
dqn_agent.save_model("trial-{}.model".format(trial))
else:
print("Completed in {} trials".format(trial))
dqn_agent.save_model("success.model")
break
if __name__ == "__main__":
main()
#nb_actions = env.action_space.n