Skip to content

Commit

Permalink
Langchain::Assistant support for Google Gemini LLM (#513)
Browse files Browse the repository at this point in the history
* Langchain::Assistant support for Google Gemini LLM

* wip

* fix

* fix

* wip

* remove gemini-ai gem

* Fix specs, add code comments

* Deprecating warning for Langchain::LLM::GooglePalm

* Update openai_validator.rb

* Adding code comments

* Additional Assistant specs

* Specs for GoogleGeminiResponse

* More specs

* Langchain::Tool::NewsRetriever draft tool to access and pull latest news

* cleanup

* add logging

* Refactor NewsRetriever methods

* Adding method annotations

* Clean up GoogleVertexAI LLM

* Remove binding.pry

* Missing gemfile.lock

* Google Gemini: Use UnifiedParameters and raise error if response is empty

* Fix spec

* Remove not needed spec

* Fix specs

* changelog entry

* Fixes

* Replace HTTParty with Net::HTTP

* Update CHANGELOG.md

* Update CHANGELOG.md

* Use UnifiedParameters
  • Loading branch information
andreibondarev authored May 14, 2024
1 parent 341d4d8 commit 950df37
Show file tree
Hide file tree
Showing 38 changed files with 1,169 additions and 425 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ GOOGLE_PALM_API_KEY=
GOOGLE_VERTEX_AI_PROJECT_ID=
# Automagical name which is picked up by Google Cloud Ruby auth module. If set, takes auth token from that key file.
GOOGLE_CLOUD_CREDENTIALS=
GOOGLE_GEMINI_API_KEY=
HUGGING_FACE_API_KEY=
LLAMACPP_MODEL_PATH=
LLAMACPP_N_THREADS=
LLAMACPP_N_GPU_LAYERS=
MILVUS_URL=
MISTRAL_AI_API_KEY=
NEWS_API_KEY=
OLLAMA_URL=http://localhost:11434
OPENAI_API_KEY=
OPEN_WEATHER_API_KEY=
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
## [Unreleased]

## [0.13.0] - 2024-05-14
- New 🛠️ `Langchain::Tool::NewsRetriever` tool to fetch news via newsapi.org
- Langchain::Assistant works with `Langchain::LLM::GoogleVertexAI` and `Langchain::LLM::GoogleGemini` llms
- [BREAKING] Introduce new `Langchain::Messages::Base` abstraction

## [0.12.1] - 2024-05-13
- Langchain::LLM::Ollama now uses `llama3` by default
- Langchain::LLM::Anthropic#complete() now uses `claude-2.1` by default
Expand Down
26 changes: 4 additions & 22 deletions Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ GEM
rexml
crass (1.0.6)
date (3.3.4)
declarative (0.0.20)
diff-lcs (1.5.1)
docx (0.8.0)
nokogiri (~> 1.13, >= 1.13.0)
Expand Down Expand Up @@ -162,22 +161,12 @@ GEM
faraday (~> 2.0)
typhoeus (~> 1.4)
ffi (1.16.3)
google-apis-aiplatform_v1 (0.13.0)
google-apis-core (>= 0.12.0, < 2.a)
google-apis-core (0.13.0)
addressable (~> 2.5, >= 2.5.1)
googleauth (~> 1.9)
httpclient (>= 2.8.1, < 3.a)
mini_mime (~> 1.0)
representable (~> 3.0)
retriable (>= 2.0, < 4.a)
rexml
google-cloud-env (2.1.1)
faraday (>= 1.0, < 3.a)
google_palm_api (0.1.3)
faraday (>= 2.0.1, < 3.0)
google_search_results (2.0.1)
googleauth (1.10.0)
googleauth (1.11.0)
faraday (>= 1.0, < 3.a)
google-cloud-env (~> 2.1)
jwt (>= 1.4, < 3.0)
Expand All @@ -198,7 +187,6 @@ GEM
httparty (0.21.0)
mini_mime (>= 1.0.0)
multi_xml (>= 0.5.2)
httpclient (2.8.3)
hugging-face (0.3.5)
faraday (>= 1.0)
i18n (1.14.1)
Expand All @@ -212,7 +200,8 @@ GEM
json (2.7.1)
json-schema (4.0.0)
addressable (>= 2.8)
jwt (2.7.1)
jwt (2.8.1)
base64
language_server-protocol (3.17.0.3)
lint_roller (1.1.0)
llama_cpp (0.9.5)
Expand Down Expand Up @@ -341,12 +330,7 @@ GEM
faraday (>= 1.0)
faraday-multipart
faraday-retry
representable (3.2.0)
declarative (< 0.1.0)
trailblazer-option (>= 0.1.1, < 0.2.0)
uber (< 0.2.0)
require-hooks (0.2.2)
retriable (3.1.2)
rexml (3.2.6)
roo (2.10.1)
nokogiri (~> 1)
Expand Down Expand Up @@ -430,7 +414,6 @@ GEM
tiktoken_ruby (0.0.8-x86_64-linux-musl)
timeout (0.4.1)
to_bool (2.0.0)
trailblazer-option (0.1.2)
treetop (1.6.12)
polyglot (~> 0.3)
ttfunk (1.8.0)
Expand All @@ -439,7 +422,6 @@ GEM
ethon (>= 0.9.0)
tzinfo (2.0.6)
concurrent-ruby (~> 1.0)
uber (0.1.0)
unicode (0.4.4.4)
unicode-display_width (2.5.0)
unparser (0.6.13)
Expand Down Expand Up @@ -481,9 +463,9 @@ DEPENDENCIES
epsilla-ruby (~> 0.0.4)
eqn (~> 1.6.5)
faraday
google-apis-aiplatform_v1 (~> 0.7)
google_palm_api (~> 0.1.3)
google_search_results (~> 2.0.0)
googleauth
hnswlib (~> 0.8.1)
hugging-face (~> 0.3.4)
langchainrb!
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
| "file_system" | Interacts with the file system | | |
| "ruby_code_interpreter" | Interprets Ruby expressions | | `gem "safe_ruby", "~> 1.0.4"` |
| "google_search" | A wrapper around Google Search | `ENV["SERPAPI_API_KEY"]` (https://serpapi.com/manage-api-key) | `gem "google_search_results", "~> 2.0.0"` |
| "news_retriever" | A wrapper around NewsApi.org | `ENV["NEWS_API_KEY"]` (https://newsapi.org/) | |
| "weather" | Calls Open Weather API to retrieve the current weather | `ENV["OPEN_WEATHER_API_KEY"]` (https://home.openweathermap.org/api_keys) | `gem "open-weather-ruby-client", "~> 0.3.0"` |
| "wikipedia" | Calls Wikipedia API to retrieve the summary | | `gem "wikipedia-client", "~> 1.17.0"` |

Expand Down Expand Up @@ -445,14 +446,14 @@ assistant = Langchain::Assistant.new(
thread: thread,
instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
tools: [
Langchain::Tool::GoogleSearch.new(api_key: ENV["SERPAPI_API_KEY"])
Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
]
)
```
### Using an Assistant
You can now add your message to an Assistant.
```ruby
assistant.add_message content: "What's the weather in New York City?"
assistant.add_message content: "What's the weather in New York, New York?"
```

Run the Assistant to generate a response.
Expand Down
2 changes: 1 addition & 1 deletion langchain.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Gem::Specification.new do |spec|
spec.add_development_dependency "elasticsearch", "~> 8.2.0"
spec.add_development_dependency "epsilla-ruby", "~> 0.0.4"
spec.add_development_dependency "eqn", "~> 1.6.5"
spec.add_development_dependency "google-apis-aiplatform_v1", "~> 0.7"
spec.add_development_dependency "googleauth"
spec.add_development_dependency "google_palm_api", "~> 0.1.3"
spec.add_development_dependency "google_search_results", "~> 2.0.0"
spec.add_development_dependency "hnswlib", "~> 0.8.1"
Expand Down
3 changes: 3 additions & 0 deletions lib/langchain.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"ai21_response" => "AI21Response",
"ai21_validator" => "AI21Validator",
"csv" => "CSV",
"google_vertex_ai" => "GoogleVertexAI",
"html" => "HTML",
"json" => "JSON",
"jsonl" => "JSONL",
Expand All @@ -21,6 +22,7 @@
"openai" => "OpenAI",
"openai_validator" => "OpenAIValidator",
"openai_response" => "OpenAIResponse",
"openai_message" => "OpenAIMessage",
"pdf" => "PDF"
)
loader.collapse("#{__dir__}/langchain/llm/response")
Expand All @@ -31,6 +33,7 @@
loader.collapse("#{__dir__}/langchain/tool/file_system")
loader.collapse("#{__dir__}/langchain/tool/google_search")
loader.collapse("#{__dir__}/langchain/tool/ruby_code_interpreter")
loader.collapse("#{__dir__}/langchain/tool/news_retriever")
loader.collapse("#{__dir__}/langchain/tool/vectorsearch")
loader.collapse("#{__dir__}/langchain/tool/weather")
loader.collapse("#{__dir__}/langchain/tool/wikipedia")
Expand Down
95 changes: 75 additions & 20 deletions lib/langchain/assistants/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ class Assistant
attr_reader :llm, :thread, :instructions
attr_accessor :tools

SUPPORTED_LLMS = [
Langchain::LLM::OpenAI,
Langchain::LLM::GoogleGemini,
Langchain::LLM::GoogleVertexAI
]

# Create a new assistant
#
# @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
Expand All @@ -19,7 +25,9 @@ def initialize(
tools: [],
instructions: nil
)
raise ArgumentError, "Invalid LLM; currently only Langchain::LLM::OpenAI is supported" unless llm.instance_of?(Langchain::LLM::OpenAI)
unless SUPPORTED_LLMS.include?(llm.class)
raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
end
raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless thread.is_a?(Langchain::Thread)
raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }

Expand All @@ -30,7 +38,10 @@ def initialize(

# The first message in the thread should be the system instructions
# TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
add_message(role: "system", content: instructions) if instructions
if llm.is_a?(Langchain::LLM::OpenAI)
add_message(role: "system", content: instructions) if instructions
end
# For Google Gemini, system instructions are added to the `system:` param in the `chat` method
end

# Add a user message to the thread
Expand Down Expand Up @@ -59,11 +70,12 @@ def run(auto_tool_execution: false)

while running
# TODO: I think we need to look at all messages and not just the last one.
case (last_message = thread.messages.last).role
when "system"
last_message = thread.messages.last

if last_message.system?
# Do nothing
running = false
when "assistant"
elsif last_message.llm?
if last_message.tool_calls.any?
if auto_tool_execution
run_tools(last_message.tool_calls)
Expand All @@ -76,11 +88,11 @@ def run(auto_tool_execution: false)
# Do nothing
running = false
end
when "user"
elsif last_message.user?
# Run it!
response = chat_with_llm

if response.tool_calls
if response.tool_calls.any?
# Re-run the while(running) loop to process the tool calls
running = true
add_message(role: response.role, tool_calls: response.tool_calls)
Expand All @@ -89,12 +101,12 @@ def run(auto_tool_execution: false)
running = false
add_message(role: response.role, content: response.chat_completion)
end
when "tool"
elsif last_message.tool?
# Run it!
response = chat_with_llm
running = true

if response.tool_calls
if response.tool_calls.any?
add_message(role: response.role, tool_calls: response.tool_calls)
elsif response.chat_completion
add_message(role: response.role, content: response.chat_completion)
Expand All @@ -121,8 +133,14 @@ def add_message_and_run(content:, auto_tool_execution: false)
# @param output [String] The output of the tool
# @return [Array<Langchain::Message>] The messages in the thread
def submit_tool_output(tool_call_id:, output:)
# TODO: Validate that `tool_call_id` is valid
add_message(role: "tool", content: output, tool_call_id: tool_call_id)
tool_role = if llm.is_a?(Langchain::LLM::OpenAI)
Langchain::Messages::OpenAIMessage::TOOL_ROLE
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE
end

# TODO: Validate that `tool_call_id` is valid by scanning messages and checking if this tool call ID was invoked
add_message(role: tool_role, content: output, tool_call_id: tool_call_id)
end

# Delete all messages in the thread
Expand Down Expand Up @@ -156,10 +174,15 @@ def instructions=(new_instructions)
def chat_with_llm
Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)

params = {messages: thread.openai_messages}
params = {messages: thread.array_of_message_hashes}

if tools.any?
params[:tools] = tools.map(&:to_openai_tools).flatten
if llm.is_a?(Langchain::LLM::OpenAI)
params[:tools] = tools.map(&:to_openai_tools).flatten
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
params[:tools] = tools.map(&:to_google_gemini_tools).flatten
params[:system] = instructions if instructions
end
# TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
params[:tool_choice] = "auto"
end
Expand All @@ -173,11 +196,11 @@ def chat_with_llm
def run_tools(tool_calls)
# Iterate over each function invocation and submit tool output
tool_calls.each do |tool_call|
tool_call_id = tool_call.dig("id")

function_name = tool_call.dig("function", "name")
tool_name, method_name = function_name.split("-")
tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI)
extract_openai_tool_call(tool_call: tool_call)
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
extract_google_gemini_tool_call(tool_call: tool_call)
end

tool_instance = tools.find do |t|
t.name == tool_name
Expand All @@ -190,13 +213,41 @@ def run_tools(tool_calls)

response = chat_with_llm

if response.tool_calls
if response.tool_calls.any?
add_message(role: response.role, tool_calls: response.tool_calls)
elsif response.chat_completion
add_message(role: response.role, content: response.chat_completion)
end
end

# Extract the tool call information from the OpenAI tool call hash
#
# @param tool_call [Hash] The tool call hash
# @return [Array] The tool call information
def extract_openai_tool_call(tool_call:)
tool_call_id = tool_call.dig("id")

function_name = tool_call.dig("function", "name")
tool_name, method_name = function_name.split("__")
tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)

[tool_call_id, tool_name, method_name, tool_arguments]
end

# Extract the tool call information from the Google Gemini tool call hash
#
# @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
# @return [Array] The tool call information
def extract_google_gemini_tool_call(tool_call:)
tool_call_id = tool_call.dig("functionCall", "name")

function_name = tool_call.dig("functionCall", "name")
tool_name, method_name = function_name.split("__")
tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)

[tool_call_id, tool_name, method_name, tool_arguments]
end

# Build a message
#
# @param role [String] The role of the message
Expand All @@ -205,7 +256,11 @@ def run_tools(tool_calls)
# @param tool_call_id [String] The ID of the tool call to include in the message
# @return [Langchain::Message] The Message object
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
Message.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
if llm.is_a?(Langchain::LLM::OpenAI)
Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end
end

# TODO: Fix the message truncation when context window is exceeded
Expand Down
Loading

0 comments on commit 950df37

Please sign in to comment.