Skip to content

Commit

Permalink
removed stop_gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjcohen committed Jan 20, 2018
1 parent 5460d7d commit 68f95b6
Showing 1 changed file with 7 additions and 41 deletions.
48 changes: 7 additions & 41 deletions capsulelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,31 +126,10 @@ def call(self, inputs, training=None):
# inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems=inputs_tiled)

"""
# Begin: routing algorithm V1, dynamic ------------------------------------------------------------#
# The prior for coupling coefficient, initialized as zeros.
b = K.zeros(shape=[self.batch_size, self.num_capsule, self.input_num_capsule])
def body(i, b, outputs):
c = tf.nn.softmax(b, dim=1) # dim=2 is the num_capsule dimension
outputs = squash(K.batch_dot(c, inputs_hat, [2, 2]))
if i != 1:
b = b + K.batch_dot(outputs, inputs_hat, [2, 3])
return [i-1, b, outputs]
cond = lambda i, b, inputs_hat: i > 0
loop_vars = [K.constant(self.routings), b, K.sum(inputs_hat, 2, keepdims=False)]
shape_invariants = [tf.TensorShape([]),
tf.TensorShape([None, self.num_capsule, self.input_num_capsule]),
tf.TensorShape([None, self.num_capsule, self.dim_capsule])]
_, _, outputs = tf.while_loop(cond, body, loop_vars, shape_invariants)
# End: routing algorithm V1, dynamic ------------------------------------------------------------#
"""
# Begin: Routing algorithm ---------------------------------------------------------------------#
# In forward pass, `inputs_hat_stopped` = `inputs_hat`;
# In backward, no gradient can flow from `inputs_hat_stopped` back to `inputs_hat`.
inputs_hat_stopped = K.stop_gradient(inputs_hat)
# inputs_hat_stopped = K.stop_gradient(inputs_hat)

# The prior for coupling coefficient, initialized as zeros.
# b.shape = [None, self.num_capsule, self.input_num_capsule].
# b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])
Expand All @@ -161,24 +140,11 @@ def body(i, b, outputs):
# c.shape=[batch_size, num_capsule, input_num_capsule]
c = tf.nn.softmax(b, dim=1)

# At last iteration, use `inputs_hat` to compute `outputs` in order to backpropagate gradient
if i == self.routings - 1:
# c.shape = [batch_size, num_capsule, input_num_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension,
# then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
# outputs.shape=[None, num_capsule, dim_capsule]
outputs = squash(K.batch_dot(c, inputs_hat, [2, 2])) # [None, 10, 16]
else: # Otherwise, use `inputs_hat_stopped` to update `b`. No gradients flow on this path.
outputs = squash(K.batch_dot(c, inputs_hat_stopped, [2, 2]))

# outputs.shape = [None, num_capsule, dim_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension,
# then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule].
# b.shape=[batch_size, num_capsule, input_num_capsule]
b += K.batch_dot(outputs, inputs_hat_stopped, [2, 3])
# End: Routing algorithm -----------------------------------------------------------------------#
# outputs.shape=[None, num_capsule, dim_capsule]
outputs = squash(K.batch_dot(c, inputs_hat, [2, 2])) # [None, 10, 16]

# b.shape=[batch_size, num_capsule, input_num_capsule]
b += K.batch_dot(outputs, inputs_hat, [2, 3])

return outputs

Expand Down

0 comments on commit 68f95b6

Please sign in to comment.