-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCriticNetwork.py
77 lines (62 loc) · 3.28 KB
/
CriticNetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import tensorflow as tf
HIDDEN1_UNITS = 300
HIDDEN2_UNITS = 600
class CriticNetwork(object):
def __init__(self, sess, state_shape, action_size, BATCH_SIZE, TAU, LEARNING_RATE):
self.sess = sess
self.BATCH_SIZE = BATCH_SIZE
self.TAU = TAU
self.LEARNING_RATE = LEARNING_RATE
self.action_size = action_size
#Now create the model
self.state = tf.placeholder(tf.float32, shape=[None]+state_shape)
self.action = tf.placeholder(tf.float32, shape=[None,action_size])
self.y = tf.placeholder(tf.float32, shape=[None, 1])
self.output, self.weights = self.create_critic_network(state_shape, action_size)
self.target_output, self.target_weights = self.create_critic_network(state_shape, action_size)
self.action_grads = tf.gradients(self.output, self.action) #GRADIENTS for policy update
self.loss = tf.losses.mean_squared_error(self.y, self.output)
self.optimize = tf.train.AdamOptimizer(LEARNING_RATE).minimize(self.loss, var_list=self.weights)
self.sess.run(tf.global_variables_initializer())
def gradients(self, states, actions):
return self.sess.run(self.action_grads, feed_dict={
self.state: states,
self.action: actions
})[0]
def predict(self, states, actions):
self.sess.run(self.output, feed_dict={
self.state: states,
self.action:actions
})
def target_predict(self, states, actions):
self.sess.run(self.target_output, feed_dict={
self.state: states,
self.action:actions
})
def train(self, states, actions, y):
print "training critic network", states.shape, actions.shape
n_iterations = 10
for _ in range(n_iterations):
self.sess.run(self.optimize, feed_dict={
self.state: states,
self.action: actions
})
def target_train(self):
critic_weights = self.sess.run(self.weights)
critic_target_weights = self.sess.run(self.target_weights)
for i in xrange(len(critic_weights)):
critic_target_weights[i] = self.TAU * critic_weights[i] + (1 - self.TAU)* critic_target_weights[i]
self.sess.run(tf.assign(self.target_weights,critic_target_weights))
def create_critic_network(self, state_shape, action_dim): #assuming state is just the image obtained from the observation
print("Now we build the critic model")
conv1 = tf.layers.conv2d(self.state, 32, kernel_size=(3,3),padding='valid', activation=tf.nn.relu)
conv2 = tf.layers.conv2d(conv1, 64, kernel_size=(3,3),padding='valid', activation=tf.nn.relu)
conv3 = tf.layers.conv2d(conv2, 128, kernel_size=(3,3),padding='valid', activation=tf.nn.relu)
conv3_flat = tf.contrib.layers.flatten(conv3)
dense1 = tf.layers.dense(conv3_flat, HIDDEN1_UNITS, activation=tf.nn.relu)
a_dense1 = tf.layers.dense(self.action, HIDDEN1_UNITS, activation=tf.nn.relu)
new_concat = tf.contrib.layers.flatten(tf.concat([dense1, a_dense1], axis=1))
dense2 = tf.layers.dense(new_concat, HIDDEN2_UNITS, activation=tf.nn.relu)
output = tf.layers.dense(dense2, 1)
return output, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)