-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy path8.1-text-generation-with-lstm.Rmd
184 lines (128 loc) · 8.22 KB
/
8.1-text-generation-with-lstm.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
---
title: "Text generation with LSTM"
output:
html_notebook:
theme: cerulean
highlight: textmate
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(warning = FALSE, message = FALSE)
```
***
This notebook contains the code samples found in Chapter 8, Section 1 of [Deep Learning with R](https://www.manning.com/books/deep-learning-with-r). Note that the original text features far more content, in particular further explanations and figures: in this notebook, you will only find source code and related comments.
***
## Implementing character-level LSTM text generation
Let's put these ideas in practice in a Keras implementation. The first thing we need is a lot of text data that we can use to learn a language model. You could use any sufficiently large text file or set of text files -- Wikipedia, the Lord of the Rings, etc. In this example we will use some of the writings of Nietzsche, the late-19th century German philosopher (translated to English). The language model we will learn will thus be specifically a model of Nietzsche's writing style and topics of choice, rather than a more generic model of the English language.
## Preparing the data
Let's start by downloading the corpus and converting it to lowercase:
```{r}
library(keras)
library(stringr)
path <- get_file(
"nietzsche.txt",
origin = "https://s3.amazonaws.com/text-datasets/nietzsche.txt"
)
text <- tolower(readChar(path, file.info(path)$size))
cat("Corpus length:", nchar(text), "\n")
```
Next, you'll extract partially overlapping sequences of length `maxlen`, one-hot encode them, and pack them in a 3D array `x` of shape `(sequences, maxlen, unique_characters)`. Simultaneously, you'll prepare an array `y` containing the corresponding targets: the one-hot-encoded characters that come after each extracted sequence.
```{r}
maxlen <- 60 # Length of extracted character sequences
step <- 3 # We sample a new sequence every `step` characters
text_indexes <- seq(1, nchar(text) - maxlen, by = step)
# This holds our extracted sequences
sentences <- str_sub(text, text_indexes, text_indexes + maxlen - 1)
# This holds the targets (the follow-up characters)
next_chars <- str_sub(text, text_indexes + maxlen, text_indexes + maxlen)
cat("Number of sequences: ", length(sentences), "\n")
# List of unique characters in the corpus
chars <- unique(sort(strsplit(text, "")[[1]]))
cat("Unique characters:", length(chars), "\n")
# Dictionary mapping unique characters to their index in `chars`
char_indices <- 1:length(chars)
names(char_indices) <- chars
# Next, one-hot encode the characters into binary arrays.
cat("Vectorization...\n")
x <- array(0L, dim = c(length(sentences), maxlen, length(chars)))
y <- array(0L, dim = c(length(sentences), length(chars)))
for (i in 1:length(sentences)) {
sentence <- strsplit(sentences[[i]], "")[[1]]
for (t in 1:length(sentence)) {
char <- sentence[[t]]
x[i, t, char_indices[[char]]] <- 1
}
next_char <- next_chars[[i]]
y[i, char_indices[[next_char]]] <- 1
}
```
## Building the network
This network is a single LSTM layer followed by a dense classifier and softmax over all possible characters. But note that recurrent neural networks aren't the only way to do sequence data generation; 1D convnets also have proven extremely successful at this task in recent times.
```{r}
model <- keras_model_sequential() %>%
layer_lstm(units = 128, input_shape = c(maxlen, length(chars))) %>%
layer_dense(units = length(chars), activation = "softmax")
```
Since our targets are one-hot encoded, we will use `categorical_crossentropy` as the loss to train the model:
```{r}
optimizer <- optimizer_rmsprop(lr = 0.01)
model %>% compile(
loss = "categorical_crossentropy",
optimizer = optimizer
)
```
## Training the language model and sampling from it
Given a trained model and a seed text snippet, we generate new text by repeatedly:
* 1) Drawing from the model a probability distribution over the next character given the text available so far
* 2) Reweighting the distribution to a certain "temperature"
* 3) Sampling the next character at random according to the reweighted distribution
* 4) Adding the new character at the end of the available text
This is the code we use to reweight the original probability distribution coming out of the model, and draw a character index from it (the "sampling function"):
```{r}
sample_next_char <- function(preds, temperature = 1.0) {
preds <- as.numeric(preds)
preds <- log(preds) / temperature
exp_preds <- exp(preds)
preds <- exp_preds / sum(exp_preds)
which.max(t(rmultinom(1, 1, preds)))
}
```
Finally, the following loop repeatedly trains and generates text. You begin generating text using a range of different temperatures after every epoch. This allows you to see how the generated text evolves as the model begins to converge, as well as the impact of temperature in the sampling strategy.
```{r}
for (epoch in 1:60) {
cat("epoch", epoch, "\n")
# Fit the model for 1 epoch on the available training data
model %>% fit(x, y, batch_size = 128, epochs = 1)
# Select a text seed at random
start_index <- sample(1:(nchar(text) - maxlen - 1), 1)
seed_text <- str_sub(text, start_index, start_index + maxlen - 1)
cat("--- Generating with seed:", seed_text, "\n\n")
for (temperature in c(0.2, 0.5, 1.0, 1.2)) {
cat("------ temperature:", temperature, "\n")
cat(seed_text, "\n")
generated_text <- seed_text
# We generate 400 characters
for (i in 1:400) {
sampled <- array(0, dim = c(1, maxlen, length(chars)))
generated_chars <- strsplit(generated_text, "")[[1]]
for (t in 1:length(generated_chars)) {
char <- generated_chars[[t]]
sampled[1, t, char_indices[[char]]] <- 1
}
preds <- model %>% predict(sampled, verbose = 0)
next_index <- sample_next_char(preds[1,], temperature)
next_char <- chars[[next_index]]
generated_text <- paste0(generated_text, next_char)
generated_text <- substring(generated_text, 2)
cat(next_char)
}
cat("\n\n")
}
}
```
As you can see, a low temperature results in extremely repetitive and predictable text, but where local structure is highly realistic: in particular, all words (a word being a local pattern of characters) are real English words. With higher temperatures, the generated text becomes more interesting, surprising, even creative; it may sometimes invent completely new words that sound somewhat plausible (such as "eterned" or "troveration"). With a high temperature, the local structure starts breaking down and most words look like semi-random strings of characters. Without a doubt, here 0.5 is the most interesting temperature for text generation in this specific setup. Always experiment with multiple sampling strategies! A clever balance between learned structure and randomness is what makes generation interesting.
Note that by training a bigger model, longer, on more data, you can achieve generated samples that will look much more coherent and realistic than ours. But of course, don't expect to ever generate any meaningful text, other than by random chance: all we are doing is sampling data from a statistical model of which characters come after which characters. Language is a communication channel, and there is a distinction between what communications are about, and the statistical structure of the messages in which communications are encoded. To evidence this distinction, here is a thought experiment: what if human language did a better job at compressing communications, much like our computers do with most of our digital communications? Then language would be no less meaningful, yet it would lack any intrinsic statistical structure, thus making it impossible to learn a language model like we just did.
## Take aways
* We can generate discrete sequence data by training a model to predict the next tokens(s) given previous tokens.
* In the case of text, such a model is called a "language model" and could be based on either words or characters.
* Sampling the next token requires balance between adhering to what the model judges likely, and introducing randomness.
* One way to handle this is the notion of _softmax temperature_. Always experiment with different temperatures to find the "right" one.