Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Langchain::Assistant support for Google Gemini LLM #513

Merged
merged 46 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
7ac08a9
Langchain::Assistant support for Google Gemini LLM
andreibondarev Mar 6, 2024
abe4c10
wip
andreibondarev Mar 15, 2024
3c36865
Merge branch 'main' into google-gemini-assistant-support
andreibondarev Mar 17, 2024
49e62b5
fix
andreibondarev Mar 17, 2024
64704d9
fix
andreibondarev Mar 17, 2024
e420a8a
Merge branch 'main' into google-gemini-assistant-support
andreibondarev Mar 19, 2024
3db5d7d
Merge branch 'main' into google-gemini-assistant-support
andreibondarev Mar 25, 2024
294a130
Merge branch 'main' into google-gemini-assistant-support
andreibondarev Mar 26, 2024
2ad0f98
Merge branch 'main' into google-gemini-assistant-support
andreibondarev Mar 29, 2024
2781aef
Merge branch 'main' into google-gemini-assistant-support
andreibondarev Apr 5, 2024
002c659
Merge branch 'main' into google-gemini-assistant-support
andreibondarev Apr 17, 2024
7dc0682
Merge branch 'main' into google-gemini-assistant-support
andreibondarev Apr 21, 2024
ee96861
wip
andreibondarev Apr 22, 2024
13d2453
Merge branch 'main' into google-gemini-assistant-support
andreibondarev Apr 23, 2024
fb2acc8
remove gemini-ai gem
andreibondarev Apr 23, 2024
9927f5a
Fix specs, add code comments
andreibondarev Apr 24, 2024
b5d6cb1
Deprecating warning for Langchain::LLM::GooglePalm
andreibondarev Apr 24, 2024
a7fbcef
Update openai_validator.rb
andreibondarev Apr 24, 2024
74e7564
Merge branch 'main' into google-gemini-assistant-support
andreibondarev Apr 30, 2024
c71f0c9
Merge branch 'main' into google-gemini-assistant-support
andreibondarev May 4, 2024
d6c2ff8
Adding code comments
andreibondarev May 4, 2024
f649d41
Additional Assistant specs
andreibondarev May 4, 2024
7d21741
Specs for GoogleGeminiResponse
andreibondarev May 4, 2024
69ab96e
More specs
andreibondarev May 5, 2024
6dac617
Langchain::Tool::NewsRetriever draft tool to access and pull latest news
andreibondarev May 5, 2024
8ea9544
cleanup
andreibondarev May 5, 2024
8b42c34
add logging
andreibondarev May 5, 2024
e9d745a
Refactor NewsRetriever methods
andreibondarev May 5, 2024
c3b9bda
Adding method annotations
andreibondarev May 5, 2024
ffcb111
Merge branch 'main' into google-gemini-assistant-support
andreibondarev May 8, 2024
ff83fa5
Clean up GoogleVertexAI LLM
andreibondarev May 9, 2024
f92b9e6
Remove binding.pry
andreibondarev May 9, 2024
836bc8e
Missing gemfile.lock
andreibondarev May 9, 2024
8437063
Merge branch 'main' into google-gemini-assistant-support
andreibondarev May 11, 2024
f570ad4
Google Gemini: Use UnifiedParameters and raise error if response is e…
andreibondarev May 11, 2024
f3f84de
Merge branch 'main' into google-gemini-assistant-support
andreibondarev May 14, 2024
313fe47
Fix spec
andreibondarev May 14, 2024
b8613a0
Remove not needed spec
andreibondarev May 14, 2024
3985f99
Fix specs
andreibondarev May 14, 2024
1afdf6c
changelog entry
andreibondarev May 14, 2024
6404d93
Fixes
andreibondarev May 14, 2024
1527c10
Replace HTTParty with Net::HTTP
andreibondarev May 14, 2024
167f648
Update CHANGELOG.md
andreibondarev May 14, 2024
1112d5f
Update CHANGELOG.md
andreibondarev May 14, 2024
a0af00b
Use UnifiedParameters
andreibondarev May 14, 2024
50d4ffc
Merge branch 'main' into google-gemini-assistant-support
andreibondarev May 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ GOOGLE_PALM_API_KEY=
GOOGLE_VERTEX_AI_PROJECT_ID=
# Automagical name which is picked up by Google Cloud Ruby auth module. If set, takes auth token from that key file.
GOOGLE_CLOUD_CREDENTIALS=
GOOGLE_GEMINI_API_KEY=
HUGGING_FACE_API_KEY=
LLAMACPP_MODEL_PATH=
LLAMACPP_N_THREADS=
Expand Down
13 changes: 7 additions & 6 deletions Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ GEM
faraday (~> 2.0)
typhoeus (~> 1.4)
ffi (1.16.3)
google-apis-aiplatform_v1 (0.13.0)
google-apis-core (>= 0.12.0, < 2.a)
google-apis-core (0.13.0)
google-apis-aiplatform_v1 (0.15.0)
google-apis-core (>= 0.14.0, < 2.a)
google-apis-core (0.14.0)
addressable (~> 2.5, >= 2.5.1)
googleauth (~> 1.9)
httpclient (>= 2.8.1, < 3.a)
Expand All @@ -182,7 +182,7 @@ GEM
google_palm_api (0.1.3)
faraday (>= 2.0.1, < 3.0)
google_search_results (2.0.1)
googleauth (1.10.0)
googleauth (1.11.0)
faraday (>= 1.0, < 3.a)
google-cloud-env (~> 2.1)
jwt (>= 1.4, < 3.0)
Expand Down Expand Up @@ -217,7 +217,8 @@ GEM
json (2.7.1)
json-schema (4.0.0)
addressable (>= 2.8)
jwt (2.7.1)
jwt (2.8.1)
base64
language_server-protocol (3.17.0.3)
lint_roller (1.1.0)
llama_cpp (0.9.5)
Expand Down Expand Up @@ -486,7 +487,7 @@ DEPENDENCIES
epsilla-ruby (~> 0.0.4)
eqn (~> 1.6.5)
faraday
google-apis-aiplatform_v1 (~> 0.7)
google-apis-aiplatform_v1 (~> 0.8)
google_palm_api (~> 0.1.3)
google_search_results (~> 2.0.0)
hnswlib (~> 0.8.1)
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,14 +445,14 @@ assistant = Langchain::Assistant.new(
thread: thread,
instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
tools: [
Langchain::Tool::GoogleSearch.new(api_key: ENV["SERPAPI_API_KEY"])
Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
]
)
```
### Using an Assistant
You can now add your message to an Assistant.
```ruby
assistant.add_message content: "What's the weather in New York City?"
assistant.add_message content: "What's the weather in New York, New York?"
```

Run the Assistant to generate a response.
Expand Down
2 changes: 1 addition & 1 deletion langchain.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Gem::Specification.new do |spec|
spec.add_development_dependency "elasticsearch", "~> 8.2.0"
spec.add_development_dependency "epsilla-ruby", "~> 0.0.4"
spec.add_development_dependency "eqn", "~> 1.6.5"
spec.add_development_dependency "google-apis-aiplatform_v1", "~> 0.7"
spec.add_development_dependency "google-apis-aiplatform_v1", "~> 0.8"
spec.add_development_dependency "google_palm_api", "~> 0.1.3"
spec.add_development_dependency "google_search_results", "~> 2.0.0"
spec.add_development_dependency "hnswlib", "~> 0.8.1"
Expand Down
1 change: 1 addition & 0 deletions lib/langchain.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"openai" => "OpenAI",
"openai_validator" => "OpenAIValidator",
"openai_response" => "OpenAIResponse",
"openai_message" => "OpenAIMessage",
"pdf" => "PDF"
)
loader.collapse("#{__dir__}/langchain/llm/response")
Expand Down
85 changes: 66 additions & 19 deletions lib/langchain/assistants/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def initialize(
tools: [],
instructions: nil
)
raise ArgumentError, "Invalid LLM; currently only Langchain::LLM::OpenAI is supported" unless llm.instance_of?(Langchain::LLM::OpenAI)
raise ArgumentError, "Invalid LLM; currently only Langchain::LLM::OpenAI and Langchain::LLM::GoogleGemini are supported" unless [Langchain::LLM::OpenAI, Langchain::LLM::GoogleGemini].include?(llm.class)
raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless thread.is_a?(Langchain::Thread)
raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }

Expand All @@ -30,7 +30,10 @@ def initialize(

# The first message in the thread should be the system instructions
# TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
add_message(role: "system", content: instructions) if instructions
if llm.is_a?(Langchain::LLM::OpenAI)
add_message(role: "system", content: instructions) if instructions
end
# For Google Gemini, system instructions are added to the `system:` param in the `chat` method
end

# Add a user message to the thread
Expand Down Expand Up @@ -59,11 +62,12 @@ def run(auto_tool_execution: false)

while running
# TODO: I think we need to look at all messages and not just the last one.
case (last_message = thread.messages.last).role
when "system"
last_message = thread.messages.last

if last_message.system?
# Do nothing
running = false
when "assistant"
elsif last_message.llm?
if last_message.tool_calls.any?
if auto_tool_execution
run_tools(last_message.tool_calls)
Expand All @@ -76,11 +80,11 @@ def run(auto_tool_execution: false)
# Do nothing
running = false
end
when "user"
elsif last_message.user?
# Run it!
response = chat_with_llm

if response.tool_calls
if response.tool_calls.any?
# Re-run the while(running) loop to process the tool calls
running = true
add_message(role: response.role, tool_calls: response.tool_calls)
Expand All @@ -89,12 +93,12 @@ def run(auto_tool_execution: false)
running = false
add_message(role: response.role, content: response.chat_completion)
end
when "tool"
elsif last_message.tool?
# Run it!
response = chat_with_llm
running = true

if response.tool_calls
if response.tool_calls.any?
add_message(role: response.role, tool_calls: response.tool_calls)
elsif response.chat_completion
add_message(role: response.role, content: response.chat_completion)
Expand All @@ -121,8 +125,14 @@ def add_message_and_run(content:, auto_tool_execution: false)
# @param output [String] The output of the tool
# @return [Array<Langchain::Message>] The messages in the thread
def submit_tool_output(tool_call_id:, output:)
tool_role = if llm.is_a?(Langchain::LLM::OpenAI)
Langchain::Messages::OpenAIMessage::TOOL_ROLE
elsif llm.is_a?(Langchain::LLM::GoogleGemini)
Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE
end

# TODO: Validate that `tool_call_id` is valid
add_message(role: "tool", content: output, tool_call_id: tool_call_id)
add_message(role: tool_role, content: output, tool_call_id: tool_call_id)
end

# Delete all messages in the thread
Expand Down Expand Up @@ -156,10 +166,15 @@ def instructions=(new_instructions)
def chat_with_llm
Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)

params = {messages: thread.openai_messages}
params = {messages: thread.array_of_message_hashes}

if tools.any?
params[:tools] = tools.map(&:to_openai_tools).flatten
if llm.is_a?(Langchain::LLM::OpenAI)
params[:tools] = tools.map(&:to_openai_tools).flatten
elsif llm.is_a?(Langchain::LLM::GoogleGemini)
params[:tools] = tools.map(&:to_google_gemini_tools).flatten
params[:system] = instructions if instructions
end
# TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
params[:tool_choice] = "auto"
end
Expand All @@ -173,11 +188,11 @@ def chat_with_llm
def run_tools(tool_calls)
# Iterate over each function invocation and submit tool output
tool_calls.each do |tool_call|
tool_call_id = tool_call.dig("id")

function_name = tool_call.dig("function", "name")
tool_name, method_name = function_name.split("-")
tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI)
extract_openai_tool_call(tool_call: tool_call)
elsif llm.is_a?(Langchain::LLM::GoogleGemini)
extract_google_gemini_tool_call(tool_call: tool_call)
end

tool_instance = tools.find do |t|
t.name == tool_name
Expand All @@ -190,13 +205,41 @@ def run_tools(tool_calls)

response = chat_with_llm

if response.tool_calls
if response.tool_calls.any?
add_message(role: response.role, tool_calls: response.tool_calls)
elsif response.chat_completion
add_message(role: response.role, content: response.chat_completion)
end
end

# Extract the tool call information from the OpenAI tool call hash
#
# @param tool_call [Hash] The tool call hash
# @return [Array] The tool call information
def extract_openai_tool_call(tool_call:)
tool_call_id = tool_call.dig("id")

function_name = tool_call.dig("function", "name")
tool_name, method_name = function_name.split("__")
tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)

[tool_call_id, tool_name, method_name, tool_arguments]
end

# Extract the tool call information from the Google Gemini tool call hash
#
# @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
# @return [Array] The tool call information
def extract_google_gemini_tool_call(tool_call:)
tool_call_id = tool_call.dig("functionCall", "name")

function_name = tool_call.dig("functionCall", "name")
tool_name, method_name = function_name.split("__")
tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)

[tool_call_id, tool_name, method_name, tool_arguments]
end

# Build a message
#
# @param role [String] The role of the message
Expand All @@ -205,7 +248,11 @@ def run_tools(tool_calls)
# @param tool_call_id [String] The ID of the tool call to include in the message
# @return [Langchain::Message] The Message object
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
Message.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
if llm.is_a?(Langchain::LLM::OpenAI)
Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
elsif llm.is_a?(Langchain::LLM::GoogleGemini)
Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end
end

# TODO: Fix the message truncation when context window is exceeded
Expand Down
58 changes: 0 additions & 58 deletions lib/langchain/assistants/message.rb

This file was deleted.

16 changes: 16 additions & 0 deletions lib/langchain/assistants/messages/base.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# frozen_string_literal: true

module Langchain
module Messages
class Base
attr_reader :role, :content, :tool_calls, :tool_call_id

# Check if the message came from a user
#
# @param [Boolean] true/false whether the message came from a user
def user?
role == "user"
end
end
end
end
Loading