Accelerator API
@@ -90,46 +91,58 @@Guides
-The Accelerator API is a simplified port of the Hugging Face Accelerate library. It allows users to avoid the boilerplate code necessary to write training loops that work correctly on both devices. Currently it only handles CPU and single-GPU usage.
-This API is meant to be the most flexible way you can use the luz package. With the Accelerator API, you write the raw torch training loop and, with a few code changes, you automatically handle device placement of the model, optimizers and dataloaders, so you don’t need to add many $to(device="cuda")
calls in your code or think about the order in which to create the model and optimizers.
The Accelerator API is a simplified port of the Hugging Face Accelerate library. +It allows users to avoid the boilerplate code necessary to write +training loops that work correctly on both devices. Currently it only +handles CPU and single-GPU usage.
+This API is meant to be the most flexible way you can use the luz
+package. With the Accelerator API, you write the raw torch training loop
+and, with a few code changes, you automatically handle device placement
+of the model, optimizers and dataloaders, so you don’t need to add many
+$to(device="cuda")
calls in your code or think about the
+order in which to create the model and optimizers.
Example
-The Accelerator API is best explained by showing an example diff in a raw torch training loop.
-library(torch)
-+ library(luz)
-
-+ acc <- accelerator()
-- device <- "cpu"
-
-data <- tensor_dataset(
- x = torch_randn(100, 10),
- y = torch_rand(100, 1)
-)
-
-dl <- dataloader(data, batch_size = 10)
-
-model <- nn_linear(10, 1)
-- model$to(device = device)
-opt <- optim_adam(model$parameters)
-
-+ c(model, opt, dl) %<-% acc$prepare(model, opt, dl)
-
-model$train()
-coro::loop(for (batch in dl) {
-
- opt$zero_grad()
-
-- preds <- model(batch$x$to(device = device))
-+ preds <- model(batch$x)
-- loss <- nnf_mse_loss(preds, batch$y$to(device = device))
-+ loss <- nnf_mse_loss(preds, batch$y)
-
- loss$backward()
- opt$step()
-})
With the code changes shown, you no longer need to manually move data and parameters between devices, which makes your code easier to read and less error prone.
-You can find additional documentation using help(accelerator)
.
The Accelerator API is best explained by showing an example diff in a +raw torch training loop.
+
+ library(torch)+ library(luz)
+
++ acc <- accelerator()
+- device <- "cpu"
+
+
+ data <- tensor_dataset(
+ x = torch_randn(100, 10),
+ y = torch_rand(100, 1)
+ )
+
+ dl <- dataloader(data, batch_size = 10)
+
+ model <- nn_linear(10, 1)- model$to(device = device)
+
+ opt <- optim_adam(model$parameters)
++ c(model, opt, dl) %<-% acc$prepare(model, opt, dl)
+
+
+ model$train()
+ coro::loop(for (batch in dl) {
+
+ opt$zero_grad()
+- preds <- model(batch$x$to(device = device))
++ preds <- model(batch$x)
+- loss <- nnf_mse_loss(preds, batch$y$to(device = device))
++ loss <- nnf_mse_loss(preds, batch$y)
+
+
+ loss$backward()
+ opt$step() })
With the code changes shown, you no longer need to manually move data +and parameters between devices, which makes your code easier to read and +less error prone.
+You can find additional documentation using
+help(accelerator)
.
Example
diff --git a/articles/accelerator_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/accelerator_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/accelerator_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/checkpoints.html b/articles/checkpoints.html
index ec51faf5..6d11dcfa 100644
--- a/articles/checkpoints.html
+++ b/articles/checkpoints.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Checkpointing your models
@@ -93,15 +94,29 @@ Guides
library(torch)
set.seed(1)
torch::torch_manual_seed(1703)
-When fitting models take too long you might want to save intermediate state to disk, if something goes wrong during training (eg. process is killed, network fails, etc) you can recover from where it stopped.
-You might also want to recover intermediate results to evaluate the model in different moments of the training, like comparing results after 10 epochs and after 30 epochs.
-This article describes luz features that are built to handle those cases. These features are optional and are enabled once you add specific callbacks to your fit
call.
+When fitting models take too long you might want to save intermediate
+state to disk, if something goes wrong during training (eg. process is
+killed, network fails, etc) you can recover from where it stopped.
+You might also want to recover intermediate results to evaluate the
+model in different moments of the training, like comparing results after
+10 epochs and after 30 epochs.
+This article describes luz features that are built to handle those
+cases. These features are optional and are enabled once you add specific
+callbacks to your fit
call.
Resuming training runs that crashed
-If you have a long training run that can crash for whatever reason (computer turned off, process kileed in cluster, etc), we recommend you to add luz_callback_autoresume()
to your list of callbacks.
-luz_callback_autoresume()
will automatically checkpoint the whole state of your model at the end of each epoch. If something fails during training you can simply rerun the same script, whithout any code changes and the checkpoint will be reloaded and the training will start from where it stopped.
-For example, lets’s take a randomly generated training dataset and a linear model to show how autoresume works.
+If you have a long training run that can crash for whatever reason
+(computer turned off, process kileed in cluster, etc), we recommend you
+to add luz_callback_autoresume()
to your list of
+callbacks.
+luz_callback_autoresume()
will automatically checkpoint
+the whole state of your model at the end of each epoch. If something
+fails during training you can simply rerun the same script, whithout any
+code changes and the checkpoint will be reloaded and the training will
+start from where it stopped.
+For example, lets’s take a randomly generated training dataset and a
+linear model to show how autoresume works.
Here’s the training data:
x <- torch_randn(1000, 10)
@@ -112,7 +127,9 @@ Resuming training runs that crashed
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
-Let’s now create a callback that simulates a random failure that could happen. This callback will just raise an R error on the 5th epoch.
+Let’s now create a callback that simulates a random failure that
+could happen. This callback will just raise an R error on the 5th
+epoch.
interrupt <- luz_callback(
"interrupt",
@@ -124,7 +141,8 @@ Resuming training runs that crashed
}
}
)
-Let’s now start training adding the luz_callback_auto_resume()
:
+Let’s now start training adding the
+luz_callback_auto_resume()
:
autoresume <- luz_callback_auto_resume(path = "state.pt")
inter <- interrupt()
@@ -140,14 +158,17 @@ Resuming training runs that crashed
#> on_epoch_end.
#> Caused by error in `self[[callback_nm]]()`:
#> ! Error on epoch 5
-To resume model training exactly from where it stopped you just need to restart fitting, using the exact same model, callbacks, etc:
+To resume model training exactly from where it stopped you just need
+to restart fitting, using the exact same model, callbacks, etc:
-With this, the model fitting process will be continued exactly from where it stopped. Records, optimizer and model state are recovered from the previous run so you can have the full results:
+With this, the model fitting process will be continued exactly from
+where it stopped. Records, optimizer and model state are recovered from
+the previous run so you can have the full results:
plot(results)
@@ -155,8 +176,12 @@ Resuming training runs that crashed
Checkpointing
-Sometimes you want to have more control over how checkpoints are handled. In this case you can use luz_callback_model_checkpoint()
to save checkpoints to a specified file or directory.
-Let’s use the same example as in the resuming section: We first generate some data.
+Sometimes you want to have more control over how checkpoints are
+handled. In this case you can use
+luz_callback_model_checkpoint()
to save checkpoints to a
+specified file or directory.
+Let’s use the same example as in the resuming section: We first
+generate some data.
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
@@ -166,7 +191,8 @@ Checkpointingsetup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
-
Let’s now fit the model using luz_callback_model_checkpoint()
.
+Let’s now fit the model using
+luz_callback_model_checkpoint()
.
checkpoint <- luz_callback_model_checkpoint(
path = "checkpoints/",
@@ -178,7 +204,12 @@ Checkpointing= list(checkpoint),
verbose = FALSE
)
-You can see now that the checkpoints
directory contains files with state dumps for each epoch. By default, luz_callback_model_checkpoint
will save the state for each epochs and format the name including the resulting loss. This can be configured withing the path parameter, see ?luz_callback_model_checkpoint
for details.
+You can see now that the checkpoints
directory contains
+files with state dumps for each epoch. By default,
+luz_callback_model_checkpoint
will save the state for each
+epochs and format the name including the resulting loss. This can be
+configured withing the path parameter, see
+?luz_callback_model_checkpoint
for details.
fs::dir_ls("checkpoints")
#> checkpoints/epoch-01-train_loss-1.237.pt
@@ -191,11 +222,21 @@ Checkpointing#> checkpoints/epoch-08-train_loss-0.998.pt
#> checkpoints/epoch-09-train_loss-1.001.pt
#> checkpoints/epoch-10-train_loss-1.002.pt
-Finally, you can load a specific checkpoint to the fitted
result using luz_load_checkpoint
. Note that loading the checkpoint into a a luz_fitted_module
is going to modify the model weights in-place.
+Finally, you can load a specific checkpoint to the
+fitted
result using luz_load_checkpoint
. Note
+that loading the checkpoint into a a luz_fitted_module
is
+going to modify the model weights in-place.
luz_load_checkpoint(results, fs::dir_ls("checkpoints")[1])
-You can then start making predictions, or evaluate your model using the reloeded weights.
-You might also want to start a new training run from a checkpoint. For this, you can use the luz_callback_resume_from_checkpoint()
. By default, it will only recover the model weights from the checkpoint file, but you can configure it to restore records, callback and optimizer state too. If a checkpoint directory is passed then training will resume from the last checkpoint file as returned by fs::dir_ls
.
+You can then start making predictions, or evaluate your model using
+the reloeded weights.
+You might also want to start a new training run from a checkpoint.
+For this, you can use the
+luz_callback_resume_from_checkpoint()
. By default, it will
+only recover the model weights from the checkpoint file, but you can
+configure it to restore records, callback and optimizer state too. If a
+checkpoint directory is passed then training will resume from the last
+checkpoint file as returned by fs::dir_ls
.
Here’s how you would use this callback:
resume <- luz_callback_resume_from_checkpoint(path = "checkpoints/")
@@ -209,8 +250,15 @@ Checkpointing
Custom callbacks state
-Sometimes callbacks also need to keep their internal state in order to allow continuing training exactly from where it stopped. In this case, callbacks can implement the state_dict()
and the load_state_dict()
methods that are automatically called when saving and reloading checkpoints.
-For example, suppose that you have a callback that tracks gradients for weights at every epoch. You want to use the tracked weights to further analyse the training procedure. It could be implemented like:
+Sometimes callbacks also need to keep their internal state in order
+to allow continuing training exactly from where it stopped. In this
+case, callbacks can implement the state_dict()
and the
+load_state_dict()
methods that are automatically called
+when saving and reloading checkpoints.
+For example, suppose that you have a callback that tracks gradients
+for weights at every epoch. You want to use the tracked weights to
+further analyse the training procedure. It could be implemented
+like:
cb_weight_grad <- luz_callback(
"weight_grad",
@@ -225,7 +273,14 @@ Custom callbacks state }
}
)
-In the above example, the gradients
field is a state in the callback. If training fails for some reason, gradients
will be lost. If it’s important for you to also checkpoint the callback state, you can implement the state_dict()
method must returning a named list of objects that compose the state of the callback and load_state_dict()
taking the same named list returned by state_dict()
and restoring the callback state.
+In the above example, the gradients
field is a
+state in the callback. If training fails for some
+reason, gradients
will be lost. If it’s important for you
+to also checkpoint the callback state, you can implement the
+state_dict()
method must returning a named list of objects
+that compose the state of the callback and
+load_state_dict()
taking the same named list returned by
+state_dict()
and restoring the callback state.
The callback above could be reimplemented with:
cb_weight_grad <- luz_callback(
@@ -262,7 +317,7 @@ Custom callbacks state
-Site built with pkgdown 2.0.7.9000.
+Site built with pkgdown 2.0.7.
diff --git a/articles/checkpoints_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/checkpoints_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/checkpoints_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/custom-loop.html b/articles/custom-loop.html
index b27b17d8..658563a7 100644
--- a/articles/custom-loop.html
+++ b/articles/custom-loop.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Custom loops with luz
@@ -91,15 +92,36 @@ Guides
-Luz is a higher level API for torch that is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control your need for your training loop.
-In the getting started vignette we have seen the basics of luz and how to quickly modify parts of the training loop using callbacks and custom metrics. In this document we will describe how luz allows the user to get fine-grained control of the training loop.
-Apart from the use of callbacks, there are three more ways that you can use luz (depending on how much control you need):
+Luz is a higher level API for torch that is designed to be highly
+flexible by providing a layered API that allows it to be useful no
+matter the level of control your need for your training loop.
+In the getting started vignette we have seen the basics of luz and
+how to quickly modify parts of the training loop using callbacks and
+custom metrics. In this document we will describe how luz allows the
+user to get fine-grained control of the training loop.
+Apart from the use of callbacks, there are three more ways that you
+can use luz (depending on how much control you need):
-Multiple optimizers or losses: You might be optimizing two loss functions each with its own optimizer, but you still don’t want to modify the backward()
- zero_grad()
and step()
calls. This is common in models like GANs (Generative Adversarial Networks) when you have competing neural networks trained with different losses and optimizers.
-Fully flexible steps: You might want to be in control of how to call backward()
, zero_grad()
and step()
. You might also want to have more control of gradient computation. For example, you might want to use ‘virtual batch sizes’, where you accumulate the gradients for a few steps before updating the weights.
-Completely flexible loops: Your training loop can be anything you want but you still want to use luz to handle device placement of the dataloaders, optimizers and models. See vignette("accelerator")
.
+Multiple optimizers or losses: You might be
+optimizing two loss functions each with its own optimizer, but you still
+don’t want to modify the backward()
-
+zero_grad()
and step()
calls. This is common
+in models like GANs (Generative Adversarial Networks) when you have
+competing neural networks trained with different losses and
+optimizers.
+Fully flexible steps: You might want to be in
+control of how to call backward()
,
+zero_grad()
and step()
. You might also want to
+have more control of gradient computation. For example, you might want
+to use ‘virtual batch sizes’, where you accumulate the gradients for a
+few steps before updating the weights.
+Completely flexible loops: Your training loop
+can be anything you want but you still want to use luz to handle device
+placement of the dataloaders, optimizers and models. See
+vignette("accelerator")
.
-Let’s consider a simplified version of the net
that we implemented in the getting started vignette:
+Let’s consider a simplified version of the net
that we
+implemented in the getting started vignette:
net <- nn_module(
"Net",
@@ -128,11 +150,18 @@ Guides
Multiple optimizers
-Suppose we want to do an experiment where we train the first fully connected layer using a learning rate of 0.1 and the second one using a learning rate of 0.01. We will minimize the same nn_cross_entropy_loss()
for both, but for the first layer we want to add L1 regularization on the weights.
-In order to use luz for this, we will implement two methods in the net
module:
+Suppose we want to do an experiment where we train the first fully
+connected layer using a learning rate of 0.1 and the second one using a
+learning rate of 0.01. We will minimize the same
+nn_cross_entropy_loss()
for both, but for the first layer
+we want to add L1 regularization on the weights.
+In order to use luz for this, we will implement two methods in the
+net
module:
-set_optimizers
: returns a named list of optimizers depending on the ctx
.
-loss
: computes the loss depending on the selected optimizer.
+set_optimizers
: returns a named list of optimizers
+depending on the ctx
.
+loss
: computes the loss depending on the selected
+optimizer.
Let’s go to the code:
@@ -163,19 +192,35 @@ Multiple optimizersnnf_cross_entropy(pred, target)
}
)
-Notice that the model optimizers will be initialized according to the set_optimizers()
method’s return value (a list). In this case, we are initializing the optimizers using different model parameters and learning rates.
-The loss()
method is responsible for computing the loss that will then be back-propagated to compute gradients and update the weights. This loss()
method can access the ctx
object that will contain an opt_name
field, describing which optimizer is currently being used. Note that this function will be called once for each optimizer for each training and validation step. See help("ctx")
for complete information about the context object.
-We can finally setup
and fit
this module, however we no longer need to specify optimizers and loss functions.
+Notice that the model optimizers will be initialized according to the
+set_optimizers()
method’s return value (a list). In this
+case, we are initializing the optimizers using different model
+parameters and learning rates.
+The loss()
method is responsible for computing the loss
+that will then be back-propagated to compute gradients and update the
+weights. This loss()
method can access the ctx
+object that will contain an opt_name
field, describing
+which optimizer is currently being used. Note that this function will be
+called once for each optimizer for each training and validation step.
+See help("ctx")
for complete information about the context
+object.
+We can finally setup
and fit
this module,
+however we no longer need to specify optimizers and loss functions.
fitted <- net %>%
setup(metrics = list(luz_metric_accuracy)) %>%
fit(train_dl, epochs = 10, valid_data = test_dl)
-Now let’s re-implement this same model using the slightly more flexible approach of overriding the training and validation step.
+Now let’s re-implement this same model using the slightly more
+flexible approach of overriding the training and validation step.
Fully flexible step
-Instead of implementing the loss()
method, we can implement the step()
method. This allows us to flexibly modify what happens when training and validating for each batch in the dataset. You are now responsible for updating the weights by stepping the optimizers and back-propagating the loss.
+Instead of implementing the loss()
method, we can
+implement the step()
method. This allows us to flexibly
+modify what happens when training and validating for each batch in the
+dataset. You are now responsible for updating the weights by stepping
+the optimizers and back-propagating the loss.
The important things to notice here are:
-The step()
method is used for both training and validation. You need to be careful to only modify the weights when training. Again, you can get complete information regarding the context object using help("ctx")
.
-ctx$optimizers
is a named list holding each optimizer that was created when the set_optimizers()
method was called.
-You need to manually track the losses by saving saving them in a named list in ctx$loss
. By convention, we use the same name as the optimizer it refers to. It is good practice to detach()
them before saving to reduce memory usage.
-Callbacks that would be called inside the default step()
method like on_train_batch_after_pred
, on_train_batch_after_loss
, etc, won’t be automatically called. You can still cal them manually by adding ctx$call_callbacks("<callback name>")
inside your training step. See the code for fit_one_batch()
and valid_one_batch
to find all the callbacks that won’t be called.
-If you want luz metrics to work with your custom step()
method, you must assign ctx$pred
with the model predictions as metrics will always be called with metric$update(ctx$pred, ctx$target)
.
+The step()
method is used for both training and
+validation. You need to be careful to only modify the weights when
+training. Again, you can get complete information regarding the context
+object using help("ctx")
.
+ctx$optimizers
is a named list holding each
+optimizer that was created when the set_optimizers()
method
+was called.
+You need to manually track the losses by saving saving them in a
+named list in ctx$loss
. By convention, we use the same name
+as the optimizer it refers to. It is good practice to
+detach()
them before saving to reduce memory
+usage.
+Callbacks that would be called inside the default
+step()
method like on_train_batch_after_pred
,
+on_train_batch_after_loss
, etc, won’t be automatically
+called. You can still cal them manually by adding
+ctx$call_callbacks("<callback name>")
inside your
+training step. See the code for fit_one_batch()
and
+valid_one_batch
to find all the callbacks that won’t be
+called.
+If you want luz metrics to work with your custom
+step()
method, you must assign ctx$pred
with
+the model predictions as metrics will always be called with
+metric$update(ctx$pred, ctx$target)
.
Next steps
-In this article you learned how to customize the step()
of your training loop using luz layered functionality.
-Luz also allows more flexible modifications of the training loop described in the Accelerator vignette (vignette("accelerator")
).
-You should now be able to follow the examples marked with the ‘intermediate’ and ‘advanced’ category in the examples gallery.
+In this article you learned how to customize the step()
+of your training loop using luz layered functionality.
+Luz also allows more flexible modifications of the training loop
+described in the Accelerator vignette
+(vignette("accelerator")
).
+You should now be able to follow the examples marked with the
+‘intermediate’ and ‘advanced’ category in the examples
+gallery.
@@ -248,7 +317,7 @@ Next steps
-Site built with pkgdown 2.0.7.9000.
+Site built with pkgdown 2.0.7.
diff --git a/articles/custom-loop_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/custom-loop_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/custom-loop_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/chargpt.html b/articles/examples/chargpt.html
index 98f09978..270e7efa 100644
--- a/articles/examples/chargpt.html
+++ b/articles/examples/chargpt.html
@@ -77,7 +77,8 @@ Guides
-
+
+
CharGPT
@@ -88,15 +89,24 @@ Guides
-This example is inspired by the chargpt project by Andrey Karpathy. We are going to train character-level language model on Shakespeare texts.
+This example is inspired by the chargpt
+project by Andrey Karpathy. We are going to train character-level
+language model on Shakespeare texts.
We first load the libraries that we plan to use:
-Next we define the torch dataset that will pre-process data for the model. It splits the text into a character vector, each element containing exactly one character.
-Then lists all unique characters into the vocab
attribute. The order of the characters in the vocabulary is used to encode each character to an integer value, that will be used in the embedding layer.
-The .getitem()
method, can take chunks of block_size
characters and encode them into their integer representation.
+Next we define the torch dataset that will pre-process data for the
+model. It splits the text into a character vector, each element
+containing exactly one character.
+Then lists all unique characters into the vocab
+attribute. The order of the characters in the vocabulary is used to
+encode each character to an integer value, that will be used in the
+embedding layer.
+The .getitem()
method, can take chunks of
+block_size
characters and encode them into their integer
+representation.
url <- "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
@@ -124,8 +134,15 @@ Guides
dataset <- char_dataset(readr::read_file(url))
dataset[1] # this allows us to see an element of the dataset
-We then define the neural net we are going to train. Defining a GPT-2 model is quite verbose, so we are going to use the minhub implementation directly. You can find the full model definition here, and this code is entirely self-contained, so you don’t need to install minhub, if you don’t want to.
-We also implemented the generate
method for the model, that allows one to generate completions using the model. It applies the model in a loop, at each iteration prediction what’s the next character.
+We then define the neural net we are going to train. Defining a GPT-2
+model is quite verbose, so we are going to use the minhub implementation
+directly. You can find the full model definition here,
+and this code is entirely self-contained, so you don’t need to install
+minhub, if you don’t want to.
+We also implemented the generate
method for the model,
+that allows one to generate completions using the model. It applies the
+model in a loop, at each iteration prediction what’s the next
+character.
model <- torch::nn_module(
initialize = function(vocab_size) {
@@ -155,7 +172,8 @@ Guides
x
}
)
-Next, we implemented a callback that is used for nicely displaying generated samples during the model training:
+Next, we implemented a callback that is used for nicely displaying
+generated samples during the model training:
# samples from the model using the context.
generate <- function(model, vocab, context, ...) {
@@ -203,7 +221,8 @@ Guides
luz_callback_gradient_clip(max_norm = 1)
)
)
-One epoch, is reasonable for this dataset and takes ~1h on the M1 MBP. You can generate new samples with:
+One epoch, is reasonable for this dataset and takes ~1h on the M1
+MBP. You can generate new samples with:
context <- "O God, O God!"
text <- generate(fitted$model, dataset$vocab, context, iter = 100)
@@ -220,7 +239,7 @@ Guides
diff --git a/articles/examples/chargpt_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/chargpt_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/chargpt_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/dogs-vs-cats-binary-classification.html b/articles/examples/dogs-vs-cats-binary-classification.html
index 3e4361be..d4d6932e 100644
--- a/articles/examples/dogs-vs-cats-binary-classification.html
+++ b/articles/examples/dogs-vs-cats-binary-classification.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Binary classification
@@ -175,7 +176,7 @@ Guides
diff --git a/articles/examples/dogs-vs-cats-binary-classification_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/dogs-vs-cats-binary-classification_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/dogs-vs-cats-binary-classification_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/index.html b/articles/examples/index.html
index 4714a06a..980d72f7 100644
--- a/articles/examples/index.html
+++ b/articles/examples/index.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Examples
@@ -88,7 +89,9 @@ Guides
-This gallery of examples uses luz to train and validate a range of common deep learning architectures. The gallery also demonstrates basic and advanced usage of luz.
+This gallery of examples uses luz to train and validate a range of
+common deep learning architectures. The gallery also demonstrates basic
+and advanced usage of luz.
@@ -129,7 +134,8 @@
basic
-Builds an autoencoder for the MNIST dataset. Demonstrates overwriting the predict method
+Builds an autoencoder for the MNIST dataset. Demonstrates overwriting
+the predict method
See code
@@ -145,7 +151,8 @@
Showcases how to create a custom fully customized training step
-See code
+See
+code
@@ -219,7 +226,8 @@
intermediate
-Implements a UNET model to separate the background of images of cats and dogs.
+Implements a UNET model to separate the background of images of cats and
+dogs.
See code
@@ -240,6 +248,23 @@
+
+
+
+
+
+Training a causal language model from scratch
+
+advanced
+
+Implements datasets and trains a causal language model from scratch
+using R source code.
+
+See code
+
+
+
+
@@ -274,7 +299,7 @@
tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-autoencoder.html b/articles/examples/mnist-autoencoder.html
index 4e2ff1e5..8827238a 100644
--- a/articles/examples/mnist-autoencoder.html
+++ b/articles/examples/mnist-autoencoder.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Autoencoder
@@ -180,7 +181,7 @@ Guides
diff --git a/articles/examples/mnist-autoencoder_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-autoencoder_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-autoencoder_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-cnn-virtual-batch-size.html b/articles/examples/mnist-cnn-virtual-batch-size.html
index cca80ae9..874ecdfd 100644
--- a/articles/examples/mnist-cnn-virtual-batch-size.html
+++ b/articles/examples/mnist-cnn-virtual-batch-size.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Virtual batch size
@@ -199,7 +200,7 @@ Guides
diff --git a/articles/examples/mnist-cnn-virtual-batch-size_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-cnn-virtual-batch-size_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-cnn-virtual-batch-size_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-cnn.html b/articles/examples/mnist-cnn.html
index cbb2b246..addd84ac 100644
--- a/articles/examples/mnist-cnn.html
+++ b/articles/examples/mnist-cnn.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Simple CNN
@@ -177,7 +178,7 @@ Guides
diff --git a/articles/examples/mnist-cnn_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-cnn_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-cnn_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-dcgan.html b/articles/examples/mnist-dcgan.html
index 5343ee76..02b9214e 100644
--- a/articles/examples/mnist-dcgan.html
+++ b/articles/examples/mnist-dcgan.html
@@ -77,7 +77,8 @@ Guides
-
+
+
DCGAN
@@ -266,7 +267,7 @@ Guides
diff --git a/articles/examples/mnist-dcgan_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-dcgan_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-dcgan_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-mixup.html b/articles/examples/mnist-mixup.html
index 31937768..d9510598 100644
--- a/articles/examples/mnist-mixup.html
+++ b/articles/examples/mnist-mixup.html
@@ -77,7 +77,8 @@ Guides
-
+
+
MixUp augmentation
@@ -187,7 +188,7 @@ Guides
diff --git a/articles/examples/mnist-mixup_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-mixup_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-mixup_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-triplet.html b/articles/examples/mnist-triplet.html
index 5ae47134..fdb036ab 100644
--- a/articles/examples/mnist-triplet.html
+++ b/articles/examples/mnist-triplet.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Triplet loss
@@ -196,7 +197,7 @@ Guides
diff --git a/articles/examples/mnist-triplet_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-triplet_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-triplet_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/pets-unet.html b/articles/examples/pets-unet.html
index 6543772e..6ab3a0be 100644
--- a/articles/examples/pets-unet.html
+++ b/articles/examples/pets-unet.html
@@ -77,7 +77,8 @@ Guides
-
+
+
UNET implementation
@@ -309,7 +310,7 @@ Guides
diff --git a/articles/examples/pets-unet_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/pets-unet_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/pets-unet_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/text-classification.html b/articles/examples/text-classification.html
index f48170e1..f6437dee 100644
--- a/articles/examples/text-classification.html
+++ b/articles/examples/text-classification.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Text classification from scratch
@@ -88,14 +89,21 @@ Guides
-This example is a port of ‘Text classification from scratch’ from Keras documentation by Mark Omerick and François Chollet.
-First we implement a torch dataset that downloads and pre-process the data. The initialize method is called when we instantiate a dataset. Our implementation:
+This example is a port of ‘Text
+classification from scratch’ from Keras documentation by Mark
+Omerick and François Chollet.
+First we implement a torch dataset that downloads and pre-process the
+data. The initialize method is called when we instantiate a dataset. Our
+implementation:
-- Downloads the IMDB dataset if it doesn’t exist in the
root
directory.
+- Downloads the IMDB dataset if it doesn’t exist in the
+
root
directory.
- Extracts the files into
root
.
- Creates a tokenizer using the files in the training set.
-We also implement the .getitem
method that is used to extract a single element from the dataset and pre-process the file contents.
+We also implement the .getitem
method that is used to
+extract a single element from the dataset and pre-process the file
+contents.
library(torch)
library(tok)
@@ -174,7 +182,9 @@ Guides
train_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "train")
test_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "test")
-We now define the model we want to train. The model is a 1D convnet starting with an embedding layer and we plug a classifier at the output.
+We now define the model we want to train. The model is a 1D convnet
+starting with an embedding layer and we plug a classifier at the
+output.
model <- nn_module(
initialize = function(vocab_size, embedding_dim) {
@@ -226,7 +236,8 @@ Guides
We can finally obtain the metrics on the test dataset:
-Remember that in order to predict for texts, we need make the same pre-processing as used in the dataset definition.
+Remember that in order to predict for texts, we need make the same
+pre-processing as used in the dataset definition.
@@ -239,7 +250,7 @@ Guides
diff --git a/articles/examples/text-classification_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/text-classification_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/text-classification_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/text-generation.html b/articles/examples/text-generation.html
new file mode 100644
index 00000000..4558aaf8
--- /dev/null
+++ b/articles/examples/text-generation.html
@@ -0,0 +1,393 @@
+
+
+
+
+
+
+
+
+Training a causal language model from scratch • luz
+
+
+
+
+
+
+
+
+
+ Skip to contents
+
+
+
+
+
+
+
+
+
+ Training a causal language model from scratch
+
+
+ Source: vignettes/examples/text-generation.Rmd
+ text-generation.Rmd
+
+
+
+
+This example is an adaptation of the ‘Training a causal language
+model from scratch’ class from the Hugging
+Face NLP course.
+
+library(torch)
+library(tok)
+library(luz)
+library(minhub) # remotes::install_github("mlverse/minhub")
+#library(tidyverse)
+options(arrow.skip_nul = TRUE)
+library(arrow)
+
+Data
+
+First step is to implement a torch dataset that gathers data and
+pre-process it into a format that is suitable for training the
+model.
+That means that we need to:
+
+- Download data
+- Train a tokenizer for this dataset
+- Be able to produce sequences of tokens in the format expected by the
+model
+
+We are going to use 2 datasets available in Hugging Face Hub. The
+first contain all R packages source code available on CRAN. The second
+contains all R code that is available in GitHub data dumps. Both
+datasets are in the Parquet format. Following we implement a function
+that downloads and caches the data and then returns a single arrow table
+containing all data.
+
+read_dataset <- function(source) {
+ d <- source |>
+ hfhub::hub_snapshot(repo_type = "dataset", allow_patterns = "parquet$") |>
+ fs::path("data/r") |>
+ arrow::open_dataset() |>
+ dplyr::filter(stringr::str_detect(path, ".*\\.[rR]$")) |>
+ dplyr::select(content) |>
+ dplyr::mutate(content = arrow::cast(content, arrow::string())) |>
+ dplyr::filter(!is.na(content)) |>
+ dplyr::collect() %>%
+ # the dataset contains invalid utf8 characters...
+ # we need to remove them, otherwise we get an error from tokenizers
+ dplyr::filter(utf8::utf8_valid(content))
+}
+
+read_datasets <- function() {
+ dplyr::bind_rows(
+ read_dataset("dfalbel/cran-packages"),
+ read_dataset("dfalbel/github-r-repos")
+ )
+}
+Next we implement a function that trains a tokenizer for our
+dataset.
+
+create_tokenizer <- function(text, vocab_size, special_tokens) {
+ tok <- tok::tokenizer$new(tok::model_bpe$new())
+
+ tok$pre_tokenizer <- tok::pre_tokenizer_byte_level$new(add_prefix_space = FALSE)
+ tok$decoder <- tok::decoder_byte_level$new()
+ tok$post_processor <- tok::processor_byte_level$new(trim_offsets = FALSE)
+
+ tok$train_from_memory(
+ text,
+ tok::trainer_bpe$new(vocab_size = vocab_size, special_tokens = special_tokens)
+ )
+ tok
+}
+
+# test code to debug the tokenizer
+# data <- read_datasets()
+# tok <- create_tokenizer(data$content)
+We can finally implement the torch dataset that we are going to use
+for training the model. We are going to use the
+torch::iterable_dataset
instead of
+torch::dataset
. The main motivation is that we can’t really
+know the total number of samples in the dataset, so we can implement a
+.getitem()
method to get any arbiratrary sample. Thus we
+implement the .iter
method that returns a new sample every
+time it’s called.
+
+r_sources_dataset <- torch::iterable_dataset(
+ "r_sources_dataset",
+ initialize = function(root = ".", vocab_size = 20000, context_length = 128) {
+ self$data <- read_datasets()
+ self$context_length <- context_length
+ self$index <- sample.int(nrow(self$data))
+
+ # we only create a tokenizer if it doesn't exist, otherwise we just load it
+ tok_path <- file.path(root, glue::glue("tokenizer-{vocab_size}.json"))
+ if (!file.exists(tok_path)) {
+ self$tok <- create_tokenizer(
+ as.character(self$data$content),
+ vocab_size,
+ c("<fbegin>", "<fend>")
+ )
+ fs::dir_create(root)
+ self$tok$save(tok_path)
+ } else {
+ self$tok <- tok::tokenizer$from_file(tok_path)
+ }
+ },
+ .iter = function() {
+ i <- 1L
+ sequence <- c()
+ function() {
+ while (length(sequence) < (self$context_length + 1) && i <= nrow(self$data)) {
+ sequence <<- c(
+ sequence,
+ self$tok$encode(paste("<fbegin>", as.character(self$data$content[self$index[i]]), "<fend>"))$ids
+ )
+ i <- i + 1L
+ }
+
+ if (length(sequence) < (self$context_length + 1)) {
+ return(coro::exhausted())
+ }
+
+ on.exit({
+ sequence <<- sequence[-seq_len(self$context_length)]
+ })
+ list(
+ input_ids = sequence[seq_len(self$context_length)] + 1L,
+ labels = sequence[2:(self$context_length + 1)] + 1L
+ )
+ }
+ }
+)
+
+# debug code for the dataset
+# ds <- r_sources_dataset("~/Downloads/")
+# it <- ds$.iter()
+# it()
+# ds$tok$get_vocab_size()
+This dataset is likely too large for us to train the model on all
+documents in this example. It’s also hard to predict how long it will
+take for it to train until the end. In order to make it easier, we
+define a wraper dataset that is used to run the above dataset for a
+fixed number of steps. This is not required, but makes using luz more
+pleasant, as we can easily define for how many tokens we want to train
+our model.
+
+fixed_steps_iterable_dataset <- iterable_dataset(
+ "fixed_steps_dataset",
+ initialize = function(dataset, steps) {
+ self$dataset <- dataset
+ self$steps <- steps
+ },
+ .iter = function() {
+ i <- 1L
+ iter <- NULL
+ function() {
+ if (i > self$steps) {
+ return(coro::exhausted())
+ }
+
+ i <<- i + 1L
+
+ if (is.null(iter) || coro::is_exhausted(data <- iter())) {
+ iter <<- self$dataset$.iter()
+ data <- iter()
+ }
+
+ data
+ }
+ },
+ .length = function() {
+ self$steps
+ }
+)
+We finally define the model we are going to train. We’ll use a small
+version of GPT2. We also define a generate
method allowing
+us to sample from the model given an initial context.
+
+net <- nn_module(
+ initialize = function() {
+ self$gpt <- minhub::gpt2(
+ vocab_size = 20000,
+ pdrop = 0.1
+ )
+ },
+ forward = function(x) {
+ self$gpt(x)$transpose(2,3)
+ },
+ generate = function(x, temperature = 1, iter = 50, top_k = 10) {
+ # samples from the model givn a context vector.
+ for (i in seq_len(iter)) {
+ logits <- self$forward(x)[,,-1]
+ logits <- logits/temperature
+ c(prob, ind) %<-% logits$topk(top_k)
+ logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
+ logits <- nnf_softmax(logits, dim = -1)
+ id_next <- torch_multinomial(logits, num_samples = 1)
+ x <- torch_cat(list(x, id_next), dim = 2)
+ }
+ x
+ }
+)
+
+# debug code for the model
+# ds <- torch::dataloader(r_sources_dataset("~/Downloads/"), batch_size = 32)
+# batch <- coro::collect(ds, 1)[[1]]
+# str(batch)
+# m <- net()
+# str(m(batch$input_ids))
+To make it easier to inspect training, we will also define a callback
+that prints a sample from the model every epoch.
+
+# samples from the model using the context.
+generate <- function(model, tok, context, ...) {
+ local_no_grad() # disables gradient for sampling
+ x <- tok$encode(context)$ids + 1L
+ x <- torch_tensor(x)[NULL,]$to(device = model$device)
+ content <- as.integer(model$generate(x, ...)$cpu())
+ tok$decode(content - 1L)
+}
+
+display_cb <- luz_callback(
+ initialize = function() {},
+ on_epoch_end = function() {
+ local_no_grad()
+ # sample from the model...
+ context <- "# creates a linear model"
+ text <- generate(ctx$model, dataset$dataset$tok, context, iter = 100)
+ cli::cli_rule()
+ cat(text, "\n")
+ cli::cli_rule()
+ }
+)
+We can finally train the model. We define that we want to train the
+model for half a billion tokens in a total of 100 epochs.
+
+n_tokens <- 500e6
+batch_size <- 16
+epochs <- 100
+context_length <- 256L
+
+steps <- n_tokens / context_length / epochs
+dataset <- fixed_steps_iterable_dataset(
+ r_sources_dataset(context_length = context_length),
+ steps = steps
+)
+
+fitted <- net %>%
+ setup(
+ optimizer = optim_adam,
+ loss = nn_cross_entropy_loss()
+ ) %>%
+ set_opt_hparams(lr = 3e-4) |>
+ fit(
+ dataset,
+ epochs = epochs,
+ dataloader_options = list(batch_size = batch_size),
+ callbacks = list(
+ luz_callback_lr_scheduler(
+ torch::lr_one_cycle,
+ max_lr = 0.1,
+ epochs = epochs,
+ steps_per_epoch = steps/batch_size,
+ call_on = "on_batch_end"
+ ),
+ luz_callback_gradient_clip(max_norm = 1),
+ display_cb()
+ ),
+ verbose = TRUE
+ )
+
+luz::luz_save(fitted, "model.pt")
+We can then use the model to generate text given a prompt with:
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/articles/get-started.html b/articles/get-started.html
index b5ff5a94..d82dc2d4 100644
--- a/articles/get-started.html
+++ b/articles/get-started.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Get started with luz
@@ -91,18 +92,42 @@ Guides
-Luz is a high-level API for torch that aims to encapsulate the training loop into a set of reusable pieces of code. Luz reduces the boilerplate code required to train a model with torch and avoids the error prone zero_grad()
- backward()
- step()
sequence of calls, and also simplifies the process of moving data and models between CPUs and GPUs. Luz is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control you need for your training loop.
-Luz is heavily inspired by other higher level frameworks for deep learning, to cite a few:
+Luz is a high-level API for torch that aims to encapsulate the
+training loop into a set of reusable pieces of code.
+Luz reduces the boilerplate code required to train a model with torch
+and avoids the error prone zero_grad()
-
+backward()
- step()
sequence of calls, and
+also simplifies the process of moving data and models between CPUs and
+GPUs. Luz is designed to be highly flexible by providing a layered API
+that allows it to be useful no matter the level of control you need for
+your training loop.
+Luz is heavily inspired by other higher level frameworks for deep
+learning, to cite a few:
-FastAI: we are heavily inspired by the FastAI library, especially the Learner
object and the callbacks API.
-Keras: We are also heavily inspired by Keras, especially callback names. The lightning module interface is similar to compile
, too.
-PyTorch Lightning: The idea of the luz_module
being a subclass of nn_module
is inspired by the LightningModule
object in lightning.
-HuggingFace Accelerate: The internal device placement API is heavily inspired by Accelerate, but is much more modest in features. Currently only CPU and Single GPU are supported.
+FastAI: we are heavily
+inspired by the FastAI library, especially the Learner
+object and the callbacks API.
+Keras: We are also heavily
+inspired by Keras, especially callback names. The lightning module
+interface is similar to compile
, too.
+PyTorch
+Lightning: The idea of the luz_module
being a subclass
+of nn_module
is inspired by the
+LightningModule
object in
+lightning.
+HuggingFace
+Accelerate: The internal device placement API is heavily inspired by
+Accelerate, but is much more modest in features. Currently only CPU and
+Single GPU are supported.
Training a nn_module
-As much as possible, luz tries to reuse the existing structures from torch. A model in luz is defined identically as you would define it if using raw torch. For a specific example, this is the definition of a feed-forward CNN that can be used to classify digits from the MNIST dataset:
+As much as possible, luz tries to reuse the existing structures from
+torch. A model in luz is defined identically as you would define it if
+using raw torch. For a specific example, this is the definition of a
+feed-forward CNN that can be used to classify digits from the MNIST
+dataset:
-We can now train this model in the train_dl
and validate it in the test_dl
torch::dataloaders()
with:
+We can now train this model in the train_dl
and validate
+it in the test_dl
torch::dataloaders()
+with:
fitted <- net %>%
setup(
@@ -144,21 +171,51 @@ Training a nn_module
fit(train_dl, epochs = 10, valid_data = test_dl)
Let’s understand what happens in this chunk of code:
-- The
setup
function allows you to configure the loss (objective) function and the optimizer that you will use to train your model. Optionally you can pass a list of metrics that are tracked during the training procedure. Note: the loss function can be any function taking input
and target
tensors and returning a scalar tensor value, and the optimizer can be any core torch optimizer or custom ones created with the torch::optimizer()
function.
-- The
set_hparams()
function allows you to set hyper-parameters that should be passed to the module initialize()
method. For example in this case we pass num_classes = 10
.
-- The
set_opt_hparams()
function allows you to pass hyper-parameters that are used by the optimizer function. For example, optim_adam()
can take the lr
parameter specifying the learning rate and we specify it with lr = 0.003
.
-- The
fit
method will take the model specification provided by setup()
and run the training procedure using the specified training and validation torch::dataloaders()
as well as the number of epochs. Note: we again reuse core torch data structures, instead of providing our own data loading functionality.
-- The returned object
fitted
contains the trained model as well as the record of metrics and losses produced during training. It can also be used for producing predictions and for evaluating the trained model on other datasets.
+- The
setup
function allows you to configure the loss
+(objective) function and the optimizer that you will use to train your
+model. Optionally you can pass a list of metrics that are tracked during
+the training procedure. Note: the loss function can be
+any function taking input
and target
tensors
+and returning a scalar tensor value, and the optimizer can be any core
+torch optimizer or custom ones created with the
+torch::optimizer()
function.
+- The
set_hparams()
function allows you to set
+hyper-parameters that should be passed to the module
+initialize()
method. For example in this case we pass
+num_classes = 10
.
+- The
set_opt_hparams()
function allows you to pass
+hyper-parameters that are used by the optimizer function. For example,
+optim_adam()
can take the lr
parameter
+specifying the learning rate and we specify it with
+lr = 0.003
.
+- The
fit
method will take the model specification
+provided by setup()
and run the training procedure using
+the specified training and validation torch::dataloaders()
+as well as the number of epochs. Note: we again reuse
+core torch data structures, instead of providing our own data loading
+functionality.
+- The returned object
fitted
contains the trained model
+as well as the record of metrics and losses produced during training. It
+can also be used for producing predictions and for evaluating the
+trained model on other datasets.
-When fitting, luz will use the fastest possible accelerator; if a CUDA-capable GPU is available it will be used, otherwise we fall back to the CPU. It also automatically moves data, optimizers, and models to the selected device so you don’t need to handle it manually (which is in general very error prone).
-To create predictions from the trained model you can use the predict
method:
+When fitting, luz will use the fastest possible accelerator; if a
+CUDA-capable GPU is available it will be used, otherwise we fall back to
+the CPU. It also automatically moves data, optimizers, and models to the
+selected device so you don’t need to handle it manually (which is in
+general very error prone).
+To create predictions from the trained model you can use the
+predict
method:
predictions <- predict(fitted, test_dl)
The training loop
-You now have a general idea of how to use the fit
function and now it’s important to have an overview of what’s happening inside it. In pseudocode, here’s what fit
does. This is not fully detailed but should help you to build your intuition:
+You now have a general idea of how to use the fit
+function and now it’s important to have an overview of what’s happening
+inside it. In pseudocode, here’s what fit
does. This is not
+fully detailed but should help you to build your intuition:
# -> Initialize objects: model, optimizers.
# -> Select fitting device.
@@ -184,25 +241,45 @@ The training loop
Metrics
-One of the most important parts in machine learning projects is choosing the evaluation metric. Luz allows tracking many different metrics during training with minimal code changes.
-In order to track metrics, you only need to modify the metrics
parameter in the setup
function:
-
-Luz provides implementations of a few of the most used metrics. If a metric is not available you can always implement a new one using the luz_metric
function.
-In order to implement a new luz_metric
we need to implement 3 methods:
+One of the most important parts in machine learning projects is
+choosing the evaluation metric. Luz allows tracking many different
+metrics during training with minimal code changes.
+In order to track metrics, you only need to modify the
+metrics
parameter in the setup
function:
+<- net %>%
+ fitted setup(
+
+ ...metrics = list(
+
+ luz_metric_accuracy
+ )%>%
+ ) fit(...)
+Luz provides implementations of a few of the most used metrics. If a
+metric is not available you can always implement a new one using the
+luz_metric
function.
+In order to implement a new luz_metric
we need to
+implement 3 methods:
-initialize
: defines the metric initial state. This function is called for each epoch for both training and validation loops.
-update
: updates the metric internal state. This function is called at every training and validation step with the predictions obtained by the model and the target values obtained from the dataloader.
-compute
: uses the internal state to compute metric values. This function is called whenever we need to obtain the current metric value. Eg, it’s called every training step for metrics displayed in the progress bar, but only called once per epoch to record it’s value when the progress bar is not displayed.
+initialize
: defines the metric initial state. This
+function is called for each epoch for both training and validation
+loops.
+update
: updates the metric internal state. This
+function is called at every training and validation step with the
+predictions obtained by the model and the target values obtained from
+the dataloader.
+compute
: uses the internal state to compute metric
+values. This function is called whenever we need to obtain the current
+metric value. Eg, it’s called every training step for metrics displayed
+in the progress bar, but only called once per epoch to record it’s value
+when the progress bar is not displayed.
-Optionally, you can implement an abbrev
field that gives the metric an abbreviation that will be used when displaying metric information in the console or tracking record. If no abbrev
is passed, the class name will be used.
-Let’s take a look at the implementation of luz_metric_accuracy
so you can see how to implement a new one:
+Optionally, you can implement an abbrev
field that gives
+the metric an abbreviation that will be used when displaying metric
+information in the console or tracking record. If no abbrev
+is passed, the class name will be used.
+Let’s take a look at the implementation of
+luz_metric_accuracy
so you can see how to implement a new
+one:
luz_metric_accuracy <- luz_metric(
# An abbreviation to be shown in progress bars, or
@@ -230,13 +307,20 @@ Metrics
self$correct/self$total
}
)
-Note: It’s good practice that the compute
metric returns regular R values instead of torch tensors and other parts of luz will expect that.
+Note: It’s good practice that the
+compute
metric returns regular R values instead of torch
+tensors and other parts of luz will expect that.
Evaluate
-Once a model has been trained you might want to evaluate its performance on a different dataset. For that reason, luz provides the ?evaluate
function that takes a fitted model and a dataset and computes the metrics attached to the model.
-Evaluate returns a luz_module_evaluation
object that you can query for metrics using the get_metrics
function or simply print
to see the results.
+Once a model has been trained you might want to evaluate its
+performance on a different dataset. For that reason, luz provides the
+?evaluate
function that takes a fitted model and a dataset
+and computes the metrics attached to the model.
+Evaluate returns a luz_module_evaluation
object that you
+can query for metrics using the get_metrics
function or
+simply print
to see the results.
For example:
evaluation <- fitted %>% evaluate(data = valid_dl)
@@ -252,16 +336,32 @@ Evaluate
Customizing with callbacks
-Luz provides different ways to customize the training progress depending on the level of control you need in the training loop. The fastest way and the more ‘reusable’, in the sense that you can create training modifications that can be used in many different situations, is via callbacks.
-The training loop in luz has many breakpoints that can call arbitrary R functions. This functionality allows you to customize the training process without having to modify the general training logic.
-Luz implements 3 default callbacks that occur in every training procedure:
+Luz provides different ways to customize the training progress
+depending on the level of control you need in the training loop. The
+fastest way and the more ‘reusable’, in the sense that you can create
+training modifications that can be used in many different situations, is
+via callbacks.
+The training loop in luz has many breakpoints that can call
+arbitrary R functions. This functionality allows you to customize the
+training process without having to modify the general training
+logic.
+Luz implements 3 default callbacks that occur in every training
+procedure:
-train-eval callback: Sets the model to train()
or eval()
depending on if the procedure is doing training or validation.
-metrics callback: evaluate metrics during training and validation process.
-progress callback: implements a progress bar and prints progress information during training.
+train-eval callback: Sets the model to
+train()
or eval()
depending on if the
+procedure is doing training or validation.
+metrics callback: evaluate metrics during
+training and validation process.
+progress callback: implements a progress bar and
+prints progress information during training.
-You can also implement custom callbacks that modify or act specifically for your training procedure. For example:
-Let’s implement a callback that prints ‘Iteration n
’ (where n
is the iteration number) for every batch in the training set and ‘Done’ when an epoch is finished. For that task we use the luz_callback
function:
+You can also implement custom callbacks that modify or act
+specifically for your training procedure. For example:
+Let’s implement a callback that prints ‘Iteration n
’
+(where n
is the iteration number) for every batch in the
+training set and ‘Done’ when an epoch is finished. For that task we use
+the luz_callback
function:
print_callback <- luz_callback(
name = "print_callback",
@@ -275,16 +375,30 @@ Customizing with callbacks cat(self$message, "\n")
}
)
-luz_callback()
takes named functions as ...
arguments, where the name indicates the moment at which the callback should be called. For instance on_train_batch_end()
is called for every batch at the end of the training procedure, and on_epoch_end()
is called at the end of every epoch.
-The returned value of luz_callback()
is a function that initializes an instance of the callback. Callbacks can have initialization parameters, like the name of a file where you want to log the results. In that case, you can pass an initialize
method when creating the callback definition, and save these parameters to the self
object. In the above example, the callback has a message
parameter that is printed at the end of each epoch.
-Once a callback is defined it can be passed to the fit
function via the callbacks
parameter:
+luz_callback()
takes named functions as ...
+arguments, where the name indicates the moment at which the callback
+should be called. For instance on_train_batch_end()
is
+called for every batch at the end of the training procedure, and
+on_epoch_end()
is called at the end of every epoch.
+The returned value of luz_callback()
is a function that
+initializes an instance of the callback. Callbacks can have
+initialization parameters, like the name of a file where you want to log
+the results. In that case, you can pass an initialize
+method when creating the callback definition, and save these parameters
+to the self
object. In the above example, the callback has
+a message
parameter that is printed at the end of each
+epoch.
+Once a callback is defined it can be passed to the fit
+function via the callbacks
parameter:
-Callbacks can be called in many different positions of the training loop, including combinations of them. Here’s an overview of possible callback breakpoints:
+Callbacks can be called in many different positions of the training
+loop, including combinations of them. Here’s an overview of possible
+callback breakpoints:
Start Fit
- on_fit_begin
Start Epoch Loop
@@ -320,10 +434,27 @@ Customizing with callbacks
-Every step market with on_*
is a point in the training procedure that is available for callbacks to be called.
-The other important part of callbacks is the ctx
(context) object. See help("ctx")
for details.
-By default, callbacks are called in the same order as they were passed to fit
(or predict
or evaluate
), but you can provide a weight
attribute that will control the order in which it will be called. For example, if one callback has weight = 10
and another has weight = 1
, then the first one is called after the second one. Callbacks that don’t specify a weight
attribute are considered weight = 0
. A few built-in callbacks in luz already provide a weight value. For example, the ?luz_callback_early_stopping
has a weight of Inf
, since in general we want to run it as the last thing in the loop.
-The ctx
object is used in luz to share information between the training loop and callbacks, model methods, and metrics. The table below describes information available in the ctx
by default. Other callbacks could potentially modify these attributes or add new ones.
+Every step market with on_*
is a point in the training
+procedure that is available for callbacks to be called.
+The other important part of callbacks is the ctx
+(context) object. See help("ctx")
for details.
+By default, callbacks are called in the same order as they were
+passed to fit
(or predict
or
+evaluate
), but you can provide a weight
+attribute that will control the order in which it will be called. For
+example, if one callback has weight = 10
and another has
+weight = 1
, then the first one is called after the second
+one. Callbacks that don’t specify a weight
attribute are
+considered weight = 0
. A few built-in callbacks in luz
+already provide a weight value. For example, the
+?luz_callback_early_stopping
has a weight of
+Inf
, since in general we want to run it as the last thing
+in the loop.
+The ctx
object is used in luz to share information
+between the training loop and callbacks, model methods, and metrics. The
+table below describes information available in the ctx
by
+default. Other callbacks could potentially modify these attributes or
+add new ones.
Attributes in ctx
can be used to produce the desired behavior of callbacks. You can find information about the context object using help("ctx")
. In our example, we use the ctx$iter
attribute to print the iteration number for each training batch.
+Attributes in ctx
can be used to produce the desired
+behavior of callbacks. You can find information about the context object
+using help("ctx")
. In our example, we use the
+ctx$iter
attribute to print the iteration number for each
+training batch.
Next steps
-In this article you learned how to train your first model using luz and the basics of customization using both custom metrics and callbacks.
-Luz also allows more flexible modifications of the training loop described in vignette("custom-loop")
.
-You should now be able to follow the examples marked with the ‘basic’ category in the examples gallery.
+In this article you learned how to train your first model using luz
+and the basics of customization using both custom metrics and
+callbacks.
+Luz also allows more flexible modifications of the training loop
+described in vignette("custom-loop")
.
+You should now be able to follow the examples marked with the ‘basic’
+category in the examples
+gallery.
@@ -472,7 +654,7 @@ Next steps
-Site built with pkgdown 2.0.7.9000.
+Site built with pkgdown 2.0.7.
diff --git a/articles/get-started_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/get-started_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/get-started_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/index.html b/articles/index.html
index 2bd5852c..81608f8f 100644
--- a/articles/index.html
+++ b/articles/index.html
@@ -63,21 +63,19 @@ All vignettes
- Accelerator API
- -
-
- CharGPT
-
- Checkpointing your models
-
+
- Using the learning rate finder
+ -
- Custom loops with luz
-
-
- Binary classification
+ - CharGPT
-
-
- Get started with luz
+ - Binary classification
-
- Examples
-
-
- Using the learning rate finder
- -
- Autoencoder
-
- Virtual batch size
@@ -94,6 +92,10 @@ All vignettes
-
- Text classification from scratch
-
+
- Training a causal language model from scratch
+ -
+
- Get started with luz
+ -
@@ -103,7 +105,7 @@ All vignettes
diff --git a/articles/lr-finder.html b/articles/lr-finder.html
index 876a03ce..7a8d20d0 100644
--- a/articles/lr-finder.html
+++ b/articles/lr-finder.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Using the learning rate finder
@@ -94,9 +95,25 @@ Guides
library(torchvision)
set.seed(1)
torch::torch_manual_seed(1703)
-In this article we discuss how to find a good learning rate for your model. Finding a good learning rate is essential to be able to fit your model. If it’s too low, you will need too many iterations for your loss to converge, and that might be impractical if your model takes too long to run. If it’s too high, the loss can explode and you might never be able to minimize the loss.
-The learning rate can be considered another hyperparameter of your model that needs to be tuned but, there are techniques that allow you to select a good learning rate for your model without having to use the costly strategy of fitting many models with different learning rates and then choosing the one with better results.
-This article by Leslie Smith that became popular once their approach had been implemented in the popular FastAI framework, proposes that we should start with a very small learning rate and slowly increase it until we reach a high learning rate. At each iteration we record the loss value and in the end we plot it against the learning rate. We can then use these results to decide on a good learning rate. That’s what lr_finder
does, and we will show how to use it.
+In this article we discuss how to find a good learning rate for your
+model. Finding a good learning rate is essential to be able to fit your
+model. If it’s too low, you will need too many iterations for your loss
+to converge, and that might be impractical if your model takes too long
+to run. If it’s too high, the loss can explode and you might never be
+able to minimize the loss.
+The learning rate can be considered another hyperparameter of your
+model that needs to be tuned but, there are techniques that allow you to
+select a good learning rate for your model without having to use the
+costly strategy of fitting many models with different learning rates and
+then choosing the one with better results.
+This article by Leslie
+Smith that became popular once their approach had been implemented in
+the popular FastAI framework, proposes that we should start with a very
+small learning rate and slowly increase it until we reach a high
+learning rate. At each iteration we record the loss value and in the end
+we plot it against the learning rate. We can then use these results to
+decide on a good learning rate. That’s what lr_finder
does,
+and we will show how to use it.
First let’s download and prepare the MNIST dataset:
dir <- "~/Downloads/mnist" # caching directory
@@ -108,7 +125,8 @@ Guides
)
#> Processing...
#> Done!
-We can now define our model. We are going to use a small, straightforward CNN in the LeNet style.
+We can now define our model. We are going to use a small,
+straightforward CNN in the LeNet style.
net <- nn_module(
"net",
@@ -135,7 +153,11 @@ Guides
self$classifier()
}
)
-We can now use the lr_finder
function to record the loss with different learning rates. It’s important to use the learning rate finder with all other hyperparameters of the model fixed because they can influence the choice of the learning rate. For example, depending on the batch size, you might want to choose different learning rates.
+We can now use the lr_finder
function to record the loss
+with different learning rates. It’s important to use the learning rate
+finder with all other hyperparameters of the model fixed because they
+can influence the choice of the learning rate. For example, depending on
+the batch size, you might want to choose different learning rates.
model <- net %>% setup(
loss = torch::nn_cross_entropy_loss(),
@@ -155,15 +177,26 @@ Guides
#> Classes 'lr_records' and 'data.frame': 100 obs. of 2 variables:
#> $ lr : num 1.15e-06 1.32e-06 1.51e-06 1.74e-06 2.00e-06 ...
#> $ loss: num 2.31 2.3 2.29 2.3 2.31 ...
-The result is a data frame with the losses and the learning rate in each step. You can use the built-in plot method to display the exact results, along with a exponentially smoothed value of the loss.
+The result is a data frame with the losses and the learning rate in
+each step. You can use the built-in plot method to display the exact
+results, along with a exponentially smoothed value of the loss.
plot(records) +
ggplot2::coord_cartesian(ylim = c(NA, 5))
-We can see that with small learning rates the loss doesn’t decrease. At some point the loss starts decreasing until it reaches a point where it starts increasing and explodes.
-And how do we choose the learning rate using this plot? Sylvain Gugger asked the same question in this blog post and we are quoting his answer:
+We can see that with small learning rates the loss doesn’t decrease.
+At some point the loss starts decreasing until it reaches a point where
+it starts increasing and explodes.
+And how do we choose the learning rate using this plot? Sylvain
+Gugger asked the same question in this blog
+post and we are quoting his answer:
-Not the one corresponding to the minimum. Why? Well the learning rate that corresponds to the minimum value is already a bit too high, since we are at the edge between improving and getting all over the place. We want to go one order of magnitude before, a value that’s still aggressive (so that we train quickly) but still on the safe side from an explosion.
+Not the one corresponding to the minimum. Why? Well the learning rate
+that corresponds to the minimum value is already a bit too high, since
+we are at the edge between improving and getting all over the place. We
+want to go one order of magnitude before, a value that’s still
+aggressive (so that we train quickly) but still on the safe side from an
+explosion.
In the above example we would choose 1e-3 instead of 1e-2.
@@ -178,7 +211,7 @@ Guides
diff --git a/articles/lr-finder_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/lr-finder_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/lr-finder_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/authors.html b/authors.html
index 91358a77..60012493 100644
--- a/authors.html
+++ b/authors.html
@@ -95,7 +95,7 @@ Citation
diff --git a/index.html b/index.html
index 6d2484de..02b9dcd1 100644
--- a/index.html
+++ b/index.html
@@ -5,14 +5,24 @@
-
+
Higher Level API for torch • luz
-
+
+
Luz is a higher level API for torch providing abstractions to allow for much less verbose training loops.
This package is still under development.
It is heavily inspired by other higher level frameworks for deep learning, to cite a few:
@@ -190,7 +201,7 @@ Dev status
diff --git a/news/index.html b/news/index.html
index 1caa70f2..488f9cb4 100644
--- a/news/index.html
+++ b/news/index.html
@@ -101,7 +101,8 @@ Bug fixes
luz 0.3.1
CRAN release: 2022-09-06
-- Re-submission to fix vignette rendering.
+- Re-submission to fix vignette rendering.
+
luz 0.3.0
CRAN release: 2022-08-19
@@ -113,7 +114,8 @@ Breaking changes
Documentation
-- Many wording improvements in the getting started guides (#81 #94, @jonthegeek).
+- Many wording improvements in the getting started guides (#81 #94, @jonthegeek).
+
New features
- Added MixUp callback and helper loss function and functional logic. (#82, @skeydan).
@@ -151,7 +153,8 @@ Internal changes
luz 0.1.0
CRAN release: 2021-06-17
-- Added a
NEWS.md
file to track changes to the package.
+- Added a
NEWS.md
file to track changes to the package.
+
@@ -161,7 +164,7 @@ luz 0.1.0
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/pkgdown.yml b/pkgdown.yml
index 3b8ad870..5832483b 100644
--- a/pkgdown.yml
+++ b/pkgdown.yml
@@ -1,15 +1,14 @@
-pandoc: 2.7.3
-pkgdown: 2.0.7.9000
-pkgdown_sha: c9206802f2888992de92aa41f517ba7812f05331
+pandoc: 2.19.2
+pkgdown: 2.0.7
+pkgdown_sha: ~
articles:
accelerator: accelerator.html
- chargpt: examples/chargpt.html
checkpoints: checkpoints.html
+ lr-finder: lr-finder.html
custom-loop: custom-loop.html
+ chargpt: examples/chargpt.html
dogs-vs-cats-binary-classification: examples/dogs-vs-cats-binary-classification.html
- get-started: get-started.html
index: examples/index.html
- lr-finder: lr-finder.html
mnist-autoencoder: examples/mnist-autoencoder.html
mnist-cnn-virtual-batch-size: examples/mnist-cnn-virtual-batch-size.html
mnist-cnn: examples/mnist-cnn.html
@@ -18,5 +17,7 @@ articles:
mnist-triplet: examples/mnist-triplet.html
pets-unet: examples/pets-unet.html
text-classification: examples/text-classification.html
-last_built: 2023-09-15T17:29Z
+ text-generation: examples/text-generation.html
+ get-started: get-started.html
+last_built: 2023-10-17T16:26Z
diff --git a/reference/accelerator.html b/reference/accelerator.html
index 6e94c1a0..ad0990cf 100644
--- a/reference/accelerator.html
+++ b/reference/accelerator.html
@@ -99,7 +99,7 @@ Arguments
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/as_dataloader.html b/reference/as_dataloader.html
index e9bfbfe2..404c070e 100644
--- a/reference/as_dataloader.html
+++ b/reference/as_dataloader.html
@@ -159,7 +159,7 @@ Overriding
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/context.html b/reference/context.html
index 9c598532..a377b0d8 100644
--- a/reference/context.html
+++ b/reference/context.html
@@ -517,7 +517,7 @@ Arguments
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/ctx.html b/reference/ctx.html
index fe747f6a..0e921c9b 100644
--- a/reference/ctx.html
+++ b/reference/ctx.html
@@ -90,7 +90,7 @@ See also
diff --git a/reference/evaluate.html b/reference/evaluate.html
index 9bdcad2d..9ef5c86f 100644
--- a/reference/evaluate.html
+++ b/reference/evaluate.html
@@ -141,12 +141,12 @@ Details
evaluation <- fitted %>% evaluate(data = valid_dl)
metrics <- get_metrics(evaluation)
print(evaluation)
-## A `luz_module_evaluation`
-## -- Results ---------------------------------------------------------------------
-## loss: 1.5146
-## mae: 1.0251
-## mse: 1.5159
-## rmse: 1.2312
+## A `luz_module_evaluation`
+## -- Results ---------------------------------------------------------------------
+## loss: 1.5146
+## mae: 1.0251
+## mse: 1.5159
+## rmse: 1.2312
diff --git a/reference/fit.luz_module_generator.html b/reference/fit.luz_module_generator.html
index 3850f33c..885905d8 100644
--- a/reference/fit.luz_module_generator.html
+++ b/reference/fit.luz_module_generator.html
@@ -170,7 +170,7 @@ See also
diff --git a/reference/get_metrics.html b/reference/get_metrics.html
index 2add3e85..692a7bf7 100644
--- a/reference/get_metrics.html
+++ b/reference/get_metrics.html
@@ -103,7 +103,7 @@ Methods (by class)
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/index.html b/reference/index.html
index 29fd15f7..5c6cf36f 100644
--- a/reference/index.html
+++ b/reference/index.html
@@ -359,7 +359,7 @@ Serialization
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/lr_finder-1.png b/reference/lr_finder-1.png
index 6801c27f..2b724bd4 100644
Binary files a/reference/lr_finder-1.png and b/reference/lr_finder-1.png differ
diff --git a/reference/lr_finder.html b/reference/lr_finder.html
index 382b47f7..6cdb2739 100644
--- a/reference/lr_finder.html
+++ b/reference/lr_finder.html
@@ -146,7 +146,7 @@ Examples
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/luz_callback.html b/reference/luz_callback.html
index fdcd20ed..a2d9fdae 100644
--- a/reference/luz_callback.html
+++ b/reference/luz_callback.html
@@ -152,41 +152,41 @@ Details
Callbacks can be called in many different positions of the training
loop, including combinations of them. Here’s an overview of possible
callback breakpoints:
-Start Fit
- - on_fit_begin
- Start Epoch Loop
- - on_epoch_begin
- Start Train
- - on_train_begin
- Start Batch Loop
- - on_train_batch_begin
- Start Default Training Step
- - on_train_batch_after_pred
- - on_train_batch_after_loss
- - on_train_batch_before_backward
- - on_train_batch_before_step
- - on_train_batch_after_step
- End Default Training Step:
- - on_train_batch_end
- End Batch Loop
- - on_train_end
- End Train
- Start Valid
- - on_valid_begin
- Start Batch Loop
- - on_valid_batch_begin
- Start Default Validation Step
- - on_valid_batch_after_pred
- - on_valid_batch_after_loss
- End Default Validation Step
- - on_valid_batch_end
- End Batch Loop
- - on_valid_end
- End Valid
- - on_epoch_end
- End Epoch Loop
- - on_fit_end
-End Fit
+
+ Start Fit- on_fit_begin
+
+ Start Epoch Loop- on_epoch_begin
+
+ Start Train- on_train_begin
+
+ Start Batch Loop- on_train_batch_begin
+
+ Start Default Training Step- on_train_batch_after_pred
+ - on_train_batch_after_loss
+ - on_train_batch_before_backward
+ - on_train_batch_before_step
+ - on_train_batch_after_step
+ :
+ End Default Training Step- on_train_batch_end
+
+ End Batch Loop- on_train_end
+
+ End Train
+ Start Valid- on_valid_begin
+
+ Start Batch Loop- on_valid_batch_begin
+
+ Start Default Validation Step- on_valid_batch_after_pred
+ - on_valid_batch_after_loss
+
+ End Default Validation Step- on_valid_batch_end
+
+ End Batch Loop- on_valid_end
+
+ End Valid- on_epoch_end
+
+ End Epoch Loop- on_fit_end
+ End Fit
Every step market with on_*
is a point in the training procedure that
is available for callbacks to be called.
The other important part of callbacks is the ctx
(context) object. See
@@ -208,14 +208,14 @@
Prediction callbackspredict(). In this case the supported
callback methods are detailed above.
-
Start predict
- - on_predict_begin
- Start prediction loop
- - on_predict_batch_begin
- - on_predict_batch_end
- End prediction loop
- - on_predict_end
-End predict
+
+ Start predict- on_predict_begin
+
+ Start prediction loop- on_predict_batch_begin
+ - on_predict_batch_end
+
+ End prediction loop- on_predict_end
+ End predict
Evaluate callbacks
@@ -224,18 +224,18 @@ Evaluate callbacksevaluate(), in this case, the callbacks that
are used are equivalent to those of the validation loop when using fit()
:
-
Start Valid
- - on_valid_begin
- Start Batch Loop
- - on_valid_batch_begin
- Start Default Validation Step
- - on_valid_batch_after_pred
- - on_valid_batch_after_loss
- End Default Validation Step
- - on_valid_batch_end
- End Batch Loop
- - on_valid_end
-End Valid
+
+ Start Valid- on_valid_begin
+
+ Start Batch Loop- on_valid_batch_begin
+
+ Start Default Validation Step- on_valid_batch_after_pred
+ - on_valid_batch_after_loss
+
+ End Default Validation Step- on_valid_batch_end
+
+ End Batch Loop- on_valid_end
+ End Valid
See also
@@ -278,7 +278,7 @@ Examples
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/luz_callback_auto_resume.html b/reference/luz_callback_auto_resume.html
index 2013b39d..b853084c 100644
--- a/reference/luz_callback_auto_resume.html
+++ b/reference/luz_callback_auto_resume.html
@@ -177,16 +177,16 @@ Examples#> Caused by error in `self[[callback_nm]]()`:
#> ! Error on epoch 5
#> set metric epoch value
-#> 1 train loss 1 1.302326
-#> 2 train loss 2 1.141849
-#> 3 train loss 3 1.094023
-#> 4 train loss 4 1.082328
-#> 5 train loss 5 1.083923
-#> 6 train loss 6 1.072870
-#> 7 train loss 7 1.083111
-#> 8 train loss 8 1.079866
-#> 9 train loss 9 1.074621
-#> 10 train loss 10 1.075743
+#> 1 train loss 1 1.217334
+#> 2 train loss 2 1.079304
+#> 3 train loss 3 1.040630
+#> 4 train loss 4 1.027106
+#> 5 train loss 5 1.023069
+#> 6 train loss 6 1.017577
+#> 7 train loss 7 1.016829
+#> 8 train loss 8 1.020484
+#> 9 train loss 9 1.022464
+#> 10 train loss 10 1.025988
diff --git a/articles/examples/index_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/index_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/index_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty
diff --git a/reference/luz_callback_csv_logger.html b/reference/luz_callback_csv_logger.html
index ef9e3e98..8caa5ca8 100644
--- a/reference/luz_callback_csv_logger.html
+++ b/reference/luz_callback_csv_logger.html
@@ -106,7 +106,7 @@ See also
diff --git a/reference/luz_callback_early_stopping.html b/reference/luz_callback_early_stopping.html
index 594b229e..6e55924f 100644
--- a/reference/luz_callback_early_stopping.html
+++ b/reference/luz_callback_early_stopping.html
@@ -150,7 +150,7 @@ Examples
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/luz_callback_gradient_clip.html b/reference/luz_callback_gradient_clip.html
index 9cd58b87..51e17643 100644
--- a/reference/luz_callback_gradient_clip.html
+++ b/reference/luz_callback_gradient_clip.html
@@ -101,7 +101,7 @@ References
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/luz_callback_interrupt.html b/reference/luz_callback_interrupt.html
index 1f59d51c..0c117239 100644
--- a/reference/luz_callback_interrupt.html
+++ b/reference/luz_callback_interrupt.html
@@ -122,7 +122,7 @@ Examples
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
Checkpointing your models
@@ -93,15 +94,29 @@Guides
library(torch) set.seed(1) torch::torch_manual_seed(1703)When fitting models take too long you might want to save intermediate state to disk, if something goes wrong during training (eg. process is killed, network fails, etc) you can recover from where it stopped.
-You might also want to recover intermediate results to evaluate the model in different moments of the training, like comparing results after 10 epochs and after 30 epochs.
-This article describes luz features that are built to handle those cases. These features are optional and are enabled once you add specific callbacks to your fit
call.
When fitting models take too long you might want to save intermediate +state to disk, if something goes wrong during training (eg. process is +killed, network fails, etc) you can recover from where it stopped.
+You might also want to recover intermediate results to evaluate the +model in different moments of the training, like comparing results after +10 epochs and after 30 epochs.
+This article describes luz features that are built to handle those
+cases. These features are optional and are enabled once you add specific
+callbacks to your fit
call.
Resuming training runs that crashed
-If you have a long training run that can crash for whatever reason (computer turned off, process kileed in cluster, etc), we recommend you to add luz_callback_autoresume()
to your list of callbacks.
luz_callback_autoresume()
will automatically checkpoint the whole state of your model at the end of each epoch. If something fails during training you can simply rerun the same script, whithout any code changes and the checkpoint will be reloaded and the training will start from where it stopped.
For example, lets’s take a randomly generated training dataset and a linear model to show how autoresume works.
+If you have a long training run that can crash for whatever reason
+(computer turned off, process kileed in cluster, etc), we recommend you
+to add luz_callback_autoresume()
to your list of
+callbacks.
luz_callback_autoresume()
will automatically checkpoint
+the whole state of your model at the end of each epoch. If something
+fails during training you can simply rerun the same script, whithout any
+code changes and the checkpoint will be reloaded and the training will
+start from where it stopped.
For example, lets’s take a randomly generated training dataset and a +linear model to show how autoresume works.
Here’s the training data:
x <- torch_randn(1000, 10)
@@ -112,7 +127,9 @@ Resuming training runs that crashed
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
Let’s now create a callback that simulates a random failure that could happen. This callback will just raise an R error on the 5th epoch.
+Let’s now create a callback that simulates a random failure that +could happen. This callback will just raise an R error on the 5th +epoch.
interrupt <- luz_callback(
"interrupt",
@@ -124,7 +141,8 @@ Resuming training runs that crashed
}
}
)
Let’s now start training adding the luz_callback_auto_resume()
:
Let’s now start training adding the
+luz_callback_auto_resume()
:
autoresume <- luz_callback_auto_resume(path = "state.pt")
inter <- interrupt()
@@ -140,14 +158,17 @@ Resuming training runs that crashed
#> on_epoch_end.
#> Caused by error in `self[[callback_nm]]()`:
#> ! Error on epoch 5
To resume model training exactly from where it stopped you just need to restart fitting, using the exact same model, callbacks, etc:
+To resume model training exactly from where it stopped you just need +to restart fitting, using the exact same model, callbacks, etc:
-With this, the model fitting process will be continued exactly from where it stopped. Records, optimizer and model state are recovered from the previous run so you can have the full results:
+With this, the model fitting process will be continued exactly from +where it stopped. Records, optimizer and model state are recovered from +the previous run so you can have the full results:
plot(results)
Resuming training runs that crashed
Checkpointing
-Sometimes you want to have more control over how checkpoints are handled. In this case you can use luz_callback_model_checkpoint()
to save checkpoints to a specified file or directory.
-Let’s use the same example as in the resuming section: We first generate some data.
+Sometimes you want to have more control over how checkpoints are
+handled. In this case you can use
+luz_callback_model_checkpoint()
to save checkpoints to a
+specified file or directory.
+Let’s use the same example as in the resuming section: We first
+generate some data.
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
@@ -166,7 +191,8 @@ Checkpointingsetup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
-
Checkpointing
-Sometimes you want to have more control over how checkpoints are handled. In this case you can use luz_callback_model_checkpoint()
to save checkpoints to a specified file or directory.
Let’s use the same example as in the resuming section: We first generate some data.
+Sometimes you want to have more control over how checkpoints are
+handled. In this case you can use
+luz_callback_model_checkpoint()
to save checkpoints to a
+specified file or directory.
Let’s use the same example as in the resuming section: We first +generate some data.
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
Checkpointingsetup(optimizer = optim_sgd, loss = nnf_mse_loss) %>% set_hparams(in_features = 10, out_features = 1) %>% set_opt_hparams(lr = 0.01)
Let’s now fit the model using luz_callback_model_checkpoint()
.
Let’s now fit the model using
+luz_callback_model_checkpoint()
.
checkpoint <- luz_callback_model_checkpoint(
path = "checkpoints/",
@@ -178,7 +204,12 @@ Checkpointing= list(checkpoint),
verbose = FALSE
)
You can see now that the checkpoints
directory contains files with state dumps for each epoch. By default, luz_callback_model_checkpoint
will save the state for each epochs and format the name including the resulting loss. This can be configured withing the path parameter, see ?luz_callback_model_checkpoint
for details.
You can see now that the checkpoints
directory contains
+files with state dumps for each epoch. By default,
+luz_callback_model_checkpoint
will save the state for each
+epochs and format the name including the resulting loss. This can be
+configured withing the path parameter, see
+?luz_callback_model_checkpoint
for details.
fs::dir_ls("checkpoints")
#> checkpoints/epoch-01-train_loss-1.237.pt
@@ -191,11 +222,21 @@ Checkpointing#> checkpoints/epoch-08-train_loss-0.998.pt
#> checkpoints/epoch-09-train_loss-1.001.pt
#> checkpoints/epoch-10-train_loss-1.002.pt
Finally, you can load a specific checkpoint to the fitted
result using luz_load_checkpoint
. Note that loading the checkpoint into a a luz_fitted_module
is going to modify the model weights in-place.
Finally, you can load a specific checkpoint to the
+fitted
result using luz_load_checkpoint
. Note
+that loading the checkpoint into a a luz_fitted_module
is
+going to modify the model weights in-place.
luz_load_checkpoint(results, fs::dir_ls("checkpoints")[1])
You can then start making predictions, or evaluate your model using the reloeded weights.
-You might also want to start a new training run from a checkpoint. For this, you can use the luz_callback_resume_from_checkpoint()
. By default, it will only recover the model weights from the checkpoint file, but you can configure it to restore records, callback and optimizer state too. If a checkpoint directory is passed then training will resume from the last checkpoint file as returned by fs::dir_ls
.
You can then start making predictions, or evaluate your model using +the reloeded weights.
+You might also want to start a new training run from a checkpoint.
+For this, you can use the
+luz_callback_resume_from_checkpoint()
. By default, it will
+only recover the model weights from the checkpoint file, but you can
+configure it to restore records, callback and optimizer state too. If a
+checkpoint directory is passed then training will resume from the last
+checkpoint file as returned by fs::dir_ls
.
Here’s how you would use this callback:
resume <- luz_callback_resume_from_checkpoint(path = "checkpoints/")
@@ -209,8 +250,15 @@ Checkpointing
Custom callbacks state
-Sometimes callbacks also need to keep their internal state in order to allow continuing training exactly from where it stopped. In this case, callbacks can implement the state_dict()
and the load_state_dict()
methods that are automatically called when saving and reloading checkpoints.
-For example, suppose that you have a callback that tracks gradients for weights at every epoch. You want to use the tracked weights to further analyse the training procedure. It could be implemented like:
+Sometimes callbacks also need to keep their internal state in order
+to allow continuing training exactly from where it stopped. In this
+case, callbacks can implement the state_dict()
and the
+load_state_dict()
methods that are automatically called
+when saving and reloading checkpoints.
+For example, suppose that you have a callback that tracks gradients
+for weights at every epoch. You want to use the tracked weights to
+further analyse the training procedure. It could be implemented
+like:
cb_weight_grad <- luz_callback(
"weight_grad",
@@ -225,7 +273,14 @@ Custom callbacks state }
}
)
-In the above example, the gradients
field is a state in the callback. If training fails for some reason, gradients
will be lost. If it’s important for you to also checkpoint the callback state, you can implement the state_dict()
method must returning a named list of objects that compose the state of the callback and load_state_dict()
taking the same named list returned by state_dict()
and restoring the callback state.
+In the above example, the gradients
field is a
+state in the callback. If training fails for some
+reason, gradients
will be lost. If it’s important for you
+to also checkpoint the callback state, you can implement the
+state_dict()
method must returning a named list of objects
+that compose the state of the callback and
+load_state_dict()
taking the same named list returned by
+state_dict()
and restoring the callback state.
The callback above could be reimplemented with:
cb_weight_grad <- luz_callback(
@@ -262,7 +317,7 @@ Custom callbacks state
-Site built with pkgdown 2.0.7.9000.
+Site built with pkgdown 2.0.7.
diff --git a/articles/checkpoints_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/checkpoints_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/checkpoints_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/custom-loop.html b/articles/custom-loop.html
index b27b17d8..658563a7 100644
--- a/articles/custom-loop.html
+++ b/articles/custom-loop.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Custom loops with luz
@@ -91,15 +92,36 @@ Guides
-Luz is a higher level API for torch that is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control your need for your training loop.
-In the getting started vignette we have seen the basics of luz and how to quickly modify parts of the training loop using callbacks and custom metrics. In this document we will describe how luz allows the user to get fine-grained control of the training loop.
-Apart from the use of callbacks, there are three more ways that you can use luz (depending on how much control you need):
+Luz is a higher level API for torch that is designed to be highly
+flexible by providing a layered API that allows it to be useful no
+matter the level of control your need for your training loop.
+In the getting started vignette we have seen the basics of luz and
+how to quickly modify parts of the training loop using callbacks and
+custom metrics. In this document we will describe how luz allows the
+user to get fine-grained control of the training loop.
+Apart from the use of callbacks, there are three more ways that you
+can use luz (depending on how much control you need):
-Multiple optimizers or losses: You might be optimizing two loss functions each with its own optimizer, but you still don’t want to modify the backward()
- zero_grad()
and step()
calls. This is common in models like GANs (Generative Adversarial Networks) when you have competing neural networks trained with different losses and optimizers.
-Fully flexible steps: You might want to be in control of how to call backward()
, zero_grad()
and step()
. You might also want to have more control of gradient computation. For example, you might want to use ‘virtual batch sizes’, where you accumulate the gradients for a few steps before updating the weights.
-Completely flexible loops: Your training loop can be anything you want but you still want to use luz to handle device placement of the dataloaders, optimizers and models. See vignette("accelerator")
.
+Multiple optimizers or losses: You might be
+optimizing two loss functions each with its own optimizer, but you still
+don’t want to modify the backward()
-
+zero_grad()
and step()
calls. This is common
+in models like GANs (Generative Adversarial Networks) when you have
+competing neural networks trained with different losses and
+optimizers.
+Fully flexible steps: You might want to be in
+control of how to call backward()
,
+zero_grad()
and step()
. You might also want to
+have more control of gradient computation. For example, you might want
+to use ‘virtual batch sizes’, where you accumulate the gradients for a
+few steps before updating the weights.
+Completely flexible loops: Your training loop
+can be anything you want but you still want to use luz to handle device
+placement of the dataloaders, optimizers and models. See
+vignette("accelerator")
.
-Let’s consider a simplified version of the net
that we implemented in the getting started vignette:
+Let’s consider a simplified version of the net
that we
+implemented in the getting started vignette:
net <- nn_module(
"Net",
@@ -128,11 +150,18 @@ Guides
Multiple optimizers
-Suppose we want to do an experiment where we train the first fully connected layer using a learning rate of 0.1 and the second one using a learning rate of 0.01. We will minimize the same nn_cross_entropy_loss()
for both, but for the first layer we want to add L1 regularization on the weights.
-In order to use luz for this, we will implement two methods in the net
module:
+Suppose we want to do an experiment where we train the first fully
+connected layer using a learning rate of 0.1 and the second one using a
+learning rate of 0.01. We will minimize the same
+nn_cross_entropy_loss()
for both, but for the first layer
+we want to add L1 regularization on the weights.
+In order to use luz for this, we will implement two methods in the
+net
module:
-set_optimizers
: returns a named list of optimizers depending on the ctx
.
-loss
: computes the loss depending on the selected optimizer.
+set_optimizers
: returns a named list of optimizers
+depending on the ctx
.
+loss
: computes the loss depending on the selected
+optimizer.
Let’s go to the code:
@@ -163,19 +192,35 @@ Multiple optimizersnnf_cross_entropy(pred, target)
}
)
-Notice that the model optimizers will be initialized according to the set_optimizers()
method’s return value (a list). In this case, we are initializing the optimizers using different model parameters and learning rates.
-The loss()
method is responsible for computing the loss that will then be back-propagated to compute gradients and update the weights. This loss()
method can access the ctx
object that will contain an opt_name
field, describing which optimizer is currently being used. Note that this function will be called once for each optimizer for each training and validation step. See help("ctx")
for complete information about the context object.
-We can finally setup
and fit
this module, however we no longer need to specify optimizers and loss functions.
+Notice that the model optimizers will be initialized according to the
+set_optimizers()
method’s return value (a list). In this
+case, we are initializing the optimizers using different model
+parameters and learning rates.
+The loss()
method is responsible for computing the loss
+that will then be back-propagated to compute gradients and update the
+weights. This loss()
method can access the ctx
+object that will contain an opt_name
field, describing
+which optimizer is currently being used. Note that this function will be
+called once for each optimizer for each training and validation step.
+See help("ctx")
for complete information about the context
+object.
+We can finally setup
and fit
this module,
+however we no longer need to specify optimizers and loss functions.
fitted <- net %>%
setup(metrics = list(luz_metric_accuracy)) %>%
fit(train_dl, epochs = 10, valid_data = test_dl)
-Now let’s re-implement this same model using the slightly more flexible approach of overriding the training and validation step.
+Now let’s re-implement this same model using the slightly more
+flexible approach of overriding the training and validation step.
Fully flexible step
-Instead of implementing the loss()
method, we can implement the step()
method. This allows us to flexibly modify what happens when training and validating for each batch in the dataset. You are now responsible for updating the weights by stepping the optimizers and back-propagating the loss.
+Instead of implementing the loss()
method, we can
+implement the step()
method. This allows us to flexibly
+modify what happens when training and validating for each batch in the
+dataset. You are now responsible for updating the weights by stepping
+the optimizers and back-propagating the loss.
The important things to notice here are:
-The step()
method is used for both training and validation. You need to be careful to only modify the weights when training. Again, you can get complete information regarding the context object using help("ctx")
.
-ctx$optimizers
is a named list holding each optimizer that was created when the set_optimizers()
method was called.
-You need to manually track the losses by saving saving them in a named list in ctx$loss
. By convention, we use the same name as the optimizer it refers to. It is good practice to detach()
them before saving to reduce memory usage.
-Callbacks that would be called inside the default step()
method like on_train_batch_after_pred
, on_train_batch_after_loss
, etc, won’t be automatically called. You can still cal them manually by adding ctx$call_callbacks("<callback name>")
inside your training step. See the code for fit_one_batch()
and valid_one_batch
to find all the callbacks that won’t be called.
-If you want luz metrics to work with your custom step()
method, you must assign ctx$pred
with the model predictions as metrics will always be called with metric$update(ctx$pred, ctx$target)
.
+The step()
method is used for both training and
+validation. You need to be careful to only modify the weights when
+training. Again, you can get complete information regarding the context
+object using help("ctx")
.
+ctx$optimizers
is a named list holding each
+optimizer that was created when the set_optimizers()
method
+was called.
+You need to manually track the losses by saving saving them in a
+named list in ctx$loss
. By convention, we use the same name
+as the optimizer it refers to. It is good practice to
+detach()
them before saving to reduce memory
+usage.
+Callbacks that would be called inside the default
+step()
method like on_train_batch_after_pred
,
+on_train_batch_after_loss
, etc, won’t be automatically
+called. You can still cal them manually by adding
+ctx$call_callbacks("<callback name>")
inside your
+training step. See the code for fit_one_batch()
and
+valid_one_batch
to find all the callbacks that won’t be
+called.
+If you want luz metrics to work with your custom
+step()
method, you must assign ctx$pred
with
+the model predictions as metrics will always be called with
+metric$update(ctx$pred, ctx$target)
.
Next steps
-In this article you learned how to customize the step()
of your training loop using luz layered functionality.
-Luz also allows more flexible modifications of the training loop described in the Accelerator vignette (vignette("accelerator")
).
-You should now be able to follow the examples marked with the ‘intermediate’ and ‘advanced’ category in the examples gallery.
+In this article you learned how to customize the step()
+of your training loop using luz layered functionality.
+Luz also allows more flexible modifications of the training loop
+described in the Accelerator vignette
+(vignette("accelerator")
).
+You should now be able to follow the examples marked with the
+‘intermediate’ and ‘advanced’ category in the examples
+gallery.
@@ -248,7 +317,7 @@ Next steps
-Site built with pkgdown 2.0.7.9000.
+Site built with pkgdown 2.0.7.
diff --git a/articles/custom-loop_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/custom-loop_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/custom-loop_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/chargpt.html b/articles/examples/chargpt.html
index 98f09978..270e7efa 100644
--- a/articles/examples/chargpt.html
+++ b/articles/examples/chargpt.html
@@ -77,7 +77,8 @@ Guides
-
+
+
CharGPT
@@ -88,15 +89,24 @@ Guides
-This example is inspired by the chargpt project by Andrey Karpathy. We are going to train character-level language model on Shakespeare texts.
+This example is inspired by the chargpt
+project by Andrey Karpathy. We are going to train character-level
+language model on Shakespeare texts.
We first load the libraries that we plan to use:
-Next we define the torch dataset that will pre-process data for the model. It splits the text into a character vector, each element containing exactly one character.
-Then lists all unique characters into the vocab
attribute. The order of the characters in the vocabulary is used to encode each character to an integer value, that will be used in the embedding layer.
-The .getitem()
method, can take chunks of block_size
characters and encode them into their integer representation.
+Next we define the torch dataset that will pre-process data for the
+model. It splits the text into a character vector, each element
+containing exactly one character.
+Then lists all unique characters into the vocab
+attribute. The order of the characters in the vocabulary is used to
+encode each character to an integer value, that will be used in the
+embedding layer.
+The .getitem()
method, can take chunks of
+block_size
characters and encode them into their integer
+representation.
url <- "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
@@ -124,8 +134,15 @@ Guides
dataset <- char_dataset(readr::read_file(url))
dataset[1] # this allows us to see an element of the dataset
-We then define the neural net we are going to train. Defining a GPT-2 model is quite verbose, so we are going to use the minhub implementation directly. You can find the full model definition here, and this code is entirely self-contained, so you don’t need to install minhub, if you don’t want to.
-We also implemented the generate
method for the model, that allows one to generate completions using the model. It applies the model in a loop, at each iteration prediction what’s the next character.
+We then define the neural net we are going to train. Defining a GPT-2
+model is quite verbose, so we are going to use the minhub implementation
+directly. You can find the full model definition here,
+and this code is entirely self-contained, so you don’t need to install
+minhub, if you don’t want to.
+We also implemented the generate
method for the model,
+that allows one to generate completions using the model. It applies the
+model in a loop, at each iteration prediction what’s the next
+character.
model <- torch::nn_module(
initialize = function(vocab_size) {
@@ -155,7 +172,8 @@ Guides
x
}
)
-Next, we implemented a callback that is used for nicely displaying generated samples during the model training:
+Next, we implemented a callback that is used for nicely displaying
+generated samples during the model training:
# samples from the model using the context.
generate <- function(model, vocab, context, ...) {
@@ -203,7 +221,8 @@ Guides
luz_callback_gradient_clip(max_norm = 1)
)
)
-One epoch, is reasonable for this dataset and takes ~1h on the M1 MBP. You can generate new samples with:
+One epoch, is reasonable for this dataset and takes ~1h on the M1
+MBP. You can generate new samples with:
context <- "O God, O God!"
text <- generate(fitted$model, dataset$vocab, context, iter = 100)
@@ -220,7 +239,7 @@ Guides
diff --git a/articles/examples/chargpt_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/chargpt_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/chargpt_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/dogs-vs-cats-binary-classification.html b/articles/examples/dogs-vs-cats-binary-classification.html
index 3e4361be..d4d6932e 100644
--- a/articles/examples/dogs-vs-cats-binary-classification.html
+++ b/articles/examples/dogs-vs-cats-binary-classification.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Binary classification
@@ -175,7 +176,7 @@ Guides
diff --git a/articles/examples/dogs-vs-cats-binary-classification_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/dogs-vs-cats-binary-classification_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/dogs-vs-cats-binary-classification_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/index.html b/articles/examples/index.html
index 4714a06a..980d72f7 100644
--- a/articles/examples/index.html
+++ b/articles/examples/index.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Examples
@@ -88,7 +89,9 @@ Guides
-This gallery of examples uses luz to train and validate a range of common deep learning architectures. The gallery also demonstrates basic and advanced usage of luz.
+This gallery of examples uses luz to train and validate a range of
+common deep learning architectures. The gallery also demonstrates basic
+and advanced usage of luz.
@@ -129,7 +134,8 @@
basic
-Builds an autoencoder for the MNIST dataset. Demonstrates overwriting the predict method
+Builds an autoencoder for the MNIST dataset. Demonstrates overwriting
+the predict method
See code
@@ -145,7 +151,8 @@
Showcases how to create a custom fully customized training step
-See code
+See
+code
@@ -219,7 +226,8 @@
intermediate
-Implements a UNET model to separate the background of images of cats and dogs.
+Implements a UNET model to separate the background of images of cats and
+dogs.
See code
@@ -240,6 +248,23 @@
+
+
+
+
+
+Training a causal language model from scratch
+
+advanced
+
+Implements datasets and trains a causal language model from scratch
+using R source code.
+
+See code
+
+
+
+
@@ -274,7 +299,7 @@
tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-autoencoder.html b/articles/examples/mnist-autoencoder.html
index 4e2ff1e5..8827238a 100644
--- a/articles/examples/mnist-autoencoder.html
+++ b/articles/examples/mnist-autoencoder.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Autoencoder
@@ -180,7 +181,7 @@ Guides
diff --git a/articles/examples/mnist-autoencoder_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-autoencoder_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-autoencoder_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-cnn-virtual-batch-size.html b/articles/examples/mnist-cnn-virtual-batch-size.html
index cca80ae9..874ecdfd 100644
--- a/articles/examples/mnist-cnn-virtual-batch-size.html
+++ b/articles/examples/mnist-cnn-virtual-batch-size.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Virtual batch size
@@ -199,7 +200,7 @@ Guides
diff --git a/articles/examples/mnist-cnn-virtual-batch-size_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-cnn-virtual-batch-size_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-cnn-virtual-batch-size_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-cnn.html b/articles/examples/mnist-cnn.html
index cbb2b246..addd84ac 100644
--- a/articles/examples/mnist-cnn.html
+++ b/articles/examples/mnist-cnn.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Simple CNN
@@ -177,7 +178,7 @@ Guides
diff --git a/articles/examples/mnist-cnn_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-cnn_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-cnn_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-dcgan.html b/articles/examples/mnist-dcgan.html
index 5343ee76..02b9214e 100644
--- a/articles/examples/mnist-dcgan.html
+++ b/articles/examples/mnist-dcgan.html
@@ -77,7 +77,8 @@ Guides
-
+
+
DCGAN
@@ -266,7 +267,7 @@ Guides
diff --git a/articles/examples/mnist-dcgan_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-dcgan_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-dcgan_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-mixup.html b/articles/examples/mnist-mixup.html
index 31937768..d9510598 100644
--- a/articles/examples/mnist-mixup.html
+++ b/articles/examples/mnist-mixup.html
@@ -77,7 +77,8 @@ Guides
-
+
+
MixUp augmentation
@@ -187,7 +188,7 @@ Guides
diff --git a/articles/examples/mnist-mixup_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-mixup_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-mixup_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/mnist-triplet.html b/articles/examples/mnist-triplet.html
index 5ae47134..fdb036ab 100644
--- a/articles/examples/mnist-triplet.html
+++ b/articles/examples/mnist-triplet.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Triplet loss
@@ -196,7 +197,7 @@ Guides
diff --git a/articles/examples/mnist-triplet_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/mnist-triplet_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/mnist-triplet_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/pets-unet.html b/articles/examples/pets-unet.html
index 6543772e..6ab3a0be 100644
--- a/articles/examples/pets-unet.html
+++ b/articles/examples/pets-unet.html
@@ -77,7 +77,8 @@ Guides
-
+
+
UNET implementation
@@ -309,7 +310,7 @@ Guides
diff --git a/articles/examples/pets-unet_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/pets-unet_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/pets-unet_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/text-classification.html b/articles/examples/text-classification.html
index f48170e1..f6437dee 100644
--- a/articles/examples/text-classification.html
+++ b/articles/examples/text-classification.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Text classification from scratch
@@ -88,14 +89,21 @@ Guides
-This example is a port of ‘Text classification from scratch’ from Keras documentation by Mark Omerick and François Chollet.
-First we implement a torch dataset that downloads and pre-process the data. The initialize method is called when we instantiate a dataset. Our implementation:
+This example is a port of ‘Text
+classification from scratch’ from Keras documentation by Mark
+Omerick and François Chollet.
+First we implement a torch dataset that downloads and pre-process the
+data. The initialize method is called when we instantiate a dataset. Our
+implementation:
-- Downloads the IMDB dataset if it doesn’t exist in the
root
directory.
+- Downloads the IMDB dataset if it doesn’t exist in the
+
root
directory.
- Extracts the files into
root
.
- Creates a tokenizer using the files in the training set.
-We also implement the .getitem
method that is used to extract a single element from the dataset and pre-process the file contents.
+We also implement the .getitem
method that is used to
+extract a single element from the dataset and pre-process the file
+contents.
library(torch)
library(tok)
@@ -174,7 +182,9 @@ Guides
train_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "train")
test_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "test")
-We now define the model we want to train. The model is a 1D convnet starting with an embedding layer and we plug a classifier at the output.
+We now define the model we want to train. The model is a 1D convnet
+starting with an embedding layer and we plug a classifier at the
+output.
model <- nn_module(
initialize = function(vocab_size, embedding_dim) {
@@ -226,7 +236,8 @@ Guides
We can finally obtain the metrics on the test dataset:
-Remember that in order to predict for texts, we need make the same pre-processing as used in the dataset definition.
+Remember that in order to predict for texts, we need make the same
+pre-processing as used in the dataset definition.
@@ -239,7 +250,7 @@ Guides
diff --git a/articles/examples/text-classification_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/text-classification_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/text-classification_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/examples/text-generation.html b/articles/examples/text-generation.html
new file mode 100644
index 00000000..4558aaf8
--- /dev/null
+++ b/articles/examples/text-generation.html
@@ -0,0 +1,393 @@
+
+
+
+
+
+
+
+
+Training a causal language model from scratch • luz
+
+
+
+
+
+
+
+
+
+ Skip to contents
+
+
+
+
+
+
+
+
+
+ Training a causal language model from scratch
+
+
+ Source: vignettes/examples/text-generation.Rmd
+ text-generation.Rmd
+
+
+
+
+This example is an adaptation of the ‘Training a causal language
+model from scratch’ class from the Hugging
+Face NLP course.
+
+library(torch)
+library(tok)
+library(luz)
+library(minhub) # remotes::install_github("mlverse/minhub")
+#library(tidyverse)
+options(arrow.skip_nul = TRUE)
+library(arrow)
+
+Data
+
+First step is to implement a torch dataset that gathers data and
+pre-process it into a format that is suitable for training the
+model.
+That means that we need to:
+
+- Download data
+- Train a tokenizer for this dataset
+- Be able to produce sequences of tokens in the format expected by the
+model
+
+We are going to use 2 datasets available in Hugging Face Hub. The
+first contain all R packages source code available on CRAN. The second
+contains all R code that is available in GitHub data dumps. Both
+datasets are in the Parquet format. Following we implement a function
+that downloads and caches the data and then returns a single arrow table
+containing all data.
+
+read_dataset <- function(source) {
+ d <- source |>
+ hfhub::hub_snapshot(repo_type = "dataset", allow_patterns = "parquet$") |>
+ fs::path("data/r") |>
+ arrow::open_dataset() |>
+ dplyr::filter(stringr::str_detect(path, ".*\\.[rR]$")) |>
+ dplyr::select(content) |>
+ dplyr::mutate(content = arrow::cast(content, arrow::string())) |>
+ dplyr::filter(!is.na(content)) |>
+ dplyr::collect() %>%
+ # the dataset contains invalid utf8 characters...
+ # we need to remove them, otherwise we get an error from tokenizers
+ dplyr::filter(utf8::utf8_valid(content))
+}
+
+read_datasets <- function() {
+ dplyr::bind_rows(
+ read_dataset("dfalbel/cran-packages"),
+ read_dataset("dfalbel/github-r-repos")
+ )
+}
+Next we implement a function that trains a tokenizer for our
+dataset.
+
+create_tokenizer <- function(text, vocab_size, special_tokens) {
+ tok <- tok::tokenizer$new(tok::model_bpe$new())
+
+ tok$pre_tokenizer <- tok::pre_tokenizer_byte_level$new(add_prefix_space = FALSE)
+ tok$decoder <- tok::decoder_byte_level$new()
+ tok$post_processor <- tok::processor_byte_level$new(trim_offsets = FALSE)
+
+ tok$train_from_memory(
+ text,
+ tok::trainer_bpe$new(vocab_size = vocab_size, special_tokens = special_tokens)
+ )
+ tok
+}
+
+# test code to debug the tokenizer
+# data <- read_datasets()
+# tok <- create_tokenizer(data$content)
+We can finally implement the torch dataset that we are going to use
+for training the model. We are going to use the
+torch::iterable_dataset
instead of
+torch::dataset
. The main motivation is that we can’t really
+know the total number of samples in the dataset, so we can implement a
+.getitem()
method to get any arbiratrary sample. Thus we
+implement the .iter
method that returns a new sample every
+time it’s called.
+
+r_sources_dataset <- torch::iterable_dataset(
+ "r_sources_dataset",
+ initialize = function(root = ".", vocab_size = 20000, context_length = 128) {
+ self$data <- read_datasets()
+ self$context_length <- context_length
+ self$index <- sample.int(nrow(self$data))
+
+ # we only create a tokenizer if it doesn't exist, otherwise we just load it
+ tok_path <- file.path(root, glue::glue("tokenizer-{vocab_size}.json"))
+ if (!file.exists(tok_path)) {
+ self$tok <- create_tokenizer(
+ as.character(self$data$content),
+ vocab_size,
+ c("<fbegin>", "<fend>")
+ )
+ fs::dir_create(root)
+ self$tok$save(tok_path)
+ } else {
+ self$tok <- tok::tokenizer$from_file(tok_path)
+ }
+ },
+ .iter = function() {
+ i <- 1L
+ sequence <- c()
+ function() {
+ while (length(sequence) < (self$context_length + 1) && i <= nrow(self$data)) {
+ sequence <<- c(
+ sequence,
+ self$tok$encode(paste("<fbegin>", as.character(self$data$content[self$index[i]]), "<fend>"))$ids
+ )
+ i <- i + 1L
+ }
+
+ if (length(sequence) < (self$context_length + 1)) {
+ return(coro::exhausted())
+ }
+
+ on.exit({
+ sequence <<- sequence[-seq_len(self$context_length)]
+ })
+ list(
+ input_ids = sequence[seq_len(self$context_length)] + 1L,
+ labels = sequence[2:(self$context_length + 1)] + 1L
+ )
+ }
+ }
+)
+
+# debug code for the dataset
+# ds <- r_sources_dataset("~/Downloads/")
+# it <- ds$.iter()
+# it()
+# ds$tok$get_vocab_size()
+This dataset is likely too large for us to train the model on all
+documents in this example. It’s also hard to predict how long it will
+take for it to train until the end. In order to make it easier, we
+define a wraper dataset that is used to run the above dataset for a
+fixed number of steps. This is not required, but makes using luz more
+pleasant, as we can easily define for how many tokens we want to train
+our model.
+
+fixed_steps_iterable_dataset <- iterable_dataset(
+ "fixed_steps_dataset",
+ initialize = function(dataset, steps) {
+ self$dataset <- dataset
+ self$steps <- steps
+ },
+ .iter = function() {
+ i <- 1L
+ iter <- NULL
+ function() {
+ if (i > self$steps) {
+ return(coro::exhausted())
+ }
+
+ i <<- i + 1L
+
+ if (is.null(iter) || coro::is_exhausted(data <- iter())) {
+ iter <<- self$dataset$.iter()
+ data <- iter()
+ }
+
+ data
+ }
+ },
+ .length = function() {
+ self$steps
+ }
+)
+We finally define the model we are going to train. We’ll use a small
+version of GPT2. We also define a generate
method allowing
+us to sample from the model given an initial context.
+
+net <- nn_module(
+ initialize = function() {
+ self$gpt <- minhub::gpt2(
+ vocab_size = 20000,
+ pdrop = 0.1
+ )
+ },
+ forward = function(x) {
+ self$gpt(x)$transpose(2,3)
+ },
+ generate = function(x, temperature = 1, iter = 50, top_k = 10) {
+ # samples from the model givn a context vector.
+ for (i in seq_len(iter)) {
+ logits <- self$forward(x)[,,-1]
+ logits <- logits/temperature
+ c(prob, ind) %<-% logits$topk(top_k)
+ logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
+ logits <- nnf_softmax(logits, dim = -1)
+ id_next <- torch_multinomial(logits, num_samples = 1)
+ x <- torch_cat(list(x, id_next), dim = 2)
+ }
+ x
+ }
+)
+
+# debug code for the model
+# ds <- torch::dataloader(r_sources_dataset("~/Downloads/"), batch_size = 32)
+# batch <- coro::collect(ds, 1)[[1]]
+# str(batch)
+# m <- net()
+# str(m(batch$input_ids))
+To make it easier to inspect training, we will also define a callback
+that prints a sample from the model every epoch.
+
+# samples from the model using the context.
+generate <- function(model, tok, context, ...) {
+ local_no_grad() # disables gradient for sampling
+ x <- tok$encode(context)$ids + 1L
+ x <- torch_tensor(x)[NULL,]$to(device = model$device)
+ content <- as.integer(model$generate(x, ...)$cpu())
+ tok$decode(content - 1L)
+}
+
+display_cb <- luz_callback(
+ initialize = function() {},
+ on_epoch_end = function() {
+ local_no_grad()
+ # sample from the model...
+ context <- "# creates a linear model"
+ text <- generate(ctx$model, dataset$dataset$tok, context, iter = 100)
+ cli::cli_rule()
+ cat(text, "\n")
+ cli::cli_rule()
+ }
+)
+We can finally train the model. We define that we want to train the
+model for half a billion tokens in a total of 100 epochs.
+
+n_tokens <- 500e6
+batch_size <- 16
+epochs <- 100
+context_length <- 256L
+
+steps <- n_tokens / context_length / epochs
+dataset <- fixed_steps_iterable_dataset(
+ r_sources_dataset(context_length = context_length),
+ steps = steps
+)
+
+fitted <- net %>%
+ setup(
+ optimizer = optim_adam,
+ loss = nn_cross_entropy_loss()
+ ) %>%
+ set_opt_hparams(lr = 3e-4) |>
+ fit(
+ dataset,
+ epochs = epochs,
+ dataloader_options = list(batch_size = batch_size),
+ callbacks = list(
+ luz_callback_lr_scheduler(
+ torch::lr_one_cycle,
+ max_lr = 0.1,
+ epochs = epochs,
+ steps_per_epoch = steps/batch_size,
+ call_on = "on_batch_end"
+ ),
+ luz_callback_gradient_clip(max_norm = 1),
+ display_cb()
+ ),
+ verbose = TRUE
+ )
+
+luz::luz_save(fitted, "model.pt")
+We can then use the model to generate text given a prompt with:
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/articles/get-started.html b/articles/get-started.html
index b5ff5a94..d82dc2d4 100644
--- a/articles/get-started.html
+++ b/articles/get-started.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Get started with luz
@@ -91,18 +92,42 @@ Guides
-Luz is a high-level API for torch that aims to encapsulate the training loop into a set of reusable pieces of code. Luz reduces the boilerplate code required to train a model with torch and avoids the error prone zero_grad()
- backward()
- step()
sequence of calls, and also simplifies the process of moving data and models between CPUs and GPUs. Luz is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control you need for your training loop.
-Luz is heavily inspired by other higher level frameworks for deep learning, to cite a few:
+Luz is a high-level API for torch that aims to encapsulate the
+training loop into a set of reusable pieces of code.
+Luz reduces the boilerplate code required to train a model with torch
+and avoids the error prone zero_grad()
-
+backward()
- step()
sequence of calls, and
+also simplifies the process of moving data and models between CPUs and
+GPUs. Luz is designed to be highly flexible by providing a layered API
+that allows it to be useful no matter the level of control you need for
+your training loop.
+Luz is heavily inspired by other higher level frameworks for deep
+learning, to cite a few:
-FastAI: we are heavily inspired by the FastAI library, especially the Learner
object and the callbacks API.
-Keras: We are also heavily inspired by Keras, especially callback names. The lightning module interface is similar to compile
, too.
-PyTorch Lightning: The idea of the luz_module
being a subclass of nn_module
is inspired by the LightningModule
object in lightning.
-HuggingFace Accelerate: The internal device placement API is heavily inspired by Accelerate, but is much more modest in features. Currently only CPU and Single GPU are supported.
+FastAI: we are heavily
+inspired by the FastAI library, especially the Learner
+object and the callbacks API.
+Keras: We are also heavily
+inspired by Keras, especially callback names. The lightning module
+interface is similar to compile
, too.
+PyTorch
+Lightning: The idea of the luz_module
being a subclass
+of nn_module
is inspired by the
+LightningModule
object in
+lightning.
+HuggingFace
+Accelerate: The internal device placement API is heavily inspired by
+Accelerate, but is much more modest in features. Currently only CPU and
+Single GPU are supported.
Training a nn_module
-As much as possible, luz tries to reuse the existing structures from torch. A model in luz is defined identically as you would define it if using raw torch. For a specific example, this is the definition of a feed-forward CNN that can be used to classify digits from the MNIST dataset:
+As much as possible, luz tries to reuse the existing structures from
+torch. A model in luz is defined identically as you would define it if
+using raw torch. For a specific example, this is the definition of a
+feed-forward CNN that can be used to classify digits from the MNIST
+dataset:
-We can now train this model in the train_dl
and validate it in the test_dl
torch::dataloaders()
with:
+We can now train this model in the train_dl
and validate
+it in the test_dl
torch::dataloaders()
+with:
fitted <- net %>%
setup(
@@ -144,21 +171,51 @@ Training a nn_module
fit(train_dl, epochs = 10, valid_data = test_dl)
Let’s understand what happens in this chunk of code:
-- The
setup
function allows you to configure the loss (objective) function and the optimizer that you will use to train your model. Optionally you can pass a list of metrics that are tracked during the training procedure. Note: the loss function can be any function taking input
and target
tensors and returning a scalar tensor value, and the optimizer can be any core torch optimizer or custom ones created with the torch::optimizer()
function.
-- The
set_hparams()
function allows you to set hyper-parameters that should be passed to the module initialize()
method. For example in this case we pass num_classes = 10
.
-- The
set_opt_hparams()
function allows you to pass hyper-parameters that are used by the optimizer function. For example, optim_adam()
can take the lr
parameter specifying the learning rate and we specify it with lr = 0.003
.
-- The
fit
method will take the model specification provided by setup()
and run the training procedure using the specified training and validation torch::dataloaders()
as well as the number of epochs. Note: we again reuse core torch data structures, instead of providing our own data loading functionality.
-- The returned object
fitted
contains the trained model as well as the record of metrics and losses produced during training. It can also be used for producing predictions and for evaluating the trained model on other datasets.
+- The
setup
function allows you to configure the loss
+(objective) function and the optimizer that you will use to train your
+model. Optionally you can pass a list of metrics that are tracked during
+the training procedure. Note: the loss function can be
+any function taking input
and target
tensors
+and returning a scalar tensor value, and the optimizer can be any core
+torch optimizer or custom ones created with the
+torch::optimizer()
function.
+- The
set_hparams()
function allows you to set
+hyper-parameters that should be passed to the module
+initialize()
method. For example in this case we pass
+num_classes = 10
.
+- The
set_opt_hparams()
function allows you to pass
+hyper-parameters that are used by the optimizer function. For example,
+optim_adam()
can take the lr
parameter
+specifying the learning rate and we specify it with
+lr = 0.003
.
+- The
fit
method will take the model specification
+provided by setup()
and run the training procedure using
+the specified training and validation torch::dataloaders()
+as well as the number of epochs. Note: we again reuse
+core torch data structures, instead of providing our own data loading
+functionality.
+- The returned object
fitted
contains the trained model
+as well as the record of metrics and losses produced during training. It
+can also be used for producing predictions and for evaluating the
+trained model on other datasets.
-When fitting, luz will use the fastest possible accelerator; if a CUDA-capable GPU is available it will be used, otherwise we fall back to the CPU. It also automatically moves data, optimizers, and models to the selected device so you don’t need to handle it manually (which is in general very error prone).
-To create predictions from the trained model you can use the predict
method:
+When fitting, luz will use the fastest possible accelerator; if a
+CUDA-capable GPU is available it will be used, otherwise we fall back to
+the CPU. It also automatically moves data, optimizers, and models to the
+selected device so you don’t need to handle it manually (which is in
+general very error prone).
+To create predictions from the trained model you can use the
+predict
method:
predictions <- predict(fitted, test_dl)
The training loop
-You now have a general idea of how to use the fit
function and now it’s important to have an overview of what’s happening inside it. In pseudocode, here’s what fit
does. This is not fully detailed but should help you to build your intuition:
+You now have a general idea of how to use the fit
+function and now it’s important to have an overview of what’s happening
+inside it. In pseudocode, here’s what fit
does. This is not
+fully detailed but should help you to build your intuition:
# -> Initialize objects: model, optimizers.
# -> Select fitting device.
@@ -184,25 +241,45 @@ The training loop
Metrics
-One of the most important parts in machine learning projects is choosing the evaluation metric. Luz allows tracking many different metrics during training with minimal code changes.
-In order to track metrics, you only need to modify the metrics
parameter in the setup
function:
-
-Luz provides implementations of a few of the most used metrics. If a metric is not available you can always implement a new one using the luz_metric
function.
-In order to implement a new luz_metric
we need to implement 3 methods:
+One of the most important parts in machine learning projects is
+choosing the evaluation metric. Luz allows tracking many different
+metrics during training with minimal code changes.
+In order to track metrics, you only need to modify the
+metrics
parameter in the setup
function:
+<- net %>%
+ fitted setup(
+
+ ...metrics = list(
+
+ luz_metric_accuracy
+ )%>%
+ ) fit(...)
+Luz provides implementations of a few of the most used metrics. If a
+metric is not available you can always implement a new one using the
+luz_metric
function.
+In order to implement a new luz_metric
we need to
+implement 3 methods:
-initialize
: defines the metric initial state. This function is called for each epoch for both training and validation loops.
-update
: updates the metric internal state. This function is called at every training and validation step with the predictions obtained by the model and the target values obtained from the dataloader.
-compute
: uses the internal state to compute metric values. This function is called whenever we need to obtain the current metric value. Eg, it’s called every training step for metrics displayed in the progress bar, but only called once per epoch to record it’s value when the progress bar is not displayed.
+initialize
: defines the metric initial state. This
+function is called for each epoch for both training and validation
+loops.
+update
: updates the metric internal state. This
+function is called at every training and validation step with the
+predictions obtained by the model and the target values obtained from
+the dataloader.
+compute
: uses the internal state to compute metric
+values. This function is called whenever we need to obtain the current
+metric value. Eg, it’s called every training step for metrics displayed
+in the progress bar, but only called once per epoch to record it’s value
+when the progress bar is not displayed.
-Optionally, you can implement an abbrev
field that gives the metric an abbreviation that will be used when displaying metric information in the console or tracking record. If no abbrev
is passed, the class name will be used.
-Let’s take a look at the implementation of luz_metric_accuracy
so you can see how to implement a new one:
+Optionally, you can implement an abbrev
field that gives
+the metric an abbreviation that will be used when displaying metric
+information in the console or tracking record. If no abbrev
+is passed, the class name will be used.
+Let’s take a look at the implementation of
+luz_metric_accuracy
so you can see how to implement a new
+one:
luz_metric_accuracy <- luz_metric(
# An abbreviation to be shown in progress bars, or
@@ -230,13 +307,20 @@ Metrics
self$correct/self$total
}
)
-Note: It’s good practice that the compute
metric returns regular R values instead of torch tensors and other parts of luz will expect that.
+Note: It’s good practice that the
+compute
metric returns regular R values instead of torch
+tensors and other parts of luz will expect that.
Evaluate
-Once a model has been trained you might want to evaluate its performance on a different dataset. For that reason, luz provides the ?evaluate
function that takes a fitted model and a dataset and computes the metrics attached to the model.
-Evaluate returns a luz_module_evaluation
object that you can query for metrics using the get_metrics
function or simply print
to see the results.
+Once a model has been trained you might want to evaluate its
+performance on a different dataset. For that reason, luz provides the
+?evaluate
function that takes a fitted model and a dataset
+and computes the metrics attached to the model.
+Evaluate returns a luz_module_evaluation
object that you
+can query for metrics using the get_metrics
function or
+simply print
to see the results.
For example:
evaluation <- fitted %>% evaluate(data = valid_dl)
@@ -252,16 +336,32 @@ Evaluate
Customizing with callbacks
-Luz provides different ways to customize the training progress depending on the level of control you need in the training loop. The fastest way and the more ‘reusable’, in the sense that you can create training modifications that can be used in many different situations, is via callbacks.
-The training loop in luz has many breakpoints that can call arbitrary R functions. This functionality allows you to customize the training process without having to modify the general training logic.
-Luz implements 3 default callbacks that occur in every training procedure:
+Luz provides different ways to customize the training progress
+depending on the level of control you need in the training loop. The
+fastest way and the more ‘reusable’, in the sense that you can create
+training modifications that can be used in many different situations, is
+via callbacks.
+The training loop in luz has many breakpoints that can call
+arbitrary R functions. This functionality allows you to customize the
+training process without having to modify the general training
+logic.
+Luz implements 3 default callbacks that occur in every training
+procedure:
-train-eval callback: Sets the model to train()
or eval()
depending on if the procedure is doing training or validation.
-metrics callback: evaluate metrics during training and validation process.
-progress callback: implements a progress bar and prints progress information during training.
+train-eval callback: Sets the model to
+train()
or eval()
depending on if the
+procedure is doing training or validation.
+metrics callback: evaluate metrics during
+training and validation process.
+progress callback: implements a progress bar and
+prints progress information during training.
-You can also implement custom callbacks that modify or act specifically for your training procedure. For example:
-Let’s implement a callback that prints ‘Iteration n
’ (where n
is the iteration number) for every batch in the training set and ‘Done’ when an epoch is finished. For that task we use the luz_callback
function:
+You can also implement custom callbacks that modify or act
+specifically for your training procedure. For example:
+Let’s implement a callback that prints ‘Iteration n
’
+(where n
is the iteration number) for every batch in the
+training set and ‘Done’ when an epoch is finished. For that task we use
+the luz_callback
function:
print_callback <- luz_callback(
name = "print_callback",
@@ -275,16 +375,30 @@ Customizing with callbacks cat(self$message, "\n")
}
)
-luz_callback()
takes named functions as ...
arguments, where the name indicates the moment at which the callback should be called. For instance on_train_batch_end()
is called for every batch at the end of the training procedure, and on_epoch_end()
is called at the end of every epoch.
-The returned value of luz_callback()
is a function that initializes an instance of the callback. Callbacks can have initialization parameters, like the name of a file where you want to log the results. In that case, you can pass an initialize
method when creating the callback definition, and save these parameters to the self
object. In the above example, the callback has a message
parameter that is printed at the end of each epoch.
-Once a callback is defined it can be passed to the fit
function via the callbacks
parameter:
+luz_callback()
takes named functions as ...
+arguments, where the name indicates the moment at which the callback
+should be called. For instance on_train_batch_end()
is
+called for every batch at the end of the training procedure, and
+on_epoch_end()
is called at the end of every epoch.
+The returned value of luz_callback()
is a function that
+initializes an instance of the callback. Callbacks can have
+initialization parameters, like the name of a file where you want to log
+the results. In that case, you can pass an initialize
+method when creating the callback definition, and save these parameters
+to the self
object. In the above example, the callback has
+a message
parameter that is printed at the end of each
+epoch.
+Once a callback is defined it can be passed to the fit
+function via the callbacks
parameter:
-Callbacks can be called in many different positions of the training loop, including combinations of them. Here’s an overview of possible callback breakpoints:
+Callbacks can be called in many different positions of the training
+loop, including combinations of them. Here’s an overview of possible
+callback breakpoints:
Start Fit
- on_fit_begin
Start Epoch Loop
@@ -320,10 +434,27 @@ Customizing with callbacks
-Every step market with on_*
is a point in the training procedure that is available for callbacks to be called.
-The other important part of callbacks is the ctx
(context) object. See help("ctx")
for details.
-By default, callbacks are called in the same order as they were passed to fit
(or predict
or evaluate
), but you can provide a weight
attribute that will control the order in which it will be called. For example, if one callback has weight = 10
and another has weight = 1
, then the first one is called after the second one. Callbacks that don’t specify a weight
attribute are considered weight = 0
. A few built-in callbacks in luz already provide a weight value. For example, the ?luz_callback_early_stopping
has a weight of Inf
, since in general we want to run it as the last thing in the loop.
-The ctx
object is used in luz to share information between the training loop and callbacks, model methods, and metrics. The table below describes information available in the ctx
by default. Other callbacks could potentially modify these attributes or add new ones.
+Every step market with on_*
is a point in the training
+procedure that is available for callbacks to be called.
+The other important part of callbacks is the ctx
+(context) object. See help("ctx")
for details.
+By default, callbacks are called in the same order as they were
+passed to fit
(or predict
or
+evaluate
), but you can provide a weight
+attribute that will control the order in which it will be called. For
+example, if one callback has weight = 10
and another has
+weight = 1
, then the first one is called after the second
+one. Callbacks that don’t specify a weight
attribute are
+considered weight = 0
. A few built-in callbacks in luz
+already provide a weight value. For example, the
+?luz_callback_early_stopping
has a weight of
+Inf
, since in general we want to run it as the last thing
+in the loop.
+The ctx
object is used in luz to share information
+between the training loop and callbacks, model methods, and metrics. The
+table below describes information available in the ctx
by
+default. Other callbacks could potentially modify these attributes or
+add new ones.
Attributes in ctx
can be used to produce the desired behavior of callbacks. You can find information about the context object using help("ctx")
. In our example, we use the ctx$iter
attribute to print the iteration number for each training batch.
+Attributes in ctx
can be used to produce the desired
+behavior of callbacks. You can find information about the context object
+using help("ctx")
. In our example, we use the
+ctx$iter
attribute to print the iteration number for each
+training batch.
Next steps
-In this article you learned how to train your first model using luz and the basics of customization using both custom metrics and callbacks.
-Luz also allows more flexible modifications of the training loop described in vignette("custom-loop")
.
-You should now be able to follow the examples marked with the ‘basic’ category in the examples gallery.
+In this article you learned how to train your first model using luz
+and the basics of customization using both custom metrics and
+callbacks.
+Luz also allows more flexible modifications of the training loop
+described in vignette("custom-loop")
.
+You should now be able to follow the examples marked with the ‘basic’
+category in the examples
+gallery.
@@ -472,7 +654,7 @@ Next steps
-Site built with pkgdown 2.0.7.9000.
+Site built with pkgdown 2.0.7.
diff --git a/articles/get-started_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/get-started_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/get-started_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/articles/index.html b/articles/index.html
index 2bd5852c..81608f8f 100644
--- a/articles/index.html
+++ b/articles/index.html
@@ -63,21 +63,19 @@ All vignettes
- Accelerator API
- -
-
- CharGPT
-
- Checkpointing your models
-
+
- Using the learning rate finder
+ -
- Custom loops with luz
-
-
- Binary classification
+ - CharGPT
-
-
- Get started with luz
+ - Binary classification
-
- Examples
-
-
- Using the learning rate finder
- -
- Autoencoder
-
- Virtual batch size
@@ -94,6 +92,10 @@ All vignettes
-
- Text classification from scratch
-
+
- Training a causal language model from scratch
+ -
+
- Get started with luz
+ -
@@ -103,7 +105,7 @@ All vignettes
diff --git a/articles/lr-finder.html b/articles/lr-finder.html
index 876a03ce..7a8d20d0 100644
--- a/articles/lr-finder.html
+++ b/articles/lr-finder.html
@@ -77,7 +77,8 @@ Guides
-
+
+
Using the learning rate finder
@@ -94,9 +95,25 @@ Guides
library(torchvision)
set.seed(1)
torch::torch_manual_seed(1703)
-In this article we discuss how to find a good learning rate for your model. Finding a good learning rate is essential to be able to fit your model. If it’s too low, you will need too many iterations for your loss to converge, and that might be impractical if your model takes too long to run. If it’s too high, the loss can explode and you might never be able to minimize the loss.
-The learning rate can be considered another hyperparameter of your model that needs to be tuned but, there are techniques that allow you to select a good learning rate for your model without having to use the costly strategy of fitting many models with different learning rates and then choosing the one with better results.
-This article by Leslie Smith that became popular once their approach had been implemented in the popular FastAI framework, proposes that we should start with a very small learning rate and slowly increase it until we reach a high learning rate. At each iteration we record the loss value and in the end we plot it against the learning rate. We can then use these results to decide on a good learning rate. That’s what lr_finder
does, and we will show how to use it.
+In this article we discuss how to find a good learning rate for your
+model. Finding a good learning rate is essential to be able to fit your
+model. If it’s too low, you will need too many iterations for your loss
+to converge, and that might be impractical if your model takes too long
+to run. If it’s too high, the loss can explode and you might never be
+able to minimize the loss.
+The learning rate can be considered another hyperparameter of your
+model that needs to be tuned but, there are techniques that allow you to
+select a good learning rate for your model without having to use the
+costly strategy of fitting many models with different learning rates and
+then choosing the one with better results.
+This article by Leslie
+Smith that became popular once their approach had been implemented in
+the popular FastAI framework, proposes that we should start with a very
+small learning rate and slowly increase it until we reach a high
+learning rate. At each iteration we record the loss value and in the end
+we plot it against the learning rate. We can then use these results to
+decide on a good learning rate. That’s what lr_finder
does,
+and we will show how to use it.
First let’s download and prepare the MNIST dataset:
dir <- "~/Downloads/mnist" # caching directory
@@ -108,7 +125,8 @@ Guides
)
#> Processing...
#> Done!
-We can now define our model. We are going to use a small, straightforward CNN in the LeNet style.
+We can now define our model. We are going to use a small,
+straightforward CNN in the LeNet style.
net <- nn_module(
"net",
@@ -135,7 +153,11 @@ Guides
self$classifier()
}
)
-We can now use the lr_finder
function to record the loss with different learning rates. It’s important to use the learning rate finder with all other hyperparameters of the model fixed because they can influence the choice of the learning rate. For example, depending on the batch size, you might want to choose different learning rates.
+We can now use the lr_finder
function to record the loss
+with different learning rates. It’s important to use the learning rate
+finder with all other hyperparameters of the model fixed because they
+can influence the choice of the learning rate. For example, depending on
+the batch size, you might want to choose different learning rates.
model <- net %>% setup(
loss = torch::nn_cross_entropy_loss(),
@@ -155,15 +177,26 @@ Guides
#> Classes 'lr_records' and 'data.frame': 100 obs. of 2 variables:
#> $ lr : num 1.15e-06 1.32e-06 1.51e-06 1.74e-06 2.00e-06 ...
#> $ loss: num 2.31 2.3 2.29 2.3 2.31 ...
-The result is a data frame with the losses and the learning rate in each step. You can use the built-in plot method to display the exact results, along with a exponentially smoothed value of the loss.
+The result is a data frame with the losses and the learning rate in
+each step. You can use the built-in plot method to display the exact
+results, along with a exponentially smoothed value of the loss.
plot(records) +
ggplot2::coord_cartesian(ylim = c(NA, 5))
-We can see that with small learning rates the loss doesn’t decrease. At some point the loss starts decreasing until it reaches a point where it starts increasing and explodes.
-And how do we choose the learning rate using this plot? Sylvain Gugger asked the same question in this blog post and we are quoting his answer:
+We can see that with small learning rates the loss doesn’t decrease.
+At some point the loss starts decreasing until it reaches a point where
+it starts increasing and explodes.
+And how do we choose the learning rate using this plot? Sylvain
+Gugger asked the same question in this blog
+post and we are quoting his answer:
-Not the one corresponding to the minimum. Why? Well the learning rate that corresponds to the minimum value is already a bit too high, since we are at the edge between improving and getting all over the place. We want to go one order of magnitude before, a value that’s still aggressive (so that we train quickly) but still on the safe side from an explosion.
+Not the one corresponding to the minimum. Why? Well the learning rate
+that corresponds to the minimum value is already a bit too high, since
+we are at the edge between improving and getting all over the place. We
+want to go one order of magnitude before, a value that’s still
+aggressive (so that we train quickly) but still on the safe side from an
+explosion.
In the above example we would choose 1e-3 instead of 1e-2.
@@ -178,7 +211,7 @@ Guides
diff --git a/articles/lr-finder_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/lr-finder_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/lr-finder_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
-// v0.0.1
-// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
-
-document.addEventListener('DOMContentLoaded', function() {
- const codeList = document.getElementsByClassName("sourceCode");
- for (var i = 0; i < codeList.length; i++) {
- var linkList = codeList[i].getElementsByTagName('a');
- for (var j = 0; j < linkList.length; j++) {
- if (linkList[j].innerHTML === "") {
- linkList[j].setAttribute('aria-hidden', 'true');
- }
- }
- }
-});
diff --git a/authors.html b/authors.html
index 91358a77..60012493 100644
--- a/authors.html
+++ b/authors.html
@@ -95,7 +95,7 @@ Citation
diff --git a/index.html b/index.html
index 6d2484de..02b9dcd1 100644
--- a/index.html
+++ b/index.html
@@ -5,14 +5,24 @@
-
+
Higher Level API for torch • luz
-
+
+
Luz is a higher level API for torch providing abstractions to allow for much less verbose training loops.
This package is still under development.
It is heavily inspired by other higher level frameworks for deep learning, to cite a few:
@@ -190,7 +201,7 @@ Dev status
diff --git a/news/index.html b/news/index.html
index 1caa70f2..488f9cb4 100644
--- a/news/index.html
+++ b/news/index.html
@@ -101,7 +101,8 @@ Bug fixes
luz 0.3.1
CRAN release: 2022-09-06
-- Re-submission to fix vignette rendering.
+- Re-submission to fix vignette rendering.
+
luz 0.3.0
CRAN release: 2022-08-19
@@ -113,7 +114,8 @@ Breaking changes
Documentation
-- Many wording improvements in the getting started guides (#81 #94, @jonthegeek).
+- Many wording improvements in the getting started guides (#81 #94, @jonthegeek).
+
New features
- Added MixUp callback and helper loss function and functional logic. (#82, @skeydan).
@@ -151,7 +153,8 @@ Internal changes
luz 0.1.0
CRAN release: 2021-06-17
-- Added a
NEWS.md
file to track changes to the package.
+- Added a
NEWS.md
file to track changes to the package.
+
@@ -161,7 +164,7 @@ luz 0.1.0
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/pkgdown.yml b/pkgdown.yml
index 3b8ad870..5832483b 100644
--- a/pkgdown.yml
+++ b/pkgdown.yml
@@ -1,15 +1,14 @@
-pandoc: 2.7.3
-pkgdown: 2.0.7.9000
-pkgdown_sha: c9206802f2888992de92aa41f517ba7812f05331
+pandoc: 2.19.2
+pkgdown: 2.0.7
+pkgdown_sha: ~
articles:
accelerator: accelerator.html
- chargpt: examples/chargpt.html
checkpoints: checkpoints.html
+ lr-finder: lr-finder.html
custom-loop: custom-loop.html
+ chargpt: examples/chargpt.html
dogs-vs-cats-binary-classification: examples/dogs-vs-cats-binary-classification.html
- get-started: get-started.html
index: examples/index.html
- lr-finder: lr-finder.html
mnist-autoencoder: examples/mnist-autoencoder.html
mnist-cnn-virtual-batch-size: examples/mnist-cnn-virtual-batch-size.html
mnist-cnn: examples/mnist-cnn.html
@@ -18,5 +17,7 @@ articles:
mnist-triplet: examples/mnist-triplet.html
pets-unet: examples/pets-unet.html
text-classification: examples/text-classification.html
-last_built: 2023-09-15T17:29Z
+ text-generation: examples/text-generation.html
+ get-started: get-started.html
+last_built: 2023-10-17T16:26Z
diff --git a/reference/accelerator.html b/reference/accelerator.html
index 6e94c1a0..ad0990cf 100644
--- a/reference/accelerator.html
+++ b/reference/accelerator.html
@@ -99,7 +99,7 @@ Arguments
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/as_dataloader.html b/reference/as_dataloader.html
index e9bfbfe2..404c070e 100644
--- a/reference/as_dataloader.html
+++ b/reference/as_dataloader.html
@@ -159,7 +159,7 @@ Overriding
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/context.html b/reference/context.html
index 9c598532..a377b0d8 100644
--- a/reference/context.html
+++ b/reference/context.html
@@ -517,7 +517,7 @@ Arguments
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/ctx.html b/reference/ctx.html
index fe747f6a..0e921c9b 100644
--- a/reference/ctx.html
+++ b/reference/ctx.html
@@ -90,7 +90,7 @@ See also
diff --git a/reference/evaluate.html b/reference/evaluate.html
index 9bdcad2d..9ef5c86f 100644
--- a/reference/evaluate.html
+++ b/reference/evaluate.html
@@ -141,12 +141,12 @@ Details
evaluation <- fitted %>% evaluate(data = valid_dl)
metrics <- get_metrics(evaluation)
print(evaluation)
-## A `luz_module_evaluation`
-## -- Results ---------------------------------------------------------------------
-## loss: 1.5146
-## mae: 1.0251
-## mse: 1.5159
-## rmse: 1.2312
+## A `luz_module_evaluation`
+## -- Results ---------------------------------------------------------------------
+## loss: 1.5146
+## mae: 1.0251
+## mse: 1.5159
+## rmse: 1.2312
diff --git a/reference/fit.luz_module_generator.html b/reference/fit.luz_module_generator.html
index 3850f33c..885905d8 100644
--- a/reference/fit.luz_module_generator.html
+++ b/reference/fit.luz_module_generator.html
@@ -170,7 +170,7 @@ See also
diff --git a/reference/get_metrics.html b/reference/get_metrics.html
index 2add3e85..692a7bf7 100644
--- a/reference/get_metrics.html
+++ b/reference/get_metrics.html
@@ -103,7 +103,7 @@ Methods (by class)
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/index.html b/reference/index.html
index 29fd15f7..5c6cf36f 100644
--- a/reference/index.html
+++ b/reference/index.html
@@ -359,7 +359,7 @@ Serialization
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/lr_finder-1.png b/reference/lr_finder-1.png
index 6801c27f..2b724bd4 100644
Binary files a/reference/lr_finder-1.png and b/reference/lr_finder-1.png differ
diff --git a/reference/lr_finder.html b/reference/lr_finder.html
index 382b47f7..6cdb2739 100644
--- a/reference/lr_finder.html
+++ b/reference/lr_finder.html
@@ -146,7 +146,7 @@ Examples
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/luz_callback.html b/reference/luz_callback.html
index fdcd20ed..a2d9fdae 100644
--- a/reference/luz_callback.html
+++ b/reference/luz_callback.html
@@ -152,41 +152,41 @@ Details
Callbacks can be called in many different positions of the training
loop, including combinations of them. Here’s an overview of possible
callback breakpoints:
-Start Fit
- - on_fit_begin
- Start Epoch Loop
- - on_epoch_begin
- Start Train
- - on_train_begin
- Start Batch Loop
- - on_train_batch_begin
- Start Default Training Step
- - on_train_batch_after_pred
- - on_train_batch_after_loss
- - on_train_batch_before_backward
- - on_train_batch_before_step
- - on_train_batch_after_step
- End Default Training Step:
- - on_train_batch_end
- End Batch Loop
- - on_train_end
- End Train
- Start Valid
- - on_valid_begin
- Start Batch Loop
- - on_valid_batch_begin
- Start Default Validation Step
- - on_valid_batch_after_pred
- - on_valid_batch_after_loss
- End Default Validation Step
- - on_valid_batch_end
- End Batch Loop
- - on_valid_end
- End Valid
- - on_epoch_end
- End Epoch Loop
- - on_fit_end
-End Fit
+
+ Start Fit- on_fit_begin
+
+ Start Epoch Loop- on_epoch_begin
+
+ Start Train- on_train_begin
+
+ Start Batch Loop- on_train_batch_begin
+
+ Start Default Training Step- on_train_batch_after_pred
+ - on_train_batch_after_loss
+ - on_train_batch_before_backward
+ - on_train_batch_before_step
+ - on_train_batch_after_step
+ :
+ End Default Training Step- on_train_batch_end
+
+ End Batch Loop- on_train_end
+
+ End Train
+ Start Valid- on_valid_begin
+
+ Start Batch Loop- on_valid_batch_begin
+
+ Start Default Validation Step- on_valid_batch_after_pred
+ - on_valid_batch_after_loss
+
+ End Default Validation Step- on_valid_batch_end
+
+ End Batch Loop- on_valid_end
+
+ End Valid- on_epoch_end
+
+ End Epoch Loop- on_fit_end
+ End Fit
Every step market with on_*
is a point in the training procedure that
is available for callbacks to be called.
The other important part of callbacks is the ctx
(context) object. See
@@ -208,14 +208,14 @@
Prediction callbackspredict(). In this case the supported
callback methods are detailed above.
-
Start predict
- - on_predict_begin
- Start prediction loop
- - on_predict_batch_begin
- - on_predict_batch_end
- End prediction loop
- - on_predict_end
-End predict
+
+ Start predict- on_predict_begin
+
+ Start prediction loop- on_predict_batch_begin
+ - on_predict_batch_end
+
+ End prediction loop- on_predict_end
+ End predict
Evaluate callbacks
@@ -224,18 +224,18 @@ Evaluate callbacksevaluate(), in this case, the callbacks that
are used are equivalent to those of the validation loop when using fit()
:
-
Start Valid
- - on_valid_begin
- Start Batch Loop
- - on_valid_batch_begin
- Start Default Validation Step
- - on_valid_batch_after_pred
- - on_valid_batch_after_loss
- End Default Validation Step
- - on_valid_batch_end
- End Batch Loop
- - on_valid_end
-End Valid
+
+ Start Valid- on_valid_begin
+
+ Start Batch Loop- on_valid_batch_begin
+
+ Start Default Validation Step- on_valid_batch_after_pred
+ - on_valid_batch_after_loss
+
+ End Default Validation Step- on_valid_batch_end
+
+ End Batch Loop- on_valid_end
+ End Valid
See also
@@ -278,7 +278,7 @@ Examples
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
diff --git a/reference/luz_callback_auto_resume.html b/reference/luz_callback_auto_resume.html
index 2013b39d..b853084c 100644
--- a/reference/luz_callback_auto_resume.html
+++ b/reference/luz_callback_auto_resume.html
@@ -177,16 +177,16 @@ Examples#> Caused by error in `self[[callback_nm]]()`:
#> ! Error on epoch 5
#> set metric epoch value
-#> 1 train loss 1 1.302326
-#> 2 train loss 2 1.141849
-#> 3 train loss 3 1.094023
-#> 4 train loss 4 1.082328
-#> 5 train loss 5 1.083923
-#> 6 train loss 6 1.072870
-#> 7 train loss 7 1.083111
-#> 8 train loss 8 1.079866
-#> 9 train loss 9 1.074621
-#> 10 train loss 10 1.075743
+#> 1 train loss 1 1.217334
+#> 2 train loss 2 1.079304
+#> 3 train loss 3 1.040630
+#> 4 train loss 4 1.027106
+#> 5 train loss 5 1.023069
+#> 6 train loss 6 1.017577
+#> 7 train loss 7 1.016829
+#> 8 train loss 8 1.020484
+#> 9 train loss 9 1.022464
+#> 10 train loss 10 1.025988
diff --git a/articles/examples/index_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/index_files/accessible-code-block-0.0.1/empty-anchor.js
deleted file mode 100644
index ca349fd6..00000000
--- a/articles/examples/index_files/accessible-code-block-0.0.1/empty-anchor.js
+++ /dev/null
@@ -1,15 +0,0 @@
-// Hide empty
diff --git a/reference/luz_callback_csv_logger.html b/reference/luz_callback_csv_logger.html
index ef9e3e98..8caa5ca8 100644
--- a/reference/luz_callback_csv_logger.html
+++ b/reference/luz_callback_csv_logger.html
@@ -106,7 +106,7 @@ See also
diff --git a/reference/luz_callback_early_stopping.html b/reference/luz_callback_early_stopping.html
index 594b229e..6e55924f 100644
--- a/reference/luz_callback_early_stopping.html
+++ b/reference/luz_callback_early_stopping.html
@@ -150,7 +150,7 @@ Examples
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
References
- Site built with pkgdown 2.0.7.9000.
+ Site built with pkgdown 2.0.7.
Site built with pkgdown 2.0.7.9000.
+Site built with pkgdown 2.0.7.