Skip to content

Commit

Permalink
More general autoencoder flow
Browse files Browse the repository at this point in the history
  • Loading branch information
brainsqueeze committed Sep 27, 2021
1 parent 053c0e3 commit b313c8a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

setup(
name="text2vec",
version="1.1.6",
version="1.1.7",
description="Building blocks for text vectorization and embedding",
author="Dave Hollander",
author_url="https://github.com/brainsqueeze",
Expand Down
16 changes: 10 additions & 6 deletions text2vec/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ def train_step(self, data):
x_enc, context = self.encode_layer(x_enc, mask=enc_mask, training=True)

with tf.name_scope('Decoding'):
targets = decoding_tok[:, 1:] # skip the <s> token with the slice on axis=1
if isinstance(self.embed_layer, TokenEmbed):
encoding_tok = tf.ragged.map_flat_values(self.embed_layer.table.lookup, encoding_tok)
targets = self.embed_layer.slicer(encoding_tok)
targets = tf.ragged.map_flat_values(self.embed_layer.table.lookup, targets)
targets = self.embed_layer.slicer(targets)

decoding_tok, dec_mask, _ = self.embed_layer(decoding_tok)
decoding_tok, dec_mask, _ = self.embed_layer(decoding_tok[:, :-1]) # skip </s>
decoding_tok = self.decode_layer(
x_enc=x_enc,
enc_mask=enc_mask,
Expand All @@ -117,6 +118,7 @@ def train_step(self, data):
labels=targets.to_tensor(default_value=0)
)
loss = loss * dec_mask
loss = tf.math.reduce_sum(loss, axis=1)
loss = tf.reduce_mean(loss)

gradients = tape.gradient(loss, self.trainable_variables)
Expand Down Expand Up @@ -220,11 +222,12 @@ def train_step(self, data):
x_enc, context, *states = self.encode_layer(x_enc, mask=enc_mask, training=True)

with tf.name_scope('Decoding'):
targets = decoding_tok[:, 1:] # skip the <s> token with the slice on axis=1
if isinstance(self.embed_layer, TokenEmbed):
encoding_tok = tf.ragged.map_flat_values(self.embed_layer.table.lookup, encoding_tok)
targets = self.embed_layer.slicer(encoding_tok)
targets = tf.ragged.map_flat_values(self.embed_layer.table.lookup, targets)
targets = self.embed_layer.slicer(targets)

decoding_tok, dec_mask, _ = self.embed_layer(decoding_tok)
decoding_tok, dec_mask, _ = self.embed_layer(decoding_tok[:, :-1])
decoding_tok = self.decode_layer(
x_enc=x_enc,
enc_mask=enc_mask,
Expand All @@ -242,6 +245,7 @@ def train_step(self, data):
labels=targets.to_tensor(default_value=0)
)
loss = loss * dec_mask
loss = tf.math.reduce_sum(loss, axis=1)
loss = tf.reduce_mean(loss)

gradients = tape.gradient(loss, self.trainable_variables)
Expand Down

0 comments on commit b313c8a

Please sign in to comment.