Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding bulk load for Spark (DataBricks) #301

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ Suggests:
odbc,
duckdb,
pool,
ParallelLogger
ParallelLogger,
AzureStor
License: Apache License
VignetteBuilder: knitr
URL: https://ohdsi.github.io/DatabaseConnector/, https://github.com/OHDSI/DatabaseConnector
Expand Down
1 change: 1 addition & 0 deletions DatabaseConnector.Rproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Version: 1.0
ProjectId: 9d51e576-41a3-432f-b696-8bfdc3eed676
schuemie marked this conversation as resolved.
Show resolved Hide resolved

RestoreWorkspace: No
SaveWorkspace: No
Expand Down
81 changes: 81 additions & 0 deletions R/BulkLoad.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@
return(FALSE)
}
return(TRUE)
} else if (dbms(connection) == "spark") {
envSet <- FALSE
container <- FALSE

Check warning on line 67 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L65-L67

Added lines #L65 - L67 were not covered by tests

if (Sys.getenv("AZR_STORAGE_ACCOUNT") != "" && Sys.getenv("AZR_ACCOUNT_KEY") != "" && Sys.setenv("AZR_CONTAINER_NAME") != "") {
envSet <- TRUE

Check warning on line 70 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L69-L70

Added lines #L69 - L70 were not covered by tests
}

# List storage containers to confirm the container
# specified in the configuration exists
ensure_installed("AzureStor")
azureEndpoint <- getAzureEndpoint()
containerList <- getAzureContainerNames(azureEndpoint)

Check warning on line 77 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L75-L77

Added lines #L75 - L77 were not covered by tests

if (Sys.getenv("AZR_CONTAINER_NAME") %in% containerList) {
container <- TRUE

Check warning on line 80 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L79-L80

Added lines #L79 - L80 were not covered by tests
}

return(envSet & container)

Check warning on line 83 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L83

Added line #L83 was not covered by tests
} else {
return(FALSE)
}
Expand All @@ -72,6 +91,18 @@
return(if (sshUser == "") "root" else sshUser)
}

getAzureEndpoint <- function() {
azureEndpoint <- AzureStor::storage_endpoint(
paste0("https://", Sys.getenv("AZR_STORAGE_ACCOUNT"), ".dfs.core.windows.net"),
key = Sys.getenv("AZR_ACCOUNT_KEY")
)
return(azureEndpoint)

Check warning on line 99 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L95-L99

Added lines #L95 - L99 were not covered by tests
}

getAzureContainerNames <- function(azureEndpoint) {
return(names(AzureStor::list_storage_containers(azureEndpoint)))

Check warning on line 103 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L103

Added line #L103 was not covered by tests
}

