diff --git a/DESCRIPTION b/DESCRIPTION index fcf265c..4449631 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -45,6 +45,7 @@ Imports: withr, zeallot Suggests: + cli, knitr, modeldata, patchwork, diff --git a/NAMESPACE b/NAMESPACE index 495de1c..f99eb7c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -21,23 +21,36 @@ S3method(tabnet_pretrain,recipe) S3method(update,tabnet) export("%>%") export(attention_width) +export(cat_emb_dim) export(check_compliant_node) +export(checkpoint_epochs) export(decision_width) +export(drop_last) +export(encoder_activation) export(feature_reusage) +export(lr_scheduler) export(mask_type) +export(mlp_activation) +export(mlp_hidden_multiplier) export(momentum) export(nn_prune_head.tabnet_fit) export(nn_prune_head.tabnet_pretrain) export(node_to_df) export(num_independent) +export(num_independent_decoder) export(num_shared) +export(num_shared_decoder) export(num_steps) +export(optimizer) +export(penalty) export(tabnet) export(tabnet_config) export(tabnet_explain) export(tabnet_fit) export(tabnet_nn) export(tabnet_pretrain) +export(verbose) +export(virtual_batch_size) importFrom(dplyr,filter) importFrom(dplyr,last_col) importFrom(dplyr,mutate) diff --git a/NEWS.md b/NEWS.md index 0741b9f..247e0bd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,7 @@ ## Bugfixes * improve function documentation consistency before translation +* fix ".... is not an exported object from 'namespace:dials'" error when using tune() on tabnet parameters. (#160 @cphaarmeyer) # tabnet 0.6.0 diff --git a/R/dials.R b/R/dials.R index 0c49d2d..76cac0c 100644 --- a/R/dials.R +++ b/R/dials.R @@ -3,6 +3,11 @@ check_dials <- function() { stop("Package \"dials\" needed for this function to work. Please install it.", call. = FALSE) } +check_cli <- function() { + if (!requireNamespace("cli", quietly = TRUE)) + stop("Package \"cli\" needed for this function to work. Please install it.", call. = FALSE) +} + #' Parameters for the tabnet model @@ -17,56 +22,70 @@ check_dials <- function() { #' @rdname tabnet_params #' @return A `dials` parameter to be used when tuning TabNet models. #' @export -decision_width <- function(range = c(8L, 64L), trans = NULL) { +attention_width <- function(range = c(8L, 64L), trans = NULL) { check_dials() dials::new_quant_param( type = "integer", range = range, inclusive = c(TRUE, TRUE), trans = trans, - label = c(decision_width = "Width of the decision prediction layer"), + label = c(attention_width = "Width of the attention embedding for each mask"), finalize = NULL ) } #' @rdname tabnet_params #' @export -attention_width <- function(range = c(8L, 64L), trans = NULL) { +decision_width <- function(range = c(8L, 64L), trans = NULL) { check_dials() dials::new_quant_param( type = "integer", range = range, inclusive = c(TRUE, TRUE), trans = trans, - label = c(attention_width = "Width of the attention embedding for each mask"), + label = c(decision_width = "Width of the decision prediction layer"), finalize = NULL ) } + #' @rdname tabnet_params #' @export -num_steps <- function(range = c(3L, 10L), trans = NULL) { +feature_reusage <- function(range = c(1, 2), trans = NULL) { check_dials() dials::new_quant_param( - type = "integer", + type = "double", range = range, inclusive = c(TRUE, TRUE), trans = trans, - label = c(num_steps = "Number of steps in the architecture"), + label = c(feature_reusage = "Coefficient for feature reusage in the masks"), finalize = NULL ) } #' @rdname tabnet_params #' @export -feature_reusage <- function(range = c(1, 2), trans = NULL) { +momentum <- function(range = c(0.01, 0.4), trans = NULL) { check_dials() dials::new_quant_param( type = "double", range = range, inclusive = c(TRUE, TRUE), trans = trans, - label = c(feature_reusage = "Coefficient for feature reusage in the masks"), + label = c(momentum = "Momentum for batch normalization"), + finalize = NULL + ) +} + + +#' @rdname tabnet_params +#' @export +mask_type <- function(values = c("sparsemax", "entmax")) { + check_dials() + dials::new_qual_param( + type = "character", + values = values, + label = c(mask_type = "Final layer of feature selector, either sparsemax or entmax"), finalize = NULL ) } @@ -101,28 +120,73 @@ num_shared <- function(range = c(1L, 5L), trans = NULL) { #' @rdname tabnet_params #' @export -momentum <- function(range = c(0.01, 0.4), trans = NULL) { +num_steps <- function(range = c(3L, 10L), trans = NULL) { check_dials() dials::new_quant_param( - type = "double", + type = "integer", range = range, inclusive = c(TRUE, TRUE), trans = trans, - label = c(momentum = "Momentum for batch normalization"), + label = c(num_steps = "Number of steps in the architecture"), finalize = NULL ) } - -#' @rdname tabnet_params +#' Non-tunable parameters for the tabnet model +#' +#' @param range unused +#' @param trans unused +#' @rdname tabnet_non_tunable #' @export -mask_type <- function(values = c("sparsemax", "entmax")) { - check_dials() - dials::new_qual_param( - type = "character", - values = values, - label = c(mask_type = "Final layer of feature selector, either sparsemax or entmax"), - finalize = NULL - ) +cat_emb_dim <- function(range = NULL, trans = NULL) { + check_cli() + cli::cli_abort("{.var cat_emb_dim} cannot be used as a {.fun tune} parameter yet.") } +#' @rdname tabnet_non_tunable +#' @export +checkpoint_epochs <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +drop_last <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +encoder_activation <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +lr_scheduler <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +mlp_activation <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +mlp_hidden_multiplier <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +num_independent_decoder <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +num_shared_decoder <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +optimizer <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +penalty <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +verbose <- cat_emb_dim + +#' @rdname tabnet_non_tunable +#' @export +virtual_batch_size <- cat_emb_dim diff --git a/R/parsnip.R b/R/parsnip.R index e7e1ed9..62df1b9 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -85,7 +85,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "cat_emb_dim", original = "cat_emb_dim", - func = list(pkg = "dials", fun = "cat_emb_dim"), + func = list(pkg = "tabnet", fun = "cat_emb_dim"), has_submodel = FALSE ) @@ -94,7 +94,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "decision_width", original = "decision_width", - func = list(pkg = "dials", fun = "decision_width"), + func = list(pkg = "tabnet", fun = "decision_width"), has_submodel = FALSE ) @@ -103,7 +103,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "attention_width", original = "attention_width", - func = list(pkg = "dials", fun = "attention_width"), + func = list(pkg = "tabnet", fun = "attention_width"), has_submodel = FALSE ) @@ -112,7 +112,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "num_steps", original = "num_steps", - func = list(pkg = "dials", fun = "num_steps"), + func = list(pkg = "tabnet", fun = "num_steps"), has_submodel = FALSE ) @@ -121,7 +121,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "mask_type", original = "mask_type", - func = list(pkg = "dials", fun = "mask_type"), + func = list(pkg = "tabnet", fun = "mask_type"), has_submodel = FALSE ) @@ -130,7 +130,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "mlp_hidden_multiplier", original = "mlp_hidden_multiplier", - func = list(pkg = "dials", fun = "mlp_hidden_multiplier"), + func = list(pkg = "tabnet", fun = "mlp_hidden_multiplier"), has_submodel = FALSE ) @@ -139,7 +139,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "mlp_activation", original = "mlp_activation", - func = list(pkg = "dials", fun = "mlp_activation"), + func = list(pkg = "tabnet", fun = "mlp_activation"), has_submodel = FALSE ) @@ -148,7 +148,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "encoder_activation", original = "encoder_activation", - func = list(pkg = "dials", fun = "encoder_activation"), + func = list(pkg = "tabnet", fun = "encoder_activation"), has_submodel = FALSE ) @@ -157,7 +157,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "num_independent", original = "num_independent", - func = list(pkg = "dials", fun = "num_independent"), + func = list(pkg = "tabnet", fun = "num_independent"), has_submodel = FALSE ) @@ -166,7 +166,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "num_shared", original = "num_shared", - func = list(pkg = "dials", fun = "num_shared"), + func = list(pkg = "tabnet", fun = "num_shared"), has_submodel = FALSE ) @@ -175,7 +175,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "num_independent_decoder", original = "num_independent_decoder", - func = list(pkg = "dials", fun = "num_independent_decoder"), + func = list(pkg = "tabnet", fun = "num_independent_decoder"), has_submodel = FALSE ) @@ -184,7 +184,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "num_shared_decoder", original = "num_shared_decoder", - func = list(pkg = "dials", fun = "num_shared_decoder"), + func = list(pkg = "tabnet", fun = "num_shared_decoder"), has_submodel = FALSE ) @@ -202,7 +202,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "feature_reusage", original = "feature_reusage", - func = list(pkg = "dials", fun = "feature_reusage"), + func = list(pkg = "tabnet", fun = "feature_reusage"), has_submodel = FALSE ) @@ -238,7 +238,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "virtual_batch_size", original = "virtual_batch_size", - func = list(pkg = "dials", fun = "virtual_batch_size"), + func = list(pkg = "tabnet", fun = "virtual_batch_size"), has_submodel = FALSE ) @@ -256,7 +256,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "optimizer", original = "optimizer", - func = list(pkg = "dials", fun = "optimizer"), + func = list(pkg = "tabnet", fun = "optimizer"), has_submodel = FALSE ) @@ -265,7 +265,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "loss", original = "loss", - func = list(pkg = "dials", fun = "loss"), + func = list(pkg = "tabnet", fun = "loss"), has_submodel = FALSE ) @@ -274,7 +274,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "clip_value", original = "clip_value", - func = list(pkg = "dials", fun = "clip_value"), + func = list(pkg = "tabnet", fun = "clip_value"), has_submodel = FALSE ) @@ -283,7 +283,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "drop_last", original = "drop_last", - func = list(pkg = "dials", fun = "drop_last"), + func = list(pkg = "tabnet", fun = "drop_last"), has_submodel = FALSE ) @@ -292,25 +292,25 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "lr_scheduler", original = "lr_scheduler", - func = list(pkg = "dials", fun = "lr_scheduler"), + func = list(pkg = "tabnet", fun = "lr_scheduler"), has_submodel = FALSE ) parsnip::set_model_arg( model = "tabnet", eng = "torch", - parsnip = "lr_decay", + parsnip = "rate_decay", original = "lr_decay", - func = list(pkg = "dials", fun = "lr_decay"), + func = list(pkg = "dials", fun = "rate_decay"), has_submodel = FALSE ) parsnip::set_model_arg( model = "tabnet", eng = "torch", - parsnip = "step_size", + parsnip = "rate_step_size", original = "step_size", - func = list(pkg = "dials", fun = "step_size"), + func = list(pkg = "dials", fun = "rate_step_size"), has_submodel = FALSE ) @@ -319,7 +319,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "checkpoint_epochs", original = "checkpoint_epochs", - func = list(pkg = "dials", fun = "checkpoint_epochs"), + func = list(pkg = "tabnet", fun = "checkpoint_epochs"), has_submodel = FALSE ) @@ -328,7 +328,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "verbose", original = "verbose", - func = list(pkg = "dials", fun = "verbose"), + func = list(pkg = "tabnet", fun = "verbose"), has_submodel = FALSE ) @@ -337,7 +337,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "importance_sample_size", original = "importance_sample_size", - func = list(pkg = "dials", fun = "importance_sample_size"), + func = list(pkg = "tabnet", fun = "importance_sample_size"), has_submodel = FALSE ) @@ -346,7 +346,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "early_stopping_monitor", original = "early_stopping_monitor", - func = list(pkg = "dials", fun = "early_stopping_monitor"), + func = list(pkg = "tabnet", fun = "early_stopping_monitor"), has_submodel = FALSE ) @@ -355,7 +355,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "early_stopping_tolerance", original = "early_stopping_tolerance", - func = list(pkg = "dials", fun = "early_stopping_tolerance"), + func = list(pkg = "tabnet", fun = "early_stopping_tolerance"), has_submodel = FALSE ) @@ -364,7 +364,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "early_stopping_patience", original = "early_stopping_patience", - func = list(pkg = "dials", fun = "early_stopping_patience"), + func = list(pkg = "tabnet", fun = "early_stopping_patience"), has_submodel = FALSE ) @@ -382,7 +382,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "tabnet_model", original = "tabnet_model", - func = list(pkg = "dials", fun = "tabnet_model"), + func = list(pkg = "tabnet", fun = "tabnet_model"), has_submodel = FALSE ) @@ -391,7 +391,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "from_epoch", original = "from_epoch", - func = list(pkg = "dials", fun = "from_epoch"), + func = list(pkg = "tabnet", fun = "from_epoch"), has_submodel = FALSE ) @@ -427,6 +427,11 @@ add_parsnip_tabnet <- function() { #' for this model are "unknown", "regression", or "classification". #' @inheritParams tabnet_config #' @inheritParams tabnet_fit +#' @param rate_decay multiplies the initial learning rate by `rate_decay` every +#' `rate_step_size` epochs. Unused if `lr_scheduler` is a `torch::lr_scheduler` +#' or `NULL`. +#' @param rate_step_size the learning rate scheduler step size. Unused if +#' `lr_scheduler` is a `torch::lr_scheduler` or `NULL`. #' #' @inheritSection tabnet_fit Threading #' @seealso tabnet_fit @@ -449,7 +454,7 @@ tabnet <- function(mode = "unknown", cat_emb_dim = NULL, decision_width = NULL, num_independent_decoder = NULL, num_shared_decoder = NULL, penalty = NULL, feature_reusage = NULL, momentum = NULL, epochs = NULL, batch_size = NULL, virtual_batch_size = NULL, learn_rate = NULL, optimizer = NULL, loss = NULL, - clip_value = NULL, drop_last = NULL, lr_scheduler = NULL, lr_decay = NULL, step_size = NULL, + clip_value = NULL, drop_last = NULL, lr_scheduler = NULL, rate_decay = NULL, rate_step_size = NULL, checkpoint_epochs = NULL, verbose = NULL, importance_sample_size = NULL, early_stopping_monitor = NULL, early_stopping_tolerance = NULL, early_stopping_patience = NULL, skip_importance = NULL, @@ -488,8 +493,8 @@ tabnet <- function(mode = "unknown", cat_emb_dim = NULL, decision_width = NULL, clip_value = rlang::enquo(clip_value), drop_last = rlang::enquo(drop_last), lr_scheduler = rlang::enquo(lr_scheduler), - lr_decay = rlang::enquo(lr_decay), - step_size = rlang::enquo(step_size), + lr_decay = rlang::enquo(rate_decay), + step_size = rlang::enquo(rate_step_size), checkpoint_epochs = rlang::enquo(checkpoint_epochs), verbose = rlang::enquo(verbose), importance_sample_size = rlang::enquo(importance_sample_size), diff --git a/inst/WORDLIST b/inst/WORDLIST index dc00692..3978496 100644 --- a/inst/WORDLIST +++ b/inst/WORDLIST @@ -15,7 +15,6 @@ Pretrain Sercan TabNet TabNet's -adam ai al ames @@ -37,6 +36,7 @@ ggplot interpretable mse nn +num orginal overfit overfits @@ -49,4 +49,5 @@ sparsemax subprocesses th tidymodels +tunable zeallot diff --git a/man/tabnet.Rd b/man/tabnet.Rd index 8184a81..37106d6 100644 --- a/man/tabnet.Rd +++ b/man/tabnet.Rd @@ -27,8 +27,8 @@ tabnet( clip_value = NULL, drop_last = NULL, lr_scheduler = NULL, - lr_decay = NULL, - step_size = NULL, + rate_decay = NULL, + rate_step_size = NULL, checkpoint_epochs = NULL, verbose = NULL, importance_sample_size = NULL, @@ -113,11 +113,11 @@ decays the learning rate by \code{lr_decay} when no improvement after \code{step It can also be a \link[torch:lr_scheduler]{torch::lr_scheduler} function that only takes the optimizer as parameter. The \code{step} method is called once per epoch.} -\item{lr_decay}{multiplies the initial learning rate by \code{lr_decay} every -\code{step_size} epochs. Unused if \code{lr_scheduler} is a \code{torch::lr_scheduler} +\item{rate_decay}{multiplies the initial learning rate by \code{rate_decay} every +\code{rate_step_size} epochs. Unused if \code{lr_scheduler} is a \code{torch::lr_scheduler} or \code{NULL}.} -\item{step_size}{the learning rate scheduler step size. Unused if +\item{rate_step_size}{the learning rate scheduler step size. Unused if \code{lr_scheduler} is a \code{torch::lr_scheduler} or \code{NULL}.} \item{checkpoint_epochs}{checkpoint model weights and architecture every diff --git a/man/tabnet_non_tunable.Rd b/man/tabnet_non_tunable.Rd new file mode 100644 index 0000000..0fa7487 --- /dev/null +++ b/man/tabnet_non_tunable.Rd @@ -0,0 +1,52 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dials.R +\name{cat_emb_dim} +\alias{cat_emb_dim} +\alias{checkpoint_epochs} +\alias{drop_last} +\alias{encoder_activation} +\alias{lr_scheduler} +\alias{mlp_activation} +\alias{mlp_hidden_multiplier} +\alias{num_independent_decoder} +\alias{num_shared_decoder} +\alias{optimizer} +\alias{penalty} +\alias{verbose} +\alias{virtual_batch_size} +\title{Non-tunable parameters for the tabnet model} +\usage{ +cat_emb_dim(range = NULL, trans = NULL) + +checkpoint_epochs(range = NULL, trans = NULL) + +drop_last(range = NULL, trans = NULL) + +encoder_activation(range = NULL, trans = NULL) + +lr_scheduler(range = NULL, trans = NULL) + +mlp_activation(range = NULL, trans = NULL) + +mlp_hidden_multiplier(range = NULL, trans = NULL) + +num_independent_decoder(range = NULL, trans = NULL) + +num_shared_decoder(range = NULL, trans = NULL) + +optimizer(range = NULL, trans = NULL) + +penalty(range = NULL, trans = NULL) + +verbose(range = NULL, trans = NULL) + +virtual_batch_size(range = NULL, trans = NULL) +} +\arguments{ +\item{range}{unused} + +\item{trans}{unused} +} +\description{ +Non-tunable parameters for the tabnet model +} diff --git a/man/tabnet_params.Rd b/man/tabnet_params.Rd index 239eb0b..cdcccfe 100644 --- a/man/tabnet_params.Rd +++ b/man/tabnet_params.Rd @@ -1,31 +1,31 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/dials.R -\name{decision_width} -\alias{decision_width} +\name{attention_width} \alias{attention_width} -\alias{num_steps} +\alias{decision_width} \alias{feature_reusage} -\alias{num_independent} -\alias{num_shared} \alias{momentum} \alias{mask_type} +\alias{num_independent} +\alias{num_shared} +\alias{num_steps} \title{Parameters for the tabnet model} \usage{ -decision_width(range = c(8L, 64L), trans = NULL) - attention_width(range = c(8L, 64L), trans = NULL) -num_steps(range = c(3L, 10L), trans = NULL) +decision_width(range = c(8L, 64L), trans = NULL) feature_reusage(range = c(1, 2), trans = NULL) +momentum(range = c(0.01, 0.4), trans = NULL) + +mask_type(values = c("sparsemax", "entmax")) + num_independent(range = c(1L, 5L), trans = NULL) num_shared(range = c(1L, 5L), trans = NULL) -momentum(range = c(0.01, 0.4), trans = NULL) - -mask_type(values = c("sparsemax", "entmax")) +num_steps(range = c(3L, 10L), trans = NULL) } \arguments{ \item{range}{the default range for the parameter value} diff --git a/tests/testthat/test-dials.R b/tests/testthat/test-dials.R new file mode 100644 index 0000000..e4200df --- /dev/null +++ b/tests/testthat/test-dials.R @@ -0,0 +1,48 @@ +test_that("Check we can use hardhat:::extract_parameter_set_dials() with {dial} tune()ed parameter", { + + model <- tabnet(batch_size = tune(), learn_rate = tune(), epochs = tune(), + momentum = tune(), penalty = tune(), rate_step_size = tune()) %>% + parsnip::set_mode("regression") %>% + parsnip::set_engine("torch") + + wf <- workflows::workflow() %>% + workflows::add_model(model) %>% + workflows::add_formula(Sale_Price ~ .) + + expect_no_error( + wf %>% hardhat::extract_parameter_set_dials() + ) +}) + +test_that("Check we can use hardhat:::extract_parameter_set_dials() with {tabnet} tune()ed parameter", { + + model <- tabnet(num_steps = tune(), num_shared = tune(), mask_type = tune(), + feature_reusage = tune(), attention_width = tune()) %>% + parsnip::set_mode("regression") %>% + parsnip::set_engine("torch") + + wf <- workflows::workflow() %>% + workflows::add_model(model) %>% + workflows::add_formula(Sale_Price ~ .) + + expect_no_error( + wf %>% hardhat::extract_parameter_set_dials() + ) +}) + +test_that("Check non supported tune()ed parameter raise an explicit error", { + + model <- tabnet(cat_emb_dim = tune(), checkpoint_epochs = 0) %>% + parsnip::set_mode("regression") %>% + parsnip::set_engine("torch") + + wf <- workflows::workflow() %>% + workflows::add_model(model) %>% + workflows::add_formula(Sale_Price ~ .) + + expect_error( + wf %>% hardhat::extract_parameter_set_dials(), + regexp = "cannot be used as a .* parameter yet" + ) +}) +