Skip to content

Commit

Permalink
re-add claudeR
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanpieper committed Jan 30, 2025
1 parent e3a681a commit e669554
Showing 1 changed file with 160 additions and 28 deletions.
188 changes: 160 additions & 28 deletions R/batchLLM.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ batchLLM <- function(df,
df_string <- if (!is.null(df_name)) df_name else deparse(substitute(df))
col <- rlang::enquo(col)
col_string <- rlang::as_name(col)

save_progress <- function(df, df_string, col_string, last_batch, total_time, prompt, LLM, model, temperature, new_col, status, log_name) {
log_file <- paste0(log_name, ".rds")
batch_log <- if (file.exists(log_file)) {
Expand All @@ -103,9 +103,10 @@ batchLLM <- function(df,
existing_output <- batch_log$data[[df_key]]$output
suppressMessages(
suppressWarnings(
batch_log$data[[df_key]]$output <- dplyr::left_join(existing_output, df,
batch_log$data[[df_key]]$output <- dplyr::left_join(existing_output, df,
by = col_string,
suffix = c(".old", "")) |>
suffix = c(".old", "")
) |>
dplyr::select(-dplyr::ends_with(".old")) |>
dplyr::rename_with(~ sub("\\.keep$", "", .), dplyr::ends_with(".keep"))
)
Expand All @@ -127,7 +128,7 @@ batchLLM <- function(df,
)
saveRDS(batch_log, log_file)
}

load_progress <- function(log_name) {
log_file <- paste0(log_name, ".rds")
if (file.exists(log_file)) {
Expand All @@ -136,7 +137,7 @@ batchLLM <- function(df,
return(list(data = list()))
}
}

sanitize_output <- function(content, tag = "results") {
if (is.null(content) || is.na(content) || !nzchar(content)) {
return(NA_character_)
Expand All @@ -148,7 +149,7 @@ batchLLM <- function(df,
result <- gsub("[[:punct:]]", "", result)
return(result)
}

case_convert_output <- function(content, case_convert) {
if (!is.null(case_convert) && case_convert != "none") {
if (case_convert == "upper") {
Expand All @@ -159,7 +160,7 @@ batchLLM <- function(df,
}
return(trimws(content))
}

batch_mutate <- function(df, df_col, df_string, col_string, system_prompt, batch_size, batch_delay, batch_num, batch_total, LLM, model, temperature, start_row, total_rows, log_name, case_convert, sanitize, max_tokens, ...) {
build_prompt <- function(system_prompt, content_input, sanitize) {
if (sanitize) {
Expand All @@ -168,14 +169,14 @@ batchLLM <- function(df,
paste0(system_prompt, ":", as.character(content_input))
}
}

mutate_row <- function(df_string, df_row, content_input, system_prompt, LLM, model, temperature, batch_delay, log_name, case_convert, sanitize, max_tokens, ...) {
if (length(content_input) == 1 && !is.na(content_input)) {
tryCatch(
{
prompt <- build_prompt(system_prompt, content_input, sanitize)
content_output <- NULL

if (grepl("openai", LLM)) {
completion <- create_chat_completion(
model = model,
Expand Down Expand Up @@ -223,7 +224,7 @@ batchLLM <- function(df,
content_output <- completion$outputs[["text"]]
}
}

if (is.null(content_output)) stop("Failed to obtain content_output.")
if (length(content_output) > 0) {
output_text <- case_convert_output(content_output[1], case_convert)
Expand All @@ -240,10 +241,10 @@ batchLLM <- function(df,
stop("\U0001F6D1 Invalid input: content_input must be a single non-NA value")
}
}

df <- df |> dplyr::mutate(row_number = dplyr::row_number() + start_row - 1)
result <- vector("character", nrow(df))

for (i in seq_len(nrow(df))) {
if (!is.na(df_col[i])) {
result[i] <- tryCatch(
Expand Down Expand Up @@ -273,14 +274,14 @@ batchLLM <- function(df,
}
Sys.sleep(runif(1, min = 0.1, max = 0.3))
}

if (batch_num < batch_total) {
message("\u231B\uFE0F Taking a break to make the API happy")

matches <- regmatches(batch_delay, regexec("(\\d+)(\\w+)", batch_delay))[[1]]
delay_value <- as.numeric(matches[2])
delay_unit <- matches[3]

if (grepl("sec", delay_unit)) {
delay <- delay_value
} else if (grepl("min", delay_unit)) {
Expand All @@ -290,19 +291,18 @@ batchLLM <- function(df,
} else {
stop("\U0001F6D1 Invalid unit of time.")
}

for (i in 1:100) {
Sys.sleep(delay / 100)
if (i %in% c(25, 50, 75, 100)) {
message(paste0("\U0001F4A4 ", i, "% through the break"))
}
}

}
df$llm_output <- result
return(df)
}

if (!is.data.frame(df) || !inherits(df, "data.frame")) {
stop("\U0001F6D1 Input must be a valid data frame.")
}
Expand Down Expand Up @@ -478,26 +478,26 @@ batchLLM <- function(df,
#' }
get_batches <- function(df_name = NULL, log_name = "batchLLM-log") {
log_file <- paste0(log_name, ".rds")

if (!file.exists(log_file)) {
stop("\u26A0\uFE0F Log file does not exist.")
}

batch_log <- readRDS(log_file)

if (is.null(df_name)) {
stop(paste(
"\U0001F6D1 Please define 'df_name' with a valid name:",
paste(unique(scrape_metadata()$df), collapse = ", ")
))
}

if (!df_name %in% names(batch_log$data)) {
stop(paste0("\U0001F6D1 No data found in ", log_name, ".rds for the specified df_name"))
}

output <- batch_log$data[[df_name]]$output

return(output)
}

Expand Down Expand Up @@ -526,7 +526,7 @@ get_batches <- function(df_name = NULL, log_name = "batchLLM-log") {
#' head(custom_metadata)
scrape_metadata <- function(df_name = NULL, log_name = "batchLLM-log") {
log_file <- paste0(log_name, ".rds")

if (!file.exists(log_file)) {
warning("\u26A0\uFE0F Log file does not exist.")
return(data.frame(
Expand All @@ -543,9 +543,9 @@ scrape_metadata <- function(df_name = NULL, log_name = "batchLLM-log") {
stringsAsFactors = FALSE
))
}

batch_log <- readRDS(log_file)

scrape_df <- function(df_name, df_data) {
if (is.null(df_data$metadata) || length(df_data$metadata) == 0) {
return(NULL)
Expand Down Expand Up @@ -574,7 +574,7 @@ scrape_metadata <- function(df_name = NULL, log_name = "batchLLM-log") {
})
do.call(rbind, unlist(metadata_list, recursive = FALSE))
}

if (!is.null(df_name)) {
if (!df_name %in% names(batch_log$data)) {
message(paste("\u26A0\uFE0F No data found for", df_name, "in the log file."))
Expand All @@ -589,3 +589,135 @@ scrape_metadata <- function(df_name = NULL, log_name = "batchLLM-log") {
return(metadata_df[!is.na(metadata_df$batch_number), ])
}
}

#' @title Interact with Anthropic's Claude API
#'
#' @description This function provides an interface to interact with Claude AI models via Anthropic's API, allowing for flexible text generation based on user inputs.
#' This function was adapted from the [claudeR](https://github.com/yrvelez/claudeR) repository by [yrvelez](https://github.com/yrvelez) on GitHub (MIT License).
#'
#' @param api_key Your API key for authentication.
#' @param prompt A string vector for Claude-2, or a list for Claude-3 specifying the input for the model.
#' @param model The model to use for the request. Default is the latest Claude-3 model.
#' @param max_tokens A maximum number of tokens to generate before stopping.
#' @param stop_sequences Optional. A list of strings upon which to stop generating.
#' @param temperature Optional. Amount of randomness injected into the response.
#' @param top_k Optional. Only sample from the top K options for each subsequent token.
#' @param top_p Optional. Does nucleus sampling.
#' @param system_prompt Optional. An optional system role specification.
#' @return The resulting completion up to and excluding the stop sequences.
#' @importFrom httr add_headers POST content http_status
#' @importFrom jsonlite fromJSON toJSON
#' @export
#'
#' @examples
#' \dontrun{
#' library(batchLLM)
#'
#' # Set API in the env or use api_key parameter in the claudeR call
#' Sys.setenv(ANTHROPIC_API_KEY = "your_anthropic_api_key")
#'
#' # Using Claude-2
#' response <- claudeR(
#' prompt = "What is the capital of France?",
#' model = "claude-2.1",
#' max_tokens = 50
#' )
#' cat(response)
#'
#' # Using Claude-3
#' response <- claudeR(
#' prompt = list(
#' list(role = "user", content = "What is the capital of France?")
#' ),
#' model = "claude-3-5-sonnet-20240620",
#' max_tokens = 50,
#' temperature = 0.8
#' )
#' cat(response)
#'
#' # Using a system prompt
#' response <- claudeR(
#' prompt = list(
#' list(role = "user", content = "Summarize the history of France in one paragraph.")
#' ),
#' system_prompt = "You are a concise summarization assistant.",
#' max_tokens = 500
#' )
#' cat(response)
#' }
claudeR <- function(
prompt,
model = "claude-3-5-sonnet-20240620",
max_tokens = 500,
stop_sequences = NULL,
temperature = .7,
top_k = -1,
top_p = -1,
api_key = NULL,
system_prompt = NULL) {
if (grepl("claude-3", model) && !is.list(prompt)) {
stop("\U0001F6D1 Claude-3 requires the input in a list format, e.g., list(list(role = \"user\", content = \"What is the capital of France?\"))")
}
if (is.null(api_key)) {
api_key <- Sys.getenv("ANTHROPIC_API_KEY")
if (api_key == "") {
stop("\U0001F6D1 Please provide an API key or set it as the ANTHROPIC_API_KEY environment variable.")
}
}
if (grepl("claude-2", model)) {
url <- "https://api.anthropic.com/v1/complete"
headers <- httr::add_headers(
"X-API-Key" = api_key,
"Content-Type" = "application/json",
"anthropic-version" = "2023-06-01"
)
prompt <- paste0("\n\nHuman: ", prompt, "\n\nAssistant: ")
stop_sequences <- "\n\nHuman: "
body <- paste0('{
"prompt": "', gsub("\n", "\\\\n", prompt), '",
"model": "', model, '",
"max_tokens_to_sample": ', max_tokens, ',
"stop_sequences": ["', paste(gsub("\n", "\\\\n", stop_sequences), collapse = '", "'), '"],
"temperature": ', temperature, ',
"top_k": ', top_k, ',
"top_p": ', top_p, "
}")
response <- httr::POST(url, headers, body = body)
if (httr::http_status(response)$category == "Success") {
result <- jsonlite::fromJSON(httr::content(response, "text", encoding = "UTF-8"))
return(trimws(result$completion))
} else {
warning(paste("\U0001F6D1 API request failed with status", httr::http_status(response)$message))
stop("\U0001F6D1 Error details:\n", httr::content(response, "text", encoding = "UTF-8"), "\n")
}
}
url <- "https://api.anthropic.com/v1/messages"
headers <- httr::add_headers(
"x-api-key" = api_key,
"anthropic-version" = "2023-06-01",
"Content-Type" = "application/json"
)
message_list <- lapply(prompt, function(msg) {
list(role = msg$role, content = msg$content)
})
request_body_list <- list(
model = model,
max_tokens = max_tokens,
temperature = temperature,
top_k = top_k,
top_p = top_p,
messages = message_list
)
if (!is.null(system_prompt)) {
request_body_list$system <- system_prompt
}
body <- jsonlite::toJSON(request_body_list, auto_unbox = TRUE)
response <- httr::POST(url, headers, body = body)
if (httr::http_status(response)$category == "Success") {
result <- jsonlite::fromJSON(httr::content(response, "text", encoding = "UTF-8"))
return(result$content$text)
} else {
warning(paste("\U0001F6D1 API request failed with status", httr::http_status(response)$message))
stop("\U0001F6D1 Error details:\n", httr::content(response, "text", encoding = "UTF-8"), "\n")
}
}

0 comments on commit e669554

Please sign in to comment.