Skip to content

Commit

Permalink
Release 0.6.0 (#158)
Browse files Browse the repository at this point in the history
* register S3 nn_prune_head

* Increment version number to 0.6.0

* fix `could not find function data()`
  • Loading branch information
cregouby authored Jun 17, 2024
1 parent bbf8e0d commit ebf731b
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 35 deletions.
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
^.V8*
^doc$
^Meta$
^CRAN-SUBMISSION$
^revdep$
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ activate
/doc/
/Meta/
revdep
revdep/
tabnet.Rcheck
..Rcheck
tabnet_*.tar.gz
8 changes: 5 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tabnet
Title: Fit 'TabNet' Models for Classification and Regression
Version: 0.5.0
Version: 0.6.0
Authors@R: c(
person("Daniel", "Falbel", , "[email protected]", role = "aut"),
person(, "RStudio", role = "cph"),
Expand All @@ -11,8 +11,8 @@ Authors@R: c(
comment = c(ORCID = "0000-0002-5573-3952"))
)
Description: Implements the 'TabNet' model by Sercan O. Arik et al. (2019)
<arXiv:1908.07442> with 'Coherent Hierarchical Multi-label
Classification Networks' by Giunchiglia et al. <arXiv:2010.10151> and
<doi:10.48550/arXiv.1908.07442> with 'Coherent Hierarchical Multi-label
Classification Networks' by Giunchiglia et al. <doi:10.48550/arXiv.2010.10151> and
provides a consistent interface for fitting and creating predictions.
It's also fully compatible with the 'tidymodels' ecosystem.
License: MIT + file LICENSE
Expand Down Expand Up @@ -51,6 +51,7 @@ Suggests:
recipes,
rmarkdown,
rsample,
spelling,
testthat (>= 3.0.0),
tidymodels,
tidyverse,
Expand All @@ -66,3 +67,4 @@ Config/testthat/start-first: interface, explain, params
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
Language: en-US
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method(multi_predict,"_tabnet_fit")
S3method(predict,tabnet_fit)
S3method(print,tabnet_fit)
S3method(print,tabnet_pretrain)
Expand Down Expand Up @@ -46,6 +47,7 @@ importFrom(dplyr,select)
importFrom(dplyr,starts_with)
importFrom(dplyr,where)
importFrom(magrittr,"%>%")
importFrom(parsnip,multi_predict)
importFrom(rlang,.data)
importFrom(stats,predict)
importFrom(stats,update)
Expand Down
3 changes: 3 additions & 0 deletions R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
vctrs::s3_register("ggplot2::autoplot", "tabnet_fit")
vctrs::s3_register("ggplot2::autoplot", "tabnet_pretrain")
vctrs::s3_register("ggplot2::autoplot", "tabnet_explain")
vctrs::s3_register("torch::nn_prune_head", "tabnet_fit")
vctrs::s3_register("torch::nn_prune_head", "tabnet_pretrain")
vctrs::s3_register("tune::min_grid", "tabnet")
}


