diff --git a/NAMESPACE b/NAMESPACE index 568353c..c4fa537 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -38,6 +38,7 @@ importFrom(shiny,fileInput) importFrom(shiny,fluidPage) importFrom(shiny,fluidRow) importFrom(shiny,h1) +importFrom(shiny,h6) importFrom(shiny,hr) importFrom(shiny,icon) importFrom(shiny,img) diff --git a/R/batchLLM.R b/R/batchLLM.R index 13b8736..9a91814 100644 --- a/R/batchLLM.R +++ b/R/batchLLM.R @@ -12,8 +12,10 @@ #' @param LLM A string for the name of the LLM with the options: "openai", "anthropic", and "google". Default is "openai". #' @param model A string for the name of the model from the LLM. Default is "gpt-4o-mini". #' @param temperature A temperature for the LLM model. Default is .5. +#' @param max_tokens A maximum number of tokens to generate before stopping. Default is 500. #' @param batch_delay A string for the batch delay with the options: "random", "min", and "sec". Numeric examples include "1min" and "30sec". Default is "random" which is an average of 10.86 seconds (n = 1,000 simulations). #' @param batch_size The number of rows to process in each batch. Default is 10. +#' @param extract_XML Extract the LLM text completion from the model's response by returning only content returned in XML tags. This helps prevent unwanted text (e.g., preamble) from being included in the model's output. Default is TRUE. #' @param attempts The maximum number of loop retry attempts. Default is 1. #' @param log_name A string for the name of the log without the \code{.rds} file extension. Default is "batchLLM-log". #' @param hash_algo A string for a hashing algorithm from the 'digest' package. Default is \code{crc32c}. @@ -55,6 +57,7 @@ #' prompt = "classify as a fact or misinformation in one word", #' LLM = config$LLM, #' model = config$model, +#' batch_size = 10, #' batch_delay = "1min", #' case_convert = "lower" #' ) @@ -70,8 +73,10 @@ batchLLM <- function(df, LLM = "openai", model = "gpt-4o-mini", temperature = .5, + max_tokens = 500, batch_delay = "random", batch_size = 10, + extract_XML = TRUE, attempts = 1, log_name = "batchLLM-log", hash_algo = "crc32c", @@ -131,62 +136,95 @@ batchLLM <- function(df, } } - batch_mutate <- function(df, df_col, df_string, col_string, system_prompt, batch_size, batch_delay, batch_num, batch_total, LLM, model, temperature, start_row, total_rows, log_name, case_convert, ...) { - mutate_row <- function(df_string, df_row, content_input, system_prompt, LLM, model, temperature, batch_delay, log_name, case_convert, ...) { + extract_xml <- function(content, tag = "results") { + if (is.null(content) || is.na(content) || !nzchar(content)) { + return(NA_character_) + } + result <- sub(paste0(".*<", tag, ">(.*?).*"), "\\1", content) + if (result == content) { + return(NA_character_) + } + return(result) + } + + sanitize_output <- function(content, case_convert) { + if (!is.null(case_convert) && case_convert != "none") { + if (case_convert == "upper") { + content <- toupper(content) + } else if (case_convert == "lower") { + content <- tolower(content) + } + } + return(trimws(content)) + } + + batch_mutate <- function(df, df_col, df_string, col_string, system_prompt, batch_size, batch_delay, batch_num, batch_total, LLM, model, temperature, start_row, total_rows, log_name, case_convert, max_tokens, extract_XML, ...) { + build_prompt <- function(system_prompt, content_input, extract_XML) { + if (extract_XML) { + paste0(system_prompt, "(put response in a single level of XML tags ):", as.character(content_input)) + } else { + paste0(system_prompt, ":", as.character(content_input)) + } + } + + mutate_row <- function(df_string, df_row, content_input, system_prompt, LLM, model, temperature, batch_delay, log_name, case_convert, max_tokens, extract_XML, ...) { if (length(content_input) == 1 && !is.na(content_input)) { tryCatch( { + prompt <- build_prompt(system_prompt, content_input, extract_XML) content_output <- NULL + if (grepl("openai", LLM)) { completion <- create_chat_completion( model = model, temperature = temperature, + max_tokens = max_tokens, messages = list( - list( - "role" = "system", - "content" = system_prompt - ), - list( - "role" = "user", - "content" = as.character(content_input) - ) + list("role" = "system", "content" = prompt), + list("role" = "user", "content" = as.character(content_input)) ), ... ) - content_output <- trimws(completion$choices$message.content) + if (extract_XML) { + content_output <- extract_xml(completion$choices$message.content) + } else { + content_output <- completion$choices$message.content + } } else if (grepl("anthropic", LLM)) { - content_output <- claudeR( + completion <- claudeR( prompt = if (grepl("claude-2", model)) { - paste0(system_prompt, "(put response in a single level of XML tags )", ":", as.character(content_input)) + prompt } else if (grepl("claude-3", model)) { - list(list(role = "user", content = paste0(system_prompt, "(put response in a single level of XML tags )", ":", as.character(content_input)))) + list(list(role = "user", content = prompt)) }, model = model, temperature = temperature, + max_tokens = max_tokens, ... ) - content_output <- sub(".*(.*?).*", "\\1", content_output) + if (extract_XML) { + content_output <- extract_xml(completion) + } else { + content_output <- completion + } } else if (grepl("google", LLM)) { - content_output <- gemini.R::gemini_chat( - prompt = paste0(system_prompt, "(put response in a single level of XML tags )", ":", as.character(content_input)), + completion <- gemini.R::gemini_chat( + prompt = prompt, model = model, temperature = temperature, + maxOutputTokens = max_tokens, ... ) - content_output <- sub(".*(.*?).*", "\\1", content_output) - } - if (is.null(content_output)) { - stop("Failed to obtain content_output.") - } - if (!is.null(content_output) && length(content_output) > 0) { - output_text <- content_output[1] - if (!is.null(case_convert) || case_convert != "none") { - if (case_convert == "upper") { - output_text <- toupper(output_text) - } else if (case_convert == "lower") { - output_text <- tolower(output_text) - } + if (extract_XML) { + content_output <- extract_xml(completion$outputs[["text"]]) + } else { + content_output <- completion$outputs[["text"]] } + } + + if (is.null(content_output)) stop("Failed to obtain content_output.") + if (length(content_output) > 0) { + output_text <- sanitize_output(content_output[1], case_convert) return(output_text) } else { stop("Error: completion message content returned NULL or empty") @@ -217,7 +255,9 @@ batchLLM <- function(df, temperature = temperature, batch_delay = batch_delay, log_name = log_name, - case_convert = case_convert + case_convert = case_convert, + max_tokens = max_tokens, + extract_XML = extract_XML ) }, error = function(e) { @@ -266,8 +306,8 @@ batchLLM <- function(df, if (!col_string %in% colnames(df)) { stop(paste("Column", col_string, "does not exist in the input data frame.")) } - param_id <- digest::digest(list(prompt, LLM, model, temperature, batch_size), algo = hash_algo) - new_col <- paste0(col_string, "_", param_id) + param_id <- digest::digest(list(LLM, model, temperature, prompt, batch_size, batch_delay, max_tokens, extract_XML), algo = hash_algo) + new_col <- paste0(col_string, "_", param_id, algo = hash_algo) param_id <- digest::digest(df[[col_string]], algo = hash_algo) new_df_key <- paste0(df_string, "_", param_id) batch_log <- load_progress(log_name) @@ -286,7 +326,7 @@ batchLLM <- function(df, 0 } if (nrow(output) == nrow(df) && !any(is.na(output[[new_col]]))) { - message("All rows have already been processed for this column. No further processing needed.") + message("All rows have already been processed for this column using the current config.") df[[new_col]] <- output[[new_col]] return(df) } @@ -352,7 +392,9 @@ batchLLM <- function(df, start_row = start_row, total_rows = nrow(df), log_name = log_name, - case_convert = case_convert + case_convert = case_convert, + max_tokens = max_tokens, + extract_XML = extract_XML ) output[rows_to_process, new_col] <- output_batch$llm_output current_time <- Sys.time() @@ -402,7 +444,7 @@ batchLLM <- function(df, ) } if (end_row == nrow(df)) { - message("All ", batch_total, " batches processed") + message("\u2714 All ", batch_total, " batches processed") break } start_time <- Sys.time() @@ -600,10 +642,15 @@ scrape_metadata <- function(df_name = NULL, log_name = "batchLLM-log") { #' cat(response) #' } claudeR <- function( - prompt, model = "claude-3-5-sonnet-20240620", max_tokens = 100, + prompt, + model = "claude-3-5-sonnet-20240620", + max_tokens = 500, stop_sequences = NULL, - temperature = .7, top_k = -1, top_p = -1, - api_key = NULL, system_prompt = NULL) { + temperature = .7, + top_k = -1, + top_p = -1, + api_key = NULL, + system_prompt = NULL) { if (grepl("claude-3", model) && !is.list(prompt)) { stop("Claude-3 requires the input in a list format, e.g., list(list(role = \"user\", content = \"What is the capital of France?\"))") } diff --git a/R/batchLLM_shiny.R b/R/batchLLM_shiny.R index 3f6e8b6..02f262c 100644 --- a/R/batchLLM_shiny.R +++ b/R/batchLLM_shiny.R @@ -2,7 +2,7 @@ #' #' @export #' @importFrom shiny fluidPage fluidRow column titlePanel tabPanel tabsetPanel conditionalPanel HTML sidebarLayout sidebarPanel -#' @importFrom shiny textInput numericInput downloadButton updateTextInput tags br hr h1 p img uiOutput textAreaInput +#' @importFrom shiny textInput numericInput downloadButton updateTextInput tags br hr h1 h6 p img uiOutput textAreaInput #' @importFrom shiny sliderInput actionButton icon mainPanel observe req #' @importFrom shiny selectInput updateSelectInput renderUI observeEvent #' @importFrom shiny runGadget paneViewer fileInput showNotification @@ -23,7 +23,7 @@ batchLLM_shiny <- function() { library(batchLLM) df_objects <- "beliefs" - + create_exportable_datatable <- function(data, filename_prefix) { datatable( data, @@ -35,12 +35,18 @@ batchLLM_shiny <- function() { list( extend = "collection", buttons = list( - list(extend = "csv", filename = paste0(filename_prefix, "_export"), - exportOptions = list(modifier = list(page = "all"))), - list(extend = "excel", filename = paste0(filename_prefix, "_export"), - exportOptions = list(modifier = list(page = "all"))), - list(extend = "pdf", filename = paste0(filename_prefix, "_export"), - exportOptions = list(modifier = list(page = "all"))) + list( + extend = "csv", filename = paste0(filename_prefix, "_export"), + exportOptions = list(modifier = list(page = "all")) + ), + list( + extend = "excel", filename = paste0(filename_prefix, "_export"), + exportOptions = list(modifier = list(page = "all")) + ), + list( + extend = "pdf", filename = paste0(filename_prefix, "_export"), + exportOptions = list(modifier = list(page = "all")) + ) ), text = "Download" ) @@ -147,10 +153,11 @@ batchLLM_shiny <- function() { radioGroupButtons( inputId = "toggle_delay", label = "Batch Delay:", - choices = c("Random (~10 Sec)" = "random", "30 Sec" = "30sec", "1 Min" = "1min"), + choices = c("Random" = "random", "30 Sec" = "30sec", "1 Min" = "1min"), selected = "random", justified = TRUE ), + h6("Random is an average of 10.86 seconds."), numericInput( inputId = "batch_size", label = "Batch Size (Rows per Batch):", @@ -318,6 +325,14 @@ batchLLM_shiny <- function() { value = config$temperature, step = 0.1 ), + sliderInput( + inputId = paste0(config$id, "_max_tokens"), + label = "Maximum Tokens:", + min = 100, + max = 4000, + value = 500, + step = 50 + ), actionButton( inputId = paste0("remove_", config$id), label = "Remove LLM" @@ -451,7 +466,8 @@ batchLLM_shiny <- function() { list( LLM = input[[paste0(config$id, "_llm")]], model = input[[paste0(config$id, "_model")]], - temperature = input[[paste0(config$id, "_temperature")]] + temperature = input[[paste0(config$id, "_temperature")]], + max_tokens = input[[paste0(config$id, "_max_tokens")]] ) }) @@ -486,6 +502,7 @@ batchLLM_shiny <- function() { batch_size = input$batch_size, model = config$model, temperature = config$temperature, + max_tokens = config$max_tokens, case_convert = input$case_convert ) } @@ -509,14 +526,20 @@ batchLLM_shiny <- function() { }) }) - output$data_results <- renderDataTable({ - data_to_show <- if (!is.null(result())) result() else selected_data() - create_exportable_datatable(data_to_show, "data") - }, server = FALSE) - - output$metadata_table <- renderDataTable({ - create_exportable_datatable(current_metadata(), "metadata") - }, server = FALSE) + output$data_results <- renderDataTable( + { + data_to_show <- if (!is.null(result())) result() else selected_data() + create_exportable_datatable(data_to_show, "data") + }, + server = FALSE + ) + + output$metadata_table <- renderDataTable( + { + create_exportable_datatable(current_metadata(), "metadata") + }, + server = FALSE + ) current_batch_data <- reactiveVal(NULL) @@ -567,10 +590,13 @@ batchLLM_shiny <- function() { current_batch_data(batch_data) }) - output$batch_table <- renderDataTable({ - req(current_batch_data()) - create_exportable_datatable(current_batch_data(), "batch") - }, server = FALSE) + output$batch_table <- renderDataTable( + { + req(current_batch_data()) + create_exportable_datatable(current_batch_data(), "batch") + }, + server = FALSE + ) observe({ req(input$df_name) @@ -588,7 +614,7 @@ batchLLM_shiny <- function() { menuItem("Download Log", tabName = "download_log", icon = icon("download")) }) }) - + observe({ observeEvent(input$datafile, { current_metadata() diff --git a/man/batchLLM.Rd b/man/batchLLM.Rd index 1bcd9c4..55c516e 100644 --- a/man/batchLLM.Rd +++ b/man/batchLLM.Rd @@ -12,8 +12,10 @@ batchLLM( LLM = "openai", model = "gpt-4o-mini", temperature = 0.5, + max_tokens = 500, batch_delay = "random", batch_size = 10, + extract_XML = TRUE, attempts = 1, log_name = "batchLLM-log", hash_algo = "crc32c", @@ -36,10 +38,14 @@ batchLLM( \item{temperature}{A temperature for the LLM model. Default is .5.} +\item{max_tokens}{A maximum number of tokens to generate before stopping. Default is 500.} + \item{batch_delay}{A string for the batch delay with the options: "random", "min", and "sec". Numeric examples include "1min" and "30sec". Default is "random" which is an average of 10.86 seconds (n = 1,000 simulations).} \item{batch_size}{The number of rows to process in each batch. Default is 10.} +\item{extract_XML}{Extract the LLM text completion from the model's response by returning only content returned in XML tags. This helps prevent unwanted text (e.g., preamble) from being included in the model's output. Default is TRUE.} + \item{attempts}{The maximum number of loop retry attempts. Default is 1.} \item{log_name}{A string for the name of the log without the \code{.rds} file extension. Default is "batchLLM-log".} @@ -55,7 +61,7 @@ Returns the input data frame with an additional column containing the text compl The function also writes the output and metadata to the log file after each batch in a nested list format. } \description{ -Batch process Large Language Model (LLM) text completions by looping across the rows of a data frame column. +Batch process large language model (LLM) text completions by looping across the rows of a data frame column. The package currently supports OpenAI's GPT, Anthropic's Claude, and Google's Gemini models, with built-in delays for API rate limiting. The package provides advanced text processing features, including automatic logging of batches and metadata to local files, side-by-side comparison of outputs from different LLMs, and integration of a user-friendly Shiny App Addin. Use cases include natural language processing tasks such as sentiment analysis, thematic analysis, classification, labeling or tagging, and language translation. @@ -81,9 +87,10 @@ beliefs <- lapply(llm_configs, function(config) { batchLLM( df = beliefs, col = statement, - prompt = "Classify the sentiment using one word: positive, negative, or neutral", + prompt = "classify as a fact or misinformation in one word", LLM = config$LLM, model = config$model, + batch_size = 10, batch_delay = "1min", case_convert = "lower" ) diff --git a/man/claudeR.Rd b/man/claudeR.Rd index 4dc7d25..b8f230d 100644 --- a/man/claudeR.Rd +++ b/man/claudeR.Rd @@ -7,7 +7,7 @@ claudeR( prompt, model = "claude-3-5-sonnet-20240620", - max_tokens = 100, + max_tokens = 500, stop_sequences = NULL, temperature = 0.7, top_k = -1, diff --git a/readme.md b/readme.md index a24639f..c37af55 100644 --- a/readme.md +++ b/readme.md @@ -44,6 +44,7 @@ beliefs <- lapply(llm_configs, function(config) { prompt = "classify as a fact or misinformation in one word", LLM = config$LLM, model = config$model, + max_tokens = 100, batch_size = 10, batch_delay = "1min", case_convert = "lower" @@ -82,9 +83,8 @@ print(beliefs) ## 🤝 Contributing -Contributions are welcome! Here are some upcoming features: +Contributions are welcome! Here are some features ideas: -- Add max tokens parameter for different LLMs - Function to analyze agreement between models ## 📄 License