Skip to content

Commit

Permalink
Merge pull request #4 from dylanpieper/elmer
Browse files Browse the repository at this point in the history
fix delay
  • Loading branch information
dylanpieper authored Jan 30, 2025
2 parents 200e9e1 + ab30dac commit e3a681a
Showing 1 changed file with 7 additions and 152 deletions.
159 changes: 7 additions & 152 deletions R/batchLLM.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ batchLLM <- function(df,
existing_output <- batch_log$data[[df_key]]$output
suppressMessages(
suppressWarnings(
batch_log$data[[df_key]]$output <- dplyr::left_join(existing_output, df, by = col_string) |>
dplyr::select(-dplyr::ends_with(".x")) |>
dplyr::rename_with(~ sub("\\.y$", "", .), dplyr::ends_with(".y"))
batch_log$data[[df_key]]$output <- dplyr::left_join(existing_output, df,
by = col_string,
suffix = c(".old", "")) |>
dplyr::select(-dplyr::ends_with(".old")) |>
dplyr::rename_with(~ sub("\\.keep$", "", .), dplyr::ends_with(".keep"))
)
)
}
Expand Down Expand Up @@ -280,9 +282,9 @@ batchLLM <- function(df,
delay_unit <- matches[3]

if (grepl("sec", delay_unit)) {
delay <- delay_value / 100
delay <- delay_value
} else if (grepl("min", delay_unit)) {
delay <- delay_value * 60 / 100
delay <- delay_value * 60
} else if (grepl("random", batch_delay)) {
delay <- sample(seq(0.05, 0.15, by = 0.01), 1)
} else {
Expand Down Expand Up @@ -587,150 +589,3 @@ scrape_metadata <- function(df_name = NULL, log_name = "batchLLM-log") {
return(metadata_df[!is.na(metadata_df$batch_number), ])
}
}

#' @title Interact with Anthropic's Claude API
#'
#' @description This function provides an interface to interact with Claude AI models via Anthropic's API, allowing for flexible text generation based on user inputs.
#' This function was adapted from the [claudeR](https://github.com/yrvelez/claudeR) repository by [yrvelez](https://github.com/yrvelez) on GitHub (MIT License).
#'
#' @param api_key Your API key for authentication.
#' @param prompt A string vector for Claude-2, or a list for Claude-3 specifying the input for the model.
#' @param model The model to use for the request. Default is the latest Claude-3 model.
#' @param max_tokens A maximum number of tokens to generate before stopping.
#' @param stop_sequences Optional. A list of strings upon which to stop generating.
#' @param temperature Optional. Amount of randomness injected into the response.
#' @param top_k Optional. Only sample from the top K options for each subsequent token.
#' @param top_p Optional. Does nucleus sampling.
#' @param system_prompt Optional. An optional system role specification.
#' @return The resulting completion up to and excluding the stop sequences.
#' @importFrom httr add_headers POST content http_status
#' @importFrom jsonlite fromJSON toJSON
#' @export
#'
#' @examples
#' \dontrun{
#' library(batchLLM)
#'
#' # Set API in the env or use api_key parameter in the claudeR call
#' Sys.setenv(ANTHROPIC_API_KEY = "your_anthropic_api_key")
#'
#' # Using Claude-2
#' response <- claudeR(
#' prompt = "What is the capital of France?",
#' model = "claude-2.1",
#' max_tokens = 50
#' )
#' cat(response)
#'
#' # Using Claude-3
#' response <- claudeR(
#' prompt = list(
#' list(role = "user", content = "What is the capital of France?")
#' ),
#' model = "claude-3-5-sonnet-20240620",
#' max_tokens = 50,
#' temperature = 0.8
#' )
#' cat(response)
#'
#' # Using a system prompt
#' response <- claudeR(
#' prompt = list(
#' list(role = "user", content = "Summarize the history of France in one paragraph.")
#' ),
#' system_prompt = "You are a concise summarization assistant.",
#' max_tokens = 500
#' )
#' cat(response)
#' }
claudeR <- function(
prompt,
model = "claude-3-5-sonnet-20240620",
max_tokens = 500,
stop_sequences = NULL,
temperature = .7,
top_k = -1,
top_p = -1,
api_key = NULL,
system_prompt = NULL) {
if (grepl("claude-3", model) && !is.list(prompt)) {
stop("\U0001F6D1 Claude-3 requires the input in a list format, e.g., list(list(role = \"user\", content = \"What is the capital of France?\"))")
}

if (is.null(api_key)) {
api_key <- Sys.getenv("ANTHROPIC_API_KEY")
if (api_key == "") {
stop("\U0001F6D1 Please provide an API key or set it as the ANTHROPIC_API_KEY environment variable.")
}
}

if (grepl("claude-2", model)) {
url <- "https://api.anthropic.com/v1/complete"
headers <- httr::add_headers(
"X-API-Key" = api_key,
"Content-Type" = "application/json",
"anthropic-version" = "2023-06-01"
)

prompt <- paste0("\n\nHuman: ", prompt, "\n\nAssistant: ")

stop_sequences <- "\n\nHuman: "

body <- paste0('{
"prompt": "', gsub("\n", "\\\\n", prompt), '",
"model": "', model, '",
"max_tokens_to_sample": ', max_tokens, ',
"stop_sequences": ["', paste(gsub("\n", "\\\\n", stop_sequences), collapse = '", "'), '"],
"temperature": ', temperature, ',
"top_k": ', top_k, ',
"top_p": ', top_p, "
}")

response <- httr::POST(url, headers, body = body)

if (httr::http_status(response)$category == "Success") {
result <- jsonlite::fromJSON(httr::content(response, "text", encoding = "UTF-8"))
return(trimws(result$completion))
} else {
warning(paste("\U0001F6D1 API request failed with status", httr::http_status(response)$message))
stop("\U0001F6D1 Error details:\n", httr::content(response, "text", encoding = "UTF-8"), "\n")
}
}

url <- "https://api.anthropic.com/v1/messages"

headers <- httr::add_headers(
"x-api-key" = api_key,
"anthropic-version" = "2023-06-01",
"Content-Type" = "application/json"
)

message_list <- lapply(prompt, function(msg) {
list(role = msg$role, content = msg$content)
})

request_body_list <- list(
model = model,
max_tokens = max_tokens,
temperature = temperature,
top_k = top_k,
top_p = top_p,
messages = message_list
)

if (!is.null(system_prompt)) {
request_body_list$system <- system_prompt
}

body <- jsonlite::toJSON(request_body_list, auto_unbox = TRUE)

response <- httr::POST(url, headers, body = body)

if (httr::http_status(response)$category == "Success") {
result <- jsonlite::fromJSON(httr::content(response, "text", encoding = "UTF-8"))
return(result$content$text)
} else {
warning(paste("\U0001F6D1 API request failed with status", httr::http_status(response)$message))
stop("\U0001F6D1 Error details:\n", httr::content(response, "text", encoding = "UTF-8"), "\n")
}
}

0 comments on commit e3a681a

Please sign in to comment.