globalVariables(c("batch_size",
"checkpoint",
"dataset",
"epoch",
"has_checkpoint",
Expand Down
11 changes: 9 additions & 2 deletions R/parsnip.R
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,12 @@ tabnet <- function(mode = "unknown", cat_emb_dim = NULL, decision_width = NULL,
if (!requireNamespace("parsnip", quietly = TRUE))
stop("Package \"parsnip\" needed for this function to work. Please install it.", call. = FALSE)

if (!tabnet_env$parsnip_added) {
if (parsnip_is_missing_tabnet(tabnet_env)) {
add_parsnip_tabnet()
tabnet_env$parsnip_added <- TRUE
}


# Capture the arguments in quosures
args <- list(
cat_emb_dim = rlang::enquo(cat_emb_dim),
Expand Down Expand Up @@ -512,7 +513,8 @@ tabnet <- function(mode = "unknown", cat_emb_dim = NULL, decision_width = NULL,
tabnet_env <- new.env()
tabnet_env$parsnip_added <- FALSE


#' @export
#' @importFrom parsnip multi_predict
multi_predict._tabnet_fit <- function(object, new_data, type = NULL, epochs = NULL, ...) {

if (is.null(epochs))
Expand Down Expand Up @@ -567,3 +569,8 @@ update.tabnet <- function(object, parameters = NULL, epochs = NULL, penalty = NU
}

min_grid.tabnet <- function(x, grid, ...) tune::fit_max_value(x, grid, ...)

parsnip_is_missing_tabnet <- function(tabnet_env) {
current <- parsnip::get_model_env()
!(any(current$models == "tabnet") || tabnet_env$parsnip_added)
}
4 changes: 2 additions & 2 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ autoplot.tabnet_fit <- function(object, ...) {

if ("checkpoint" %in% names(collect_metrics)) {
checkpoints <- collect_metrics %>%
dplyr::filter(.data$checkpoint == TRUE, dataset == "train") %>%
dplyr::select(-.data$checkpoint) %>%
dplyr::filter(checkpoint == TRUE, dataset == "train") %>%
dplyr::select(-checkpoint) %>%
dplyr::mutate(size = 2)
p +
ggplot2::geom_point(data = checkpoints, ggplot2::aes(x = epoch, y = loss, color = dataset, size = .data$size ))
Expand Down
7 changes: 3 additions & 4 deletions cran-comments.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
## R CMD check results
── R CMD check results ─────────────────────────────────────────────────────────────────── tabnet 0.6.0 ────
Duration: 2m 13s

0 errors | 0 warnings | 1 note

* This is a new release.
0 errors ✔ | 0 warnings ✔ | 0 notes ✔
48 changes: 36 additions & 12 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
@@ -1,28 +1,52 @@
Ames
Arik
Bugfixes
Eleonora
Explicitely
GLU
Giunchiglia
Interpretability
Interpretable
Lifecycle
MNAR
OOM
Pfister
Pretrain
Sercan
TabNet
TabNet's
adam
ai
al.
al
ames
Ames
Arik
arXiv
autograd
beeing
classif
config
cpu
cuda
dataloading
detailled
doi
dreamquark
entmax
et
Giunchiglia
giunchiglia
explainability
ggplot
interpretable
Interpretable
Lifecycle
mse
neurips
NMAR
nn
orginal
overfit
overfits
pre
pretrain
pretrained
pretraining
PyTorch
Sercan
TabNet
reusage
sparsemax
subprocesses
th
tidymodels
zeallot
7 changes: 1 addition & 6 deletions man/autoplot.tabnet_explain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/autoplot.tabnet_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ suppressPackageStartupMessages(library(data.tree))


# ames small data
data("ames", package = "modeldata")
utils::data("ames", package = "modeldata")
ids <- sample(nrow(ames), 256)
small_ames <- ames[ids,]
x <- ames[ids,-which(names(ames) == "Sale_Price")]
Expand All @@ -20,7 +20,7 @@ ames_fit_vsplit <- tabnet_fit(x, y, tabnet_model=ames_pretrain_vsplit, epochs =
num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1)

# attrition small data
data("attrition", package = "modeldata")
utils::data("attrition", package = "modeldata")
ids <- sample(nrow(attrition), 256)

# attrition common models
Expand All @@ -35,7 +35,7 @@ attr_fitted <- tabnet_fit(attrix, attriy, epochs = 12)
attr_fitted_vsplit <- tabnet_fit(attrix, attriy, epochs = 12, valid_split=0.3)

# data.tree Node dataset
data("acme", package = "data.tree")
utils::data("acme", package = "data.tree")
acme_df <- data.tree::ToDataFrameTypeCol(acme, acme$attributesAll) %>%
select(-starts_with("level_"))

Expand All @@ -46,4 +46,4 @@ attrition_tree <- attrition %>%
data.tree::as.Node()

# Run after all tests
withr::defer(teardown_env())
withr::defer(testthat::teardown_env())

0 comments on commit ebf731b

Please sign in to comment.