countRows <- function(connection, sqlTableName) {
sql <- "SELECT COUNT(*) FROM @table"
count <- renderTranslateQuerySql(
Expand Down Expand Up @@ -354,3 +385,53 @@
delta <- Sys.time() - startTime
inform(paste("Bulk load to PostgreSQL took", signif(delta, 3), attr(delta, "units")))
}

bulkLoadSpark <- function(connection, sqlTableName, data) {
ensure_installed("AzureStor")
logTrace(sprintf("Inserting %d rows into table '%s' using DataBricks bulk load", nrow(data), sqlTableName))
start <- Sys.time()

Check warning on line 392 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L390-L392

Added lines #L390 - L392 were not covered by tests

csvFileName <- tempfile("spark_insert_", fileext = ".csv")
write.csv(x = data, na = "", file = csvFileName, row.names = FALSE, quote = TRUE)
on.exit(unlink(csvFileName))

Check warning on line 396 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L394-L396

Added lines #L394 - L396 were not covered by tests

azureEndpoint <- getAzureEndpoint()
containers <- AzureStor::list_storage_containers(azureEndpoint)
targetContainer <- containers[[Sys.getenv("AZR_CONTAINER_NAME")]]
AzureStor::storage_upload(
targetContainer,
src=csvFileName,
dest=csvFileName
)

Check warning on line 405 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L398-L405

Added lines #L398 - L405 were not covered by tests

on.exit(
AzureStor::delete_storage_file(
targetContainer,
file = csvFileName,
confirm = FALSE
),
add = TRUE
)

Check warning on line 414 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L407-L414

Added lines #L407 - L414 were not covered by tests

sql <- SqlRender::loadRenderTranslateSql(
sqlFilename = "sparkCopy.sql",
packageName = "DatabaseConnector",
dbms = "spark",
sqlTableName = sqlTableName,
fileName = basename(csvFileName),
azureAccountKey = Sys.getenv("AZR_ACCOUNT_KEY"),
azureStorageAccount = Sys.getenv("AZR_STORAGE_ACCOUNT")
)

Check warning on line 424 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L416-L424

Added lines #L416 - L424 were not covered by tests

tryCatch(
{
DatabaseConnector::executeSql(connection = connection, sql = sql, reportOverallTime = FALSE)
},
error = function(e) {
abort("Error in DataBricks bulk upload. Please check DataBricks/Azure Storage access.")
}
)
delta <- Sys.time() - start
inform(paste("Bulk load to DataBricks took", signif(delta, 3), attr(delta, "units")))

Check warning on line 435 in R/BulkLoad.R

View check run for this annotation

Codecov / codecov/patch

R/BulkLoad.R#L426-L435

Added lines #L426 - L435 were not covered by tests
}

9 changes: 9 additions & 0 deletions R/InsertTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@
#' "some_aws_region", "AWS_BUCKET_NAME" = "some_bucket_name", "AWS_OBJECT_KEY" = "some_object_key",
#' "AWS_SSE_TYPE" = "server_side_encryption_type").
#'
#' Spark (DataBricks): The MPP bulk loading relies upon the AzureStor library
#' to test a connection to an Azure ADLS Gen2 storage container using Azure credentials.
#' Credentials are configured directly into the System Environment using the
#' following keys: Sys.setenv("AZR_STORAGE_ACCOUNT" =
#' "some_azure_storage_account", "AZR_ACCOUNT_KEY" = "some_secret_account_key", "AZR_CONTAINER_NAME" =
#' "some_container_name").
#'
#' PDW: The MPP bulk loading relies upon the client
#' having a Windows OS and the DWLoader exe installed, and the following permissions granted: --Grant
#' BULK Load permissions - needed at a server level USE master; GRANT ADMINISTER BULK OPERATIONS TO
Expand Down Expand Up @@ -308,6 +315,8 @@
bulkLoadHive(connection, sqlTableName, sqlFieldNames, data)
} else if (dbms == "postgresql") {
bulkLoadPostgres(connection, sqlTableName, sqlFieldNames, sqlDataTypes, data)
} else if (dbms == "spark") {
bulkLoadSpark(connection, sqlTableName, data)

Check warning on line 319 in R/InsertTable.R

View check run for this annotation

Codecov / codecov/patch

R/InsertTable.R#L318-L319

Added lines #L318 - L319 were not covered by tests
}
} else if (useCtasHack) {
# Inserting using CTAS hack ----------------------------------------------------------------
Expand Down
34 changes: 34 additions & 0 deletions extras/TestBulkLoad.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,37 @@ all.equal(data, data2)

renderTranslateExecuteSql(connection, "DROP TABLE scratch_mschuemi.insert_test;")
disconnect(connection)


# Spark ------------------------------------------------------------------------------
# Assumes Spark (DataBricks) environmental variables have been set
options(sqlRenderTempEmulationSchema = Sys.getenv("DATABRICKS_SCRATCH_SCHEMA"))
databricksConnectionString <- paste0("jdbc:databricks://", Sys.getenv('DATABRICKS_HOST'), "/default;transportMode=http;ssl=1;AuthMech=3;httpPath=", Sys.getenv('DATABRICKS_HTTP_PATH'))
connectionDetails <- createConnectionDetails(dbms = "spark",
connectionString = databricksConnectionString,
user = "token",
password = Sys.getenv("DATABRICKS_TOKEN"))


connection <- connect(connectionDetails)
system.time(
insertTable(connection = connection,
tableName = "scratch.scratch_asena5.insert_test",
data = data,
dropTableIfExists = TRUE,
createTable = TRUE,
tempTable = FALSE,
progressBar = TRUE,
camelCaseToSnakeCase = TRUE,
bulkLoad = TRUE)
)
data2 <- querySql(connection, "SELECT * FROM scratch.scratch_asena5.insert_test;", snakeCaseToCamelCase = TRUE, integer64AsNumeric = FALSE)

data <- data[order(data$id), ]
data2 <- data2[order(data2$id), ]
row.names(data) <- NULL
row.names(data2) <- NULL
all.equal(data, data2)

renderTranslateExecuteSql(connection, "DROP TABLE scratch.scratch_asena5.insert_test;")
disconnect(connection)
10 changes: 10 additions & 0 deletions inst/sql/sql_server/sparkCopy.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
COPY INTO @sqlTableName
FROM 'abfss://@azureStorageAccount.dfs.core.windows.net/@fileName'
WITH (
CREDENTIAL (AZURE_SAS_TOKEN = '@azureAccountKey')
)
FILEFORMAT = CSV
FORMAT_OPTIONS (
'header' = 'true',
'inferSchema' = 'true'
);
Loading