Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data loader for spam dataset with tests and documentation #1224

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions R/spam-dataloader.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#' Spam Data Loader
#'
#' A dataloader for the spam dataset commonly used in machine learning. The dataset
#' contains 57 features extracted from email messages and a binary label indicating
#' whether an email is spam (1) or not spam (0).
#'
#' @param url A character string representing the URL of the dataset. Defaults to
#' "https://hastie.su.domains/ElemStatLearn/datasets/spam.data".
#' @param batch_size Number of samples per batch. Defaults to 32.
#' @param shuffle Logical; whether to shuffle the data. Defaults to TRUE.
#' @param download Logical; whether to download the dataset if not already available. Defaults to FALSE.
#' @return A dataloader object for the spam dataset.
#' @examples
#' dl <- spam_dataloader(batch_size = 32, shuffle = TRUE)
#' iter <- dl$.iter()
#' batch <- iter$.next()
#' print(batch)
#' @export
spam_dataloader <- function(url = "https://hastie.su.domains/ElemStatLearn/datasets/spam.data",
batch_size = 32, shuffle = TRUE, download = FALSE) {
library(torch) # Ensure torch is loaded

# Download the dataset if needed
data_path <- tempfile(fileext = ".data")
if (download) {
download.file(url, data_path)
} else {
data_path <- url
}

# Load and preprocess the dataset
spam_data <- read.table(data_path, header = FALSE)
x_data <- as.matrix(spam_data[, -ncol(spam_data)]) # Extract predictors
y_data <- as.numeric(spam_data[, ncol(spam_data)]) - 1 # Extract target (convert to 0/1)

# Convert data to tensors
x_tensor <- torch_tensor(x_data, dtype = torch_float())
y_tensor <- torch_tensor(y_data, dtype = torch_long())

# Define the dataset class
spam_dataset <- dataset(
name = "spam_dataset",
initialize = function(x, y) {
self$x <- x
self$y <- y
},
.getbatch = function(index) {
list(
x = self$x[index, ],
y = self$y[index]
)
},
.length = function() {
self$y$size(1)
}
)

# Create the dataset and dataloader
dataset <- spam_dataset(x = x_tensor, y = y_tensor)
dataloader(dataset, batch_size = batch_size, shuffle = shuffle)
}
26 changes: 26 additions & 0 deletions tests/testthat/test-spam-dataloader.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
library(testthat)
library(torch)

test_that("spam_dataloader loads and batches data correctly", {
# Create the dataloader with batch size 32 and shuffling enabled
dl <- spam_dataloader(batch_size = 32, shuffle = TRUE, download = TRUE)

# Check if the returned object is a dataloader
expect_true(inherits(dl, "dataloader"), "The returned object is not a dataloader.")

# Get the first batch
iter <- dl$.iter()
batch <- iter$.next()

# Verify the batch structure
expect_equal(length(batch), 2, "The batch should be a list with two elements (x and y).")
expect_equal(batch[[1]]$dim()[2], 57, "The predictors (x) should have 57 features.")

# Verify the data types
expect_true(batch[[1]]$dtype() == torch_float(), "The predictors (x) should have dtype torch_float.")
expect_true(batch[[2]]$dtype() == torch_long(), "The labels (y) should have dtype torch_long.")

# Verify batch size
expect_equal(batch[[1]]$size(1), 32, "The batch size for predictors (x) should match 32.")
expect_equal(batch[[2]]$size(1), 32, "The batch size for labels (y) should match 32.")
})
44 changes: 28 additions & 16 deletions tools/buildlantern.R
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
if (dir.exists("src/lantern")) {
cat("Building lantern .... \n")

dir.create("src/lantern/build", showWarnings = FALSE, recursive = TRUE)
cat("Starting Lantern build process...\n")

if (dir.exists("src/lantern")) {
cat("Lantern directory exists. Proceeding...\n")

# Ensure the build directory is created
if (!dir.exists("src/lantern/build")) {
cat("Creating build directory...\n")
dir.create("src/lantern/build", recursive = TRUE)
cat("Build directory created.\n")
} else {
cat("Build directory already exists.\n")
}

# Run CMake commands
withr::with_dir("src/lantern/build", {
system("cmake ..")
system("cmake --build . --target lantern --config Release --parallel 8")
cat("Running CMake configuration with MinGW...\n")
cmake_result <- system(
"cmake -G \"MinGW Makefiles\" -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_BUILD_TYPE=Release ..",
intern = TRUE
)
cat("CMake configuration output:\n", paste(cmake_result, collapse = "\n"), "\n")

cat("Running make with MinGW...\n")
make_result <- system("mingw32-make VERBOSE=1", intern = TRUE)
cat("Make output:\n", paste(make_result, collapse = "\n"), "\n")
})

# copy lantern
source("R/install.R")
source("R/lantern_sync.R")
lantern_sync(TRUE)

# download torch
install_torch(path = normalizePath("inst/"), load = FALSE)

cat("Lantern build process completed successfully.\n")
} else {
cat("Lantern directory does not exist. Please check the path.\n")
}