Skip to content

Commit

Permalink
add max_tokens parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanpieper committed Sep 15, 2024
1 parent 41d93df commit 2bc04b5
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 67 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ importFrom(shiny,fileInput)
importFrom(shiny,fluidPage)
importFrom(shiny,fluidRow)
importFrom(shiny,h1)
importFrom(shiny,h6)
importFrom(shiny,hr)
importFrom(shiny,icon)
importFrom(shiny,img)
Expand Down
125 changes: 86 additions & 39 deletions R/batchLLM.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
#' @param LLM A string for the name of the LLM with the options: "openai", "anthropic", and "google". Default is "openai".
#' @param model A string for the name of the model from the LLM. Default is "gpt-4o-mini".
#' @param temperature A temperature for the LLM model. Default is .5.
#' @param max_tokens A maximum number of tokens to generate before stopping. Default is 500.
#' @param batch_delay A string for the batch delay with the options: "random", "min", and "sec". Numeric examples include "1min" and "30sec". Default is "random" which is an average of 10.86 seconds (n = 1,000 simulations).
#' @param batch_size The number of rows to process in each batch. Default is 10.
#' @param extract_XML Extract the LLM text completion from the model's response by returning only content returned in XML tags. This helps prevent unwanted text (e.g., preamble) from being included in the model's output. Default is TRUE.
#' @param attempts The maximum number of loop retry attempts. Default is 1.
#' @param log_name A string for the name of the log without the \code{.rds} file extension. Default is "batchLLM-log".
#' @param hash_algo A string for a hashing algorithm from the 'digest' package. Default is \code{crc32c}.
Expand Down Expand Up @@ -55,6 +57,7 @@
#' prompt = "classify as a fact or misinformation in one word",
#' LLM = config$LLM,
#' model = config$model,
#' batch_size = 10,
#' batch_delay = "1min",
#' case_convert = "lower"
#' )
Expand All @@ -70,8 +73,10 @@ batchLLM <- function(df,
LLM = "openai",
model = "gpt-4o-mini",
temperature = .5,
max_tokens = 500,
batch_delay = "random",
batch_size = 10,
extract_XML = TRUE,
attempts = 1,
log_name = "batchLLM-log",
hash_algo = "crc32c",
Expand Down Expand Up @@ -131,62 +136,95 @@ batchLLM <- function(df,
}
}

