Skip to content

Commit

Permalink
Add predict script
Browse files Browse the repository at this point in the history
  • Loading branch information
dennybritz committed Jul 3, 2016
1 parent 7842a27 commit 67b70f7
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 7 deletions.
3 changes: 3 additions & 0 deletions models/dual_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def dual_encoder_model(
# Apply sigmoid to convert logits to probabilities
probs = tf.sigmoid(logits)

if mode == tf.contrib.learn.ModeKeys.INFER:
return probs, None

# Calculate the binary cross-entropy loss
losses = tf.nn.sigmoid_cross_entropy_with_logits(logits, tf.to_float(targets))

Expand Down
8 changes: 7 additions & 1 deletion scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
VALIDATION_PATH = os.path.join(FLAGS.input_dir, "valid.csv")
TEST_PATH = os.path.join(FLAGS.input_dir, "test.csv")

def tokenizer_fn(iterator):
return (x.split(" ") for x in iterator)

def create_csv_iter(filename):
"""
Returns an iterator over a CSV file. Skips the header.
Expand All @@ -45,7 +48,7 @@ def create_vocab(input_iter, min_frequency):
vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(
FLAGS.max_sentence_len,
min_frequency=min_frequency,
tokenizer_fn=lambda iterator: (x.split(" ") for x in iterator))
tokenizer_fn=tokenizer_fn)
vocab_processor.fit(input_iter)
return vocab_processor

Expand Down Expand Up @@ -158,6 +161,9 @@ def write_vocabulary(vocab_processor, outfile):
write_vocabulary(
vocab, os.path.join(FLAGS.output_dir, "vocabulary.txt"))

# Save vocab processor
vocab.save(os.path.join(FLAGS.output_dir, "vocab_processor.bin"))

# Create validation.tfrecords
create_tfrecords_file(
input_filename=VALIDATION_PATH,
Expand Down
6 changes: 3 additions & 3 deletions udc_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

# Model Parameters
tf.flags.DEFINE_integer("embedding_dim", 100, "Dimensionality of the embeddings")
tf.flags.DEFINE_integer("rnn_dim", 128, "Dimensionality of the RNN cell")
tf.flags.DEFINE_integer("max_context_len", 120, "Truncate contexts to this length")
tf.flags.DEFINE_integer("max_utterance_len", 60, "Truncate utterance to this length")
tf.flags.DEFINE_integer("rnn_dim", 256, "Dimensionality of the RNN cell")
tf.flags.DEFINE_integer("max_context_len", 160, "Truncate contexts to this length")
tf.flags.DEFINE_integer("max_utterance_len", 80, "Truncate utterance to this length")

# Pre-trained embeddings
tf.flags.DEFINE_string("glove_path", None, "Path to pre-trained Glove vectors")
Expand Down
14 changes: 12 additions & 2 deletions udc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def get_id_feature(features, key, len_key, max_len):
ids = features[key]
ids_len = tf.squeeze(features[len_key])
ids_len = tf.squeeze(features[len_key], [1])
ids_len = tf.minimum(ids_len, tf.constant(max_len, dtype=tf.int64))
return ids, ids_len

Expand Down Expand Up @@ -41,8 +41,18 @@ def model_fn(features, targets, mode):
train_op = create_train_op(loss, hparams)
return probs, loss, train_op

if mode == tf.contrib.learn.ModeKeys.INFER:
probs, loss = model_impl(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
None)
return probs, 0.0, None

if mode != tf.contrib.learn.ModeKeys.TRAIN:
if mode == tf.contrib.learn.ModeKeys.EVAL:

# We have 10 exampels per record, so we accumulate them
all_contexts = [context]
Expand Down
58 changes: 58 additions & 0 deletions udc_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import time
import itertools
import sys
import numpy as np
import tensorflow as tf
import udc_model
import udc_hparams
import udc_metrics
import udc_inputs
from models.dual_encoder import dual_encoder_model
from models.helpers import load_vocab

tf.flags.DEFINE_string("model_dir", None, "Directory to load model checkpoints from")
tf.flags.DEFINE_string("vocab_processor_file", "./data/vocab_processor.bin", "Saved vocabulary processor file")
FLAGS = tf.flags.FLAGS

if not FLAGS.model_dir:
print("You must specify a model directory")
sys.exit(1)

def tokenizer_fn(iterator):
return (x.split(" ") for x in iterator)

# Load vocabulary
vp = tf.contrib.learn.preprocessing.VocabularyProcessor.restore(
FLAGS.vocab_processor_file)

# Load your own data here
INPUT_CONTEXT = "Example context"
POTENTIAL_RESPONSES = ["Response 1", "Response 2"]

def get_features(context, utterance):
context_matrix = np.array(list(vp.transform([context])))
utterance_matrix = np.array(list(vp.transform([utterance])))
context_len = len(context.split(" "))
utterance_len = len(utterance.split(" "))
features = {
"context": tf.convert_to_tensor(context_matrix, dtype=tf.int64),
"context_len": tf.constant(context_len, shape=[1,1], dtype=tf.int64),
"utterance": tf.convert_to_tensor(utterance_matrix, dtype=tf.int64),
"utterance_len": tf.constant(utterance_len, shape=[1,1], dtype=tf.int64),
}
return features, None

if __name__ == "__main__":
hparams = udc_hparams.create_hparams()
model_fn = udc_model.create_model_fn(hparams, model_impl=dual_encoder_model)
estimator = tf.contrib.learn.Estimator(model_fn=model_fn, model_dir=FLAGS.model_dir)

# Ugly hack, seems to be a bug in Tensorflow
# estimator.predict doesn't work without this line
estimator._targets_info = tf.contrib.learn.estimators.tensor_signature.TensorSignature(tf.constant(0, shape=[1,1]))

print("Context: {}".format(INPUT_CONTEXT))
for r in POTENTIAL_RESPONSES:
prob = estimator.predict(input_fn=lambda: get_features(INPUT_CONTEXT, r))
print("{}: {:g}".format(r, prob[0,0]))
2 changes: 1 addition & 1 deletion udc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
tf.flags.DEFINE_string("test_file", "./data/test.tfrecords", "Path of test data in TFRecords format")
tf.flags.DEFINE_string("model_dir", None, "Directory to load model checkpoints from")
tf.flags.DEFINE_integer("loglevel", 20, "Tensorflow log level")
tf.flags.DEFINE_integer("test_batch_size", 16, "Tensorflow log level")
tf.flags.DEFINE_integer("test_batch_size", 16, "Batch size for testing")
FLAGS = tf.flags.FLAGS

if not FLAGS.model_dir:
Expand Down

0 comments on commit 67b70f7

Please sign in to comment.