Skip to content

Commit

Permalink
add message for unused weights and translation
Browse files Browse the repository at this point in the history
  • Loading branch information
cregouby committed Feb 21, 2024
1 parent 5378d2d commit 10c8e52
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 35 deletions.
9 changes: 9 additions & 0 deletions R/hardhat.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ tabnet_fit.default <- function(x, ...) {
#' @rdname tabnet_fit
tabnet_fit.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_config(), ...,
from_epoch = NULL, weights = NULL) {
if (!is.null(weights)) {
message(gettextf("Configured `weights` will not be used"))
}
processed <- hardhat::mold(x, y)
check_type(processed$outcomes)

Expand All @@ -134,6 +137,9 @@ tabnet_fit.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_con
#' @rdname tabnet_fit
tabnet_fit.formula <- function(formula, data, tabnet_model = NULL, config = tabnet_config(), ...,
from_epoch = NULL, weights = NULL) {
if (!is.null(weights)) {
message(gettextf("Configured `weights` will not be used"))
}
processed <- hardhat::mold(
formula, data,
blueprint = hardhat::default_formula_blueprint(
Expand All @@ -160,6 +166,9 @@ tabnet_fit.formula <- function(formula, data, tabnet_model = NULL, config = tabn
#' @rdname tabnet_fit
tabnet_fit.recipe <- function(x, data, tabnet_model = NULL, config = tabnet_config(), ...,
from_epoch = NULL, weights = NULL) {
if (!is.null(weights)) {
message(gettextf("Configured `weights` will not be used"))
}
processed <- hardhat::mold(x, data)
check_type(processed$outcomes)

Expand Down
Binary file modified inst/po/fr/LC_MESSAGES/R-tabnet.mo
Binary file not shown.
41 changes: 23 additions & 18 deletions po/R-fr.po
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ msgid ""
msgstr ""
"Project-Id-Version: tabnet 0.4.0.9000\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2024-02-10 17:01+0100\n"
"PO-Revision-Date: 2024-02-10 17:09+0100\n"
"POT-Creation-Date: 2024-02-21 22:37+0100\n"
"PO-Revision-Date: 2024-02-21 22:42+0100\n"
"Last-Translator: Christophe Regouby <[email protected]>\n"
"Language-Team: fr\n"
"Language: fr\n"
Expand All @@ -23,44 +23,49 @@ msgstr ""
msgid "`tabnet_explain()` is not defined for a '%s'."
msgstr "`tabnet_explain()` n’est pas défini pour un '%s'."

#: hardhat.R:108
#: hardhat.R:109
#, c-format
msgid "`tabnet_fit()` is not defined for a '%s'."
msgstr "`tabnet_fit()` n’est pas défini pour un '%s'."

#: hardhat.R:296
#: hardhat.R:118 hardhat.R:141 hardhat.R:170
#, c-format
msgid "Configured `weights` will not be used"
msgstr "Les `weights` configurés ne seront pas utilisés."

#: hardhat.R:309
#, c-format
msgid "`tabnet_pretrain()` is not defined for a '%s'."
msgstr "`tabnet_pretrain()` n’est pas défini pour un '%s'."

#: hardhat.R:404
#: hardhat.R:417
#, c-format
msgid "'%s' is not recognised as a proper TabNet model"
msgstr "'%s' n’est pas reconnu comme un modèle TabNet correct"

#: hardhat.R:411
#: hardhat.R:424
#, c-format
msgid "The model was trained for less than '%s' epochs"
msgstr "Le modèle a été entrainé sur moins de '%s' époques"

#: hardhat.R:423
#: hardhat.R:436
#, c-format
msgid "Found missing values in the `%s` outcome column."
msgstr "Il y a des valeurs manquantes dans la colonne de résultats `%s`."

#: hardhat.R:433 hardhat.R:548
#: hardhat.R:446 hardhat.R:561
msgid "Model dimensions don't match."
msgstr "Les dimensions ne correspondent pas entre les modèles."

#: hardhat.R:458
#: hardhat.R:471
#, c-format
msgid ""
"No model serialized weight can be found in `%s`, check the model history"
msgstr ""
"Il n’y a pas de points sérialisés de modèle dans `%s`, veuillez vérifier "
"l’historique du modèle"

#: hardhat.R:466
#: hardhat.R:479
msgid ""
"`tabnet_pretrain()` from a model is not currently supported.\n"
"The pretraining here will start with a network initialization"
Expand All @@ -69,42 +74,42 @@ msgstr ""
"pour l'instant.\n"
"Le pré-entraînement va commencer par une initialisation du réseau."

#: hardhat.R:503
#: hardhat.R:516
#, c-format
msgid "The model was trained for less than `%s` epochs"
msgstr "Le modèle a été entrainé sur moins de '%s' époques"

#: hardhat.R:600
#: hardhat.R:613
#, c-format
msgid "Mixed multi-outcome type '%s' is not supported"
msgstr "Le type '%s' n’est pas supporté pour des modèles multi-résultat"

#: hardhat.R:608
#: hardhat.R:621
#, c-format
msgid "Unknown outcome type '%s'"
msgstr "Le type `%s` est inconnu pour une colonne de résultat"

#: hardhat.R:615
#: hardhat.R:628
#, c-format
msgid "Outcome is factor and the prediction type is '%s'."
msgstr ""
"La colonne de résultats est catégorielle et la prédiction est de type '%s'."

#: hardhat.R:618
#: hardhat.R:631
#, c-format
msgid "Outcome is numeric and the prediction type is '%s'."
msgstr ""
"La colonne de résultats est numérique et la prédiction est de type '%s'."

#: hardhat.R:656
#: hardhat.R:669
msgid ""
"The provided hierarchical object is not recognized with a valid format that "
"can be checked"
msgstr ""
"L’objet hiérarchique fournit n’est pas reconnu dans un format valide qui "
"peut être vérifié"

#: hardhat.R:661
#: hardhat.R:674
#, c-format
msgid ""
"The attributes or colnames in the provided hierarchical object use the "
Expand Down Expand Up @@ -167,7 +172,7 @@ msgstr ""
"aléatoirement à une taille de 1e5. Vous pouvez rendre ce message silencieux "
"en configurant l’argument `importance_sample_size`."

#: parsnip.R:245
#: parsnip.R:460
msgid ""
"Package \"parsnip\" needed for this function to work. Please install it."
msgstr ""
Expand Down
39 changes: 22 additions & 17 deletions po/R-tabnet.pot
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
msgid ""
msgstr ""
"Project-Id-Version: tabnet 0.5.0.9000\n"
"POT-Creation-Date: 2024-02-10 17:01+0100\n"
"POT-Creation-Date: 2024-02-21 22:37+0100\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <[email protected]>\n"
Expand All @@ -19,79 +19,84 @@ msgstr ""
msgid "`tabnet_explain()` is not defined for a '%s'."
msgstr ""

#: hardhat.R:108
#: hardhat.R:109
#, c-format
msgid "`tabnet_fit()` is not defined for a '%s'."
msgstr ""

#: hardhat.R:296
#: hardhat.R:118 hardhat.R:141 hardhat.R:170
#, c-format
msgid "Configured `weights` will not be used"
msgstr ""

#: hardhat.R:309
#, c-format
msgid "`tabnet_pretrain()` is not defined for a '%s'."
msgstr ""

#: hardhat.R:404
#: hardhat.R:417
#, c-format
msgid "'%s' is not recognised as a proper TabNet model"
msgstr ""

#: hardhat.R:411
#: hardhat.R:424
#, c-format
msgid "The model was trained for less than '%s' epochs"
msgstr ""

#: hardhat.R:423
#: hardhat.R:436
#, c-format
msgid "Found missing values in the `%s` outcome column."
msgstr ""

#: hardhat.R:433 hardhat.R:548
#: hardhat.R:446 hardhat.R:561
msgid "Model dimensions don't match."
msgstr ""

#: hardhat.R:458
#: hardhat.R:471
#, c-format
msgid ""
"No model serialized weight can be found in `%s`, check the model history"
msgstr ""

#: hardhat.R:466
#: hardhat.R:479
msgid ""
"`tabnet_pretrain()` from a model is not currently supported.\n"
"The pretraining here will start with a network initialization"
msgstr ""

#: hardhat.R:503
#: hardhat.R:516
#, c-format
msgid "The model was trained for less than `%s` epochs"
msgstr ""

#: hardhat.R:600
#: hardhat.R:613
#, c-format
msgid "Mixed multi-outcome type '%s' is not supported"
msgstr ""

#: hardhat.R:608
#: hardhat.R:621
#, c-format
msgid "Unknown outcome type '%s'"
msgstr ""

#: hardhat.R:615
#: hardhat.R:628
#, c-format
msgid "Outcome is factor and the prediction type is '%s'."
msgstr ""

#: hardhat.R:618
#: hardhat.R:631
#, c-format
msgid "Outcome is numeric and the prediction type is '%s'."
msgstr ""

#: hardhat.R:656
#: hardhat.R:669
msgid ""
"The provided hierarchical object is not recognized with a valid format that "
"can be checked"
msgstr ""

#: hardhat.R:661
#: hardhat.R:674
#, c-format
msgid ""
"The attributes or colnames in the provided hierarchical object use the "
Expand Down Expand Up @@ -142,7 +147,7 @@ msgid ""
"message by using the `importance_sample_size` argument."
msgstr ""

#: parsnip.R:245
#: parsnip.R:460
msgid ""
"Package \"parsnip\" needed for this function to work. Please install it."
msgstr ""
Expand Down
27 changes: 27 additions & 0 deletions tests/testthat/test-hardhat_interfaces.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,32 @@ test_that("we can prune head of restored models from disk", {
expect_equal(all(stringr::str_detect("final_mapping", names(pruned_pretrain$children))),FALSE)
})

})

test_that("using weights raise a message", {

testthat::skip_on_ci()

# dataframe interface
expect_message(
fit <- tabnet_fit(x, y, epochs = 1, weights = 27),
"Configured `weights`"
)

# formula interface
expect_message(
fit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 1, weights = 27),
"Configured `weights`"
)

# recipe interface
rec <- recipe(Attrition ~ ., data = attrition) %>%
step_normalize(all_numeric(), -all_outcomes())

expect_message(
fit <- tabnet_fit(rec, attrition[1:256,], epochs = 1, weights = "whatever"),
"Configured `weights`"
)

})

0 comments on commit 10c8e52

Please sign in to comment.