batch_mutate <- function(df, df_col, df_string, col_string, system_prompt, batch_size, batch_delay, batch_num, batch_total, LLM, model, temperature, start_row, total_rows, log_name, case_convert, ...) {
mutate_row <- function(df_string, df_row, content_input, system_prompt, LLM, model, temperature, batch_delay, log_name, case_convert, ...) {
extract_xml <- function(content, tag = "results") {
if (is.null(content) || is.na(content) || !nzchar(content)) {
return(NA_character_)
}
result <- sub(paste0(".*<", tag, ">(.*?)</", tag, ">.*"), "\\1", content)
if (result == content) {
return(NA_character_)
}
return(result)
}

sanitize_output <- function(content, case_convert) {
if (!is.null(case_convert) && case_convert != "none") {
if (case_convert == "upper") {
content <- toupper(content)
} else if (case_convert == "lower") {
content <- tolower(content)
}
}
return(trimws(content))
}

batch_mutate <- function(df, df_col, df_string, col_string, system_prompt, batch_size, batch_delay, batch_num, batch_total, LLM, model, temperature, start_row, total_rows, log_name, case_convert, max_tokens, extract_XML, ...) {
build_prompt <- function(system_prompt, content_input, extract_XML) {
if (extract_XML) {
paste0(system_prompt, "(put response in a single level of XML tags <results>):", as.character(content_input))
} else {
paste0(system_prompt, ":", as.character(content_input))
}
}

mutate_row <- function(df_string, df_row, content_input, system_prompt, LLM, model, temperature, batch_delay, log_name, case_convert, max_tokens, extract_XML, ...) {
if (length(content_input) == 1 && !is.na(content_input)) {
tryCatch(
{
prompt <- build_prompt(system_prompt, content_input, extract_XML)
content_output <- NULL

if (grepl("openai", LLM)) {
completion <- create_chat_completion(
model = model,
temperature = temperature,
max_tokens = max_tokens,
messages = list(
list(
"role" = "system",
"content" = system_prompt
),
list(
"role" = "user",
"content" = as.character(content_input)
)
list("role" = "system", "content" = prompt),
list("role" = "user", "content" = as.character(content_input))
),
...
)
content_output <- trimws(completion$choices$message.content)
if (extract_XML) {
content_output <- extract_xml(completion$choices$message.content)
} else {
content_output <- completion$choices$message.content
}
} else if (grepl("anthropic", LLM)) {
content_output <- claudeR(
completion <- claudeR(
prompt = if (grepl("claude-2", model)) {
paste0(system_prompt, "(put response in a single level of XML tags <results>)", ":", as.character(content_input))
prompt
} else if (grepl("claude-3", model)) {
list(list(role = "user", content = paste0(system_prompt, "(put response in a single level of XML tags <results>)", ":", as.character(content_input))))
list(list(role = "user", content = prompt))
},
model = model,
temperature = temperature,
max_tokens = max_tokens,
...
)
content_output <- sub(".*<results>(.*?)</results>.*", "\\1", content_output)
if (extract_XML) {
content_output <- extract_xml(completion)
} else {
content_output <- completion
}
} else if (grepl("google", LLM)) {
content_output <- gemini.R::gemini_chat(
prompt = paste0(system_prompt, "(put response in a single level of XML tags <results>)", ":", as.character(content_input)),
completion <- gemini.R::gemini_chat(
prompt = prompt,
model = model,
temperature = temperature,
maxOutputTokens = max_tokens,
...
)
content_output <- sub(".*<results>(.*?)</results>.*", "\\1", content_output)
}
if (is.null(content_output)) {
stop("Failed to obtain content_output.")
}
if (!is.null(content_output) && length(content_output) > 0) {
output_text <- content_output[1]
if (!is.null(case_convert) || case_convert != "none") {
if (case_convert == "upper") {
output_text <- toupper(output_text)
} else if (case_convert == "lower") {
output_text <- tolower(output_text)
}
if (extract_XML) {
content_output <- extract_xml(completion$outputs[["text"]])
} else {
content_output <- completion$outputs[["text"]]
}
}

if (is.null(content_output)) stop("Failed to obtain content_output.")
if (length(content_output) > 0) {
output_text <- sanitize_output(content_output[1], case_convert)
return(output_text)
} else {
stop("Error: completion message content returned NULL or empty")
Expand Down Expand Up @@ -217,7 +255,9 @@ batchLLM <- function(df,
temperature = temperature,
batch_delay = batch_delay,
log_name = log_name,
case_convert = case_convert
case_convert = case_convert,
max_tokens = max_tokens,
extract_XML = extract_XML
)
},
error = function(e) {
Expand Down Expand Up @@ -266,8 +306,8 @@ batchLLM <- function(df,
if (!col_string %in% colnames(df)) {
stop(paste("Column", col_string, "does not exist in the input data frame."))
}
param_id <- digest::digest(list(prompt, LLM, model, temperature, batch_size), algo = hash_algo)
new_col <- paste0(col_string, "_", param_id)
param_id <- digest::digest(list(LLM, model, temperature, prompt, batch_size, batch_delay, max_tokens, extract_XML), algo = hash_algo)
new_col <- paste0(col_string, "_", param_id, algo = hash_algo)
param_id <- digest::digest(df[[col_string]], algo = hash_algo)
new_df_key <- paste0(df_string, "_", param_id)
batch_log <- load_progress(log_name)
Expand All @@ -286,7 +326,7 @@ batchLLM <- function(df,
0
}
if (nrow(output) == nrow(df) && !any(is.na(output[[new_col]]))) {
message("All rows have already been processed for this column. No further processing needed.")
message("All rows have already been processed for this column using the current config.")
df[[new_col]] <- output[[new_col]]
return(df)
}
Expand Down Expand Up @@ -352,7 +392,9 @@ batchLLM <- function(df,
start_row = start_row,
total_rows = nrow(df),
log_name = log_name,
case_convert = case_convert
case_convert = case_convert,
max_tokens = max_tokens,
extract_XML = extract_XML
)
output[rows_to_process, new_col] <- output_batch$llm_output
current_time <- Sys.time()
Expand Down Expand Up @@ -402,7 +444,7 @@ batchLLM <- function(df,
)
}
if (end_row == nrow(df)) {
message("All ", batch_total, " batches processed")
message("\u2714 All ", batch_total, " batches processed")
break
}
start_time <- Sys.time()
Expand Down Expand Up @@ -600,10 +642,15 @@ scrape_metadata <- function(df_name = NULL, log_name = "batchLLM-log") {
#' cat(response)
#' }
claudeR <- function(
prompt, model = "claude-3-5-sonnet-20240620", max_tokens = 100,
prompt,
model = "claude-3-5-sonnet-20240620",
max_tokens = 500,
stop_sequences = NULL,
temperature = .7, top_k = -1, top_p = -1,
api_key = NULL, system_prompt = NULL) {
temperature = .7,
top_k = -1,
top_p = -1,
api_key = NULL,
system_prompt = NULL) {
if (grepl("claude-3", model) && !is.list(prompt)) {
stop("Claude-3 requires the input in a list format, e.g., list(list(role = \"user\", content = \"What is the capital of France?\"))")
}
Expand Down
72 changes: 49 additions & 23 deletions R/batchLLM_shiny.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' @export
#' @importFrom shiny fluidPage fluidRow column titlePanel tabPanel tabsetPanel conditionalPanel HTML sidebarLayout sidebarPanel
#' @importFrom shiny textInput numericInput downloadButton updateTextInput tags br hr h1 p img uiOutput textAreaInput
#' @importFrom shiny textInput numericInput downloadButton updateTextInput tags br hr h1 h6 p img uiOutput textAreaInput
#' @importFrom shiny sliderInput actionButton icon mainPanel observe req
#' @importFrom shiny selectInput updateSelectInput renderUI observeEvent
#' @importFrom shiny runGadget paneViewer fileInput showNotification
Expand All @@ -23,7 +23,7 @@ batchLLM_shiny <- function() {
library(batchLLM)

df_objects <- "beliefs"

create_exportable_datatable <- function(data, filename_prefix) {
datatable(
data,
Expand All @@ -35,12 +35,18 @@ batchLLM_shiny <- function() {
list(
extend = "collection",
buttons = list(
list(extend = "csv", filename = paste0(filename_prefix, "_export"),
exportOptions = list(modifier = list(page = "all"))),
list(extend = "excel", filename = paste0(filename_prefix, "_export"),
exportOptions = list(modifier = list(page = "all"))),
list(extend = "pdf", filename = paste0(filename_prefix, "_export"),
exportOptions = list(modifier = list(page = "all")))
list(
extend = "csv", filename = paste0(filename_prefix, "_export"),
exportOptions = list(modifier = list(page = "all"))
),
list(
extend = "excel", filename = paste0(filename_prefix, "_export"),
exportOptions = list(modifier = list(page = "all"))
),
list(
extend = "pdf", filename = paste0(filename_prefix, "_export"),
exportOptions = list(modifier = list(page = "all"))
)
),
text = "Download"
)
Expand Down Expand Up @@ -147,10 +153,11 @@ batchLLM_shiny <- function() {
radioGroupButtons(
inputId = "toggle_delay",
label = "Batch Delay:",
choices = c("Random (~10 Sec)" = "random", "30 Sec" = "30sec", "1 Min" = "1min"),
choices = c("Random" = "random", "30 Sec" = "30sec", "1 Min" = "1min"),
selected = "random",
justified = TRUE
),
h6("Random is an average of 10.86 seconds."),
numericInput(
inputId = "batch_size",
label = "Batch Size (Rows per Batch):",
Expand Down Expand Up @@ -318,6 +325,14 @@ batchLLM_shiny <- function() {
value = config$temperature,
step = 0.1
),
sliderInput(
inputId = paste0(config$id, "_max_tokens"),
label = "Maximum Tokens:",
min = 100,
max = 4000,
value = 500,
step = 50
),
actionButton(
inputId = paste0("remove_", config$id),
label = "Remove LLM"
Expand Down Expand Up @@ -451,7 +466,8 @@ batchLLM_shiny <- function() {
list(
LLM = input[[paste0(config$id, "_llm")]],
model = input[[paste0(config$id, "_model")]],
temperature = input[[paste0(config$id, "_temperature")]]
temperature = input[[paste0(config$id, "_temperature")]],
max_tokens = input[[paste0(config$id, "_max_tokens")]]
)
})

Expand Down Expand Up @@ -486,6 +502,7 @@ batchLLM_shiny <- function() {
batch_size = input$batch_size,
model = config$model,
temperature = config$temperature,
max_tokens = config$max_tokens,
case_convert = input$case_convert
)
}
Expand All @@ -509,14 +526,20 @@ batchLLM_shiny <- function() {
})
})

output$data_results <- renderDataTable({
data_to_show <- if (!is.null(result())) result() else selected_data()
create_exportable_datatable(data_to_show, "data")
}, server = FALSE)

output$metadata_table <- renderDataTable({
create_exportable_datatable(current_metadata(), "metadata")
}, server = FALSE)
output$data_results <- renderDataTable(
{
data_to_show <- if (!is.null(result())) result() else selected_data()
create_exportable_datatable(data_to_show, "data")
},
server = FALSE
)

output$metadata_table <- renderDataTable(
{
create_exportable_datatable(current_metadata(), "metadata")
},
server = FALSE
)

current_batch_data <- reactiveVal(NULL)

Expand Down Expand Up @@ -567,10 +590,13 @@ batchLLM_shiny <- function() {
current_batch_data(batch_data)
})

output$batch_table <- renderDataTable({
req(current_batch_data())
create_exportable_datatable(current_batch_data(), "batch")
}, server = FALSE)
output$batch_table <- renderDataTable(
{
req(current_batch_data())
create_exportable_datatable(current_batch_data(), "batch")
},
server = FALSE
)

observe({
req(input$df_name)
Expand All @@ -588,7 +614,7 @@ batchLLM_shiny <- function() {
menuItem("Download Log", tabName = "download_log", icon = icon("download"))
})
})

observe({
observeEvent(input$datafile, {
current_metadata()
Expand Down
Loading

0 comments on commit 2bc04b5

Please sign in to comment.