diff --git a/.env.example b/.env.example index acc3d2240..cdfbf6360 100644 --- a/.env.example +++ b/.env.example @@ -12,12 +12,14 @@ GOOGLE_PALM_API_KEY= GOOGLE_VERTEX_AI_PROJECT_ID= # Automagical name which is picked up by Google Cloud Ruby auth module. If set, takes auth token from that key file. GOOGLE_CLOUD_CREDENTIALS= +GOOGLE_GEMINI_API_KEY= HUGGING_FACE_API_KEY= LLAMACPP_MODEL_PATH= LLAMACPP_N_THREADS= LLAMACPP_N_GPU_LAYERS= MILVUS_URL= MISTRAL_AI_API_KEY= +NEWS_API_KEY= OLLAMA_URL=http://localhost:11434 OPENAI_API_KEY= OPEN_WEATHER_API_KEY= diff --git a/CHANGELOG.md b/CHANGELOG.md index 7871e753c..e75c333f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ ## [Unreleased] +## [0.13.0] - 2024-05-14 +- New 🛠️ `Langchain::Tool::NewsRetriever` tool to fetch news via newsapi.org +- Langchain::Assistant works with `Langchain::LLM::GoogleVertexAI` and `Langchain::LLM::GoogleGemini` llms +- [BREAKING] Introduce new `Langchain::Messages::Base` abstraction + ## [0.12.1] - 2024-05-13 - Langchain::LLM::Ollama now uses `llama3` by default - Langchain::LLM::Anthropic#complete() now uses `claude-2.1` by default diff --git a/Gemfile.lock b/Gemfile.lock index 8ec4d8004..5ad9c2f0a 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -82,7 +82,6 @@ GEM rexml crass (1.0.6) date (3.3.4) - declarative (0.0.20) diff-lcs (1.5.1) docx (0.8.0) nokogiri (~> 1.13, >= 1.13.0) @@ -162,22 +161,12 @@ GEM faraday (~> 2.0) typhoeus (~> 1.4) ffi (1.16.3) - google-apis-aiplatform_v1 (0.13.0) - google-apis-core (>= 0.12.0, < 2.a) - google-apis-core (0.13.0) - addressable (~> 2.5, >= 2.5.1) - googleauth (~> 1.9) - httpclient (>= 2.8.1, < 3.a) - mini_mime (~> 1.0) - representable (~> 3.0) - retriable (>= 2.0, < 4.a) - rexml google-cloud-env (2.1.1) faraday (>= 1.0, < 3.a) google_palm_api (0.1.3) faraday (>= 2.0.1, < 3.0) google_search_results (2.0.1) - googleauth (1.10.0) + googleauth (1.11.0) faraday (>= 1.0, < 3.a) google-cloud-env (~> 2.1) jwt (>= 1.4, < 3.0) @@ -198,7 +187,6 @@ GEM httparty (0.21.0) mini_mime (>= 1.0.0) multi_xml (>= 0.5.2) - httpclient (2.8.3) hugging-face (0.3.5) faraday (>= 1.0) i18n (1.14.1) @@ -212,7 +200,8 @@ GEM json (2.7.1) json-schema (4.0.0) addressable (>= 2.8) - jwt (2.7.1) + jwt (2.8.1) + base64 language_server-protocol (3.17.0.3) lint_roller (1.1.0) llama_cpp (0.9.5) @@ -341,12 +330,7 @@ GEM faraday (>= 1.0) faraday-multipart faraday-retry - representable (3.2.0) - declarative (< 0.1.0) - trailblazer-option (>= 0.1.1, < 0.2.0) - uber (< 0.2.0) require-hooks (0.2.2) - retriable (3.1.2) rexml (3.2.6) roo (2.10.1) nokogiri (~> 1) @@ -430,7 +414,6 @@ GEM tiktoken_ruby (0.0.8-x86_64-linux-musl) timeout (0.4.1) to_bool (2.0.0) - trailblazer-option (0.1.2) treetop (1.6.12) polyglot (~> 0.3) ttfunk (1.8.0) @@ -439,7 +422,6 @@ GEM ethon (>= 0.9.0) tzinfo (2.0.6) concurrent-ruby (~> 1.0) - uber (0.1.0) unicode (0.4.4.4) unicode-display_width (2.5.0) unparser (0.6.13) @@ -481,9 +463,9 @@ DEPENDENCIES epsilla-ruby (~> 0.0.4) eqn (~> 1.6.5) faraday - google-apis-aiplatform_v1 (~> 0.7) google_palm_api (~> 0.1.3) google_search_results (~> 2.0.0) + googleauth hnswlib (~> 0.8.1) hugging-face (~> 0.3.4) langchainrb! diff --git a/README.md b/README.md index 5fbc0b6d0..38e5c0245 100644 --- a/README.md +++ b/README.md @@ -412,6 +412,7 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool | "file_system" | Interacts with the file system | | | | "ruby_code_interpreter" | Interprets Ruby expressions | | `gem "safe_ruby", "~> 1.0.4"` | | "google_search" | A wrapper around Google Search | `ENV["SERPAPI_API_KEY"]` (https://serpapi.com/manage-api-key) | `gem "google_search_results", "~> 2.0.0"` | +| "news_retriever" | A wrapper around NewsApi.org | `ENV["NEWS_API_KEY"]` (https://newsapi.org/) | | | "weather" | Calls Open Weather API to retrieve the current weather | `ENV["OPEN_WEATHER_API_KEY"]` (https://home.openweathermap.org/api_keys) | `gem "open-weather-ruby-client", "~> 0.3.0"` | | "wikipedia" | Calls Wikipedia API to retrieve the summary | | `gem "wikipedia-client", "~> 1.17.0"` | @@ -445,14 +446,14 @@ assistant = Langchain::Assistant.new( thread: thread, instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location", tools: [ - Langchain::Tool::GoogleSearch.new(api_key: ENV["SERPAPI_API_KEY"]) + Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"]) ] ) ``` ### Using an Assistant You can now add your message to an Assistant. ```ruby -assistant.add_message content: "What's the weather in New York City?" +assistant.add_message content: "What's the weather in New York, New York?" ``` Run the Assistant to generate a response. diff --git a/langchain.gemspec b/langchain.gemspec index 82807f808..900521d63 100644 --- a/langchain.gemspec +++ b/langchain.gemspec @@ -55,7 +55,7 @@ Gem::Specification.new do |spec| spec.add_development_dependency "elasticsearch", "~> 8.2.0" spec.add_development_dependency "epsilla-ruby", "~> 0.0.4" spec.add_development_dependency "eqn", "~> 1.6.5" - spec.add_development_dependency "google-apis-aiplatform_v1", "~> 0.7" + spec.add_development_dependency "googleauth" spec.add_development_dependency "google_palm_api", "~> 0.1.3" spec.add_development_dependency "google_search_results", "~> 2.0.0" spec.add_development_dependency "hnswlib", "~> 0.8.1" diff --git a/lib/langchain.rb b/lib/langchain.rb index e796ba866..48b8b5180 100644 --- a/lib/langchain.rb +++ b/lib/langchain.rb @@ -12,6 +12,7 @@ "ai21_response" => "AI21Response", "ai21_validator" => "AI21Validator", "csv" => "CSV", + "google_vertex_ai" => "GoogleVertexAI", "html" => "HTML", "json" => "JSON", "jsonl" => "JSONL", @@ -21,6 +22,7 @@ "openai" => "OpenAI", "openai_validator" => "OpenAIValidator", "openai_response" => "OpenAIResponse", + "openai_message" => "OpenAIMessage", "pdf" => "PDF" ) loader.collapse("#{__dir__}/langchain/llm/response") @@ -31,6 +33,7 @@ loader.collapse("#{__dir__}/langchain/tool/file_system") loader.collapse("#{__dir__}/langchain/tool/google_search") loader.collapse("#{__dir__}/langchain/tool/ruby_code_interpreter") +loader.collapse("#{__dir__}/langchain/tool/news_retriever") loader.collapse("#{__dir__}/langchain/tool/vectorsearch") loader.collapse("#{__dir__}/langchain/tool/weather") loader.collapse("#{__dir__}/langchain/tool/wikipedia") diff --git a/lib/langchain/assistants/assistant.rb b/lib/langchain/assistants/assistant.rb index bcbfb06ad..92c959386 100644 --- a/lib/langchain/assistants/assistant.rb +++ b/lib/langchain/assistants/assistant.rb @@ -7,6 +7,12 @@ class Assistant attr_reader :llm, :thread, :instructions attr_accessor :tools + SUPPORTED_LLMS = [ + Langchain::LLM::OpenAI, + Langchain::LLM::GoogleGemini, + Langchain::LLM::GoogleVertexAI + ] + # Create a new assistant # # @param llm [Langchain::LLM::Base] LLM instance that the assistant will use @@ -19,7 +25,9 @@ def initialize( tools: [], instructions: nil ) - raise ArgumentError, "Invalid LLM; currently only Langchain::LLM::OpenAI is supported" unless llm.instance_of?(Langchain::LLM::OpenAI) + unless SUPPORTED_LLMS.include?(llm.class) + raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported" + end raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless thread.is_a?(Langchain::Thread) raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) } @@ -30,7 +38,10 @@ def initialize( # The first message in the thread should be the system instructions # TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions? - add_message(role: "system", content: instructions) if instructions + if llm.is_a?(Langchain::LLM::OpenAI) + add_message(role: "system", content: instructions) if instructions + end + # For Google Gemini, system instructions are added to the `system:` param in the `chat` method end # Add a user message to the thread @@ -59,11 +70,12 @@ def run(auto_tool_execution: false) while running # TODO: I think we need to look at all messages and not just the last one. - case (last_message = thread.messages.last).role - when "system" + last_message = thread.messages.last + + if last_message.system? # Do nothing running = false - when "assistant" + elsif last_message.llm? if last_message.tool_calls.any? if auto_tool_execution run_tools(last_message.tool_calls) @@ -76,11 +88,11 @@ def run(auto_tool_execution: false) # Do nothing running = false end - when "user" + elsif last_message.user? # Run it! response = chat_with_llm - if response.tool_calls + if response.tool_calls.any? # Re-run the while(running) loop to process the tool calls running = true add_message(role: response.role, tool_calls: response.tool_calls) @@ -89,12 +101,12 @@ def run(auto_tool_execution: false) running = false add_message(role: response.role, content: response.chat_completion) end - when "tool" + elsif last_message.tool? # Run it! response = chat_with_llm running = true - if response.tool_calls + if response.tool_calls.any? add_message(role: response.role, tool_calls: response.tool_calls) elsif response.chat_completion add_message(role: response.role, content: response.chat_completion) @@ -121,8 +133,14 @@ def add_message_and_run(content:, auto_tool_execution: false) # @param output [String] The output of the tool # @return [Array] The messages in the thread def submit_tool_output(tool_call_id:, output:) - # TODO: Validate that `tool_call_id` is valid - add_message(role: "tool", content: output, tool_call_id: tool_call_id) + tool_role = if llm.is_a?(Langchain::LLM::OpenAI) + Langchain::Messages::OpenAIMessage::TOOL_ROLE + elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class) + Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE + end + + # TODO: Validate that `tool_call_id` is valid by scanning messages and checking if this tool call ID was invoked + add_message(role: tool_role, content: output, tool_call_id: tool_call_id) end # Delete all messages in the thread @@ -156,10 +174,15 @@ def instructions=(new_instructions) def chat_with_llm Langchain.logger.info("Sending a call to #{llm.class}", for: self.class) - params = {messages: thread.openai_messages} + params = {messages: thread.array_of_message_hashes} if tools.any? - params[:tools] = tools.map(&:to_openai_tools).flatten + if llm.is_a?(Langchain::LLM::OpenAI) + params[:tools] = tools.map(&:to_openai_tools).flatten + elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class) + params[:tools] = tools.map(&:to_google_gemini_tools).flatten + params[:system] = instructions if instructions + end # TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it. params[:tool_choice] = "auto" end @@ -173,11 +196,11 @@ def chat_with_llm def run_tools(tool_calls) # Iterate over each function invocation and submit tool output tool_calls.each do |tool_call| - tool_call_id = tool_call.dig("id") - - function_name = tool_call.dig("function", "name") - tool_name, method_name = function_name.split("-") - tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true) + tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI) + extract_openai_tool_call(tool_call: tool_call) + elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class) + extract_google_gemini_tool_call(tool_call: tool_call) + end tool_instance = tools.find do |t| t.name == tool_name @@ -190,13 +213,41 @@ def run_tools(tool_calls) response = chat_with_llm - if response.tool_calls + if response.tool_calls.any? add_message(role: response.role, tool_calls: response.tool_calls) elsif response.chat_completion add_message(role: response.role, content: response.chat_completion) end end + # Extract the tool call information from the OpenAI tool call hash + # + # @param tool_call [Hash] The tool call hash + # @return [Array] The tool call information + def extract_openai_tool_call(tool_call:) + tool_call_id = tool_call.dig("id") + + function_name = tool_call.dig("function", "name") + tool_name, method_name = function_name.split("__") + tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true) + + [tool_call_id, tool_name, method_name, tool_arguments] + end + + # Extract the tool call information from the Google Gemini tool call hash + # + # @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}} + # @return [Array] The tool call information + def extract_google_gemini_tool_call(tool_call:) + tool_call_id = tool_call.dig("functionCall", "name") + + function_name = tool_call.dig("functionCall", "name") + tool_name, method_name = function_name.split("__") + tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym) + + [tool_call_id, tool_name, method_name, tool_arguments] + end + # Build a message # # @param role [String] The role of the message @@ -205,7 +256,11 @@ def run_tools(tool_calls) # @param tool_call_id [String] The ID of the tool call to include in the message # @return [Langchain::Message] The Message object def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil) - Message.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) + if llm.is_a?(Langchain::LLM::OpenAI) + Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) + elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class) + Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) + end end # TODO: Fix the message truncation when context window is exceeded diff --git a/lib/langchain/assistants/message.rb b/lib/langchain/assistants/message.rb deleted file mode 100644 index f486eba02..000000000 --- a/lib/langchain/assistants/message.rb +++ /dev/null @@ -1,58 +0,0 @@ -# frozen_string_literal: true - -module Langchain - # Langchain::Message are the messages that are sent to LLM chat methods - class Message - attr_reader :role, :content, :tool_calls, :tool_call_id - - ROLES = %w[ - system - assistant - user - tool - ].freeze - - # @param role [String] The role of the message - # @param content [String] The content of the message - # @param tool_calls [Array] Tool calls to be made - # @param tool_call_id [String] The ID of the tool call to be made - def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content) - raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role) - raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) } - - @role = role - # Some Tools return content as a JSON hence `.to_s` - @content = content.to_s - @tool_calls = tool_calls - @tool_call_id = tool_call_id - end - - # Convert the message to an OpenAI API-compatible hash - # - # @return [Hash] The message as an OpenAI API-compatible hash - def to_openai_format - {}.tap do |h| - h[:role] = role - h[:content] = content if content # Content is nil for tool calls - h[:tool_calls] = tool_calls if tool_calls.any? - h[:tool_call_id] = tool_call_id if tool_call_id - end - end - - def assistant? - role == "assistant" - end - - def system? - role == "system" - end - - def user? - role == "user" - end - - def tool? - role == "tool" - end - end -end diff --git a/lib/langchain/assistants/messages/base.rb b/lib/langchain/assistants/messages/base.rb new file mode 100644 index 000000000..f889de0f1 --- /dev/null +++ b/lib/langchain/assistants/messages/base.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +module Langchain + module Messages + class Base + attr_reader :role, :content, :tool_calls, :tool_call_id + + # Check if the message came from a user + # + # @param [Boolean] true/false whether the message came from a user + def user? + role == "user" + end + end + end +end diff --git a/lib/langchain/assistants/messages/google_gemini_message.rb b/lib/langchain/assistants/messages/google_gemini_message.rb new file mode 100644 index 000000000..cf3f0fe04 --- /dev/null +++ b/lib/langchain/assistants/messages/google_gemini_message.rb @@ -0,0 +1,90 @@ +# frozen_string_literal: true + +module Langchain + module Messages + class GoogleGeminiMessage < Base + # Google Gemini uses the following roles: + ROLES = [ + "user", + "model", + "function" + ].freeze + + TOOL_ROLE = "function" + + # Initialize a new Google Gemini message + # + # @param [String] The role of the message + # @param [String] The content of the message + # @param [Array] The tool calls made in the message + # @param [String] The ID of the tool call + def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) + raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role) + raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) } + + @role = role + # Some Tools return content as a JSON hence `.to_s` + @content = content.to_s + @tool_calls = tool_calls + @tool_call_id = tool_call_id + end + + # Check if the message came from an LLM + # + # @return [Boolean] true/false whether this message was produced by an LLM + def llm? + model? + end + + # Convert the message to a Google Gemini API-compatible hash + # + # @return [Hash] The message as a Google Gemini API-compatible hash + def to_hash + {}.tap do |h| + h[:role] = role + h[:parts] = if function? + [{ + functionResponse: { + name: tool_call_id, + response: { + name: tool_call_id, + content: content + } + } + }] + elsif tool_calls.any? + tool_calls + else + [{text: content}] + end + end + end + + # Google Gemini does not implement system prompts + def system? + false + end + + # Check if the message is a tool call + # + # @return [Boolean] true/false whether this message is a tool call + def tool? + function? + end + + # Check if the message is a tool call + # + # @return [Boolean] true/false whether this message is a tool call + def function? + role == "function" + end + + # Check if the message came from an LLM + # + # @return [Boolean] true/false whether this message was produced by an LLM + def model? + role == "model" + end + end + end +end diff --git a/lib/langchain/assistants/messages/openai_message.rb b/lib/langchain/assistants/messages/openai_message.rb new file mode 100644 index 000000000..b673ff377 --- /dev/null +++ b/lib/langchain/assistants/messages/openai_message.rb @@ -0,0 +1,74 @@ +# frozen_string_literal: true + +module Langchain + module Messages + class OpenAIMessage < Base + # OpenAI uses the following roles: + ROLES = [ + "system", + "assistant", + "user", + "tool" + ].freeze + + TOOL_ROLE = "tool" + + # Initialize a new OpenAI message + # + # @param [String] The role of the message + # @param [String] The content of the message + # @param [Array] The tool calls made in the message + # @param [String] The ID of the tool call + def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content) + raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role) + raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) } + + @role = role + # Some Tools return content as a JSON hence `.to_s` + @content = content.to_s + @tool_calls = tool_calls + @tool_call_id = tool_call_id + end + + # Check if the message came from an LLM + # + # @return [Boolean] true/false whether this message was produced by an LLM + def llm? + assistant? + end + + # Convert the message to an OpenAI API-compatible hash + # + # @return [Hash] The message as an OpenAI API-compatible hash + def to_hash + {}.tap do |h| + h[:role] = role + h[:content] = content if content # Content is nil for tool calls + h[:tool_calls] = tool_calls if tool_calls.any? + h[:tool_call_id] = tool_call_id if tool_call_id + end + end + + # Check if the message came from an LLM + # + # @return [Boolean] true/false whether this message was produced by an LLM + def assistant? + role == "assistant" + end + + # Check if the message are system instructions + # + # @return [Boolean] true/false whether this message are system instructions + def system? + role == "system" + end + + # Check if the message is a tool call + # + # @return [Boolean] true/false whether this message is a tool call + def tool? + role == "tool" + end + end + end +end diff --git a/lib/langchain/assistants/thread.rb b/lib/langchain/assistants/thread.rb index 1c2023843..84a30a8b0 100644 --- a/lib/langchain/assistants/thread.rb +++ b/lib/langchain/assistants/thread.rb @@ -8,16 +8,16 @@ class Thread # @param messages [Array] def initialize(messages: []) - raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Message) } + raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Messages::Base) } @messages = messages end - # Convert the thread to an OpenAI API-compatible array of hashes + # Convert the thread to an LLM APIs-compatible array of hashes # # @return [Array] The thread as an OpenAI API-compatible array of hashes - def openai_messages - messages.map(&:to_openai_format) + def array_of_message_hashes + messages.map(&:to_hash) end # Add a message to the thread @@ -25,7 +25,7 @@ def openai_messages # @param message [Langchain::Message] The message to add # @return [Array] The updated messages array def add_message(message) - raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Message) + raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Messages::Base) # Prepend the message to the thread messages << message diff --git a/lib/langchain/llm/base.rb b/lib/langchain/llm/base.rb index f07445230..2f58752d4 100644 --- a/lib/langchain/llm/base.rb +++ b/lib/langchain/llm/base.rb @@ -11,7 +11,8 @@ class ApiError < StandardError; end # - {Langchain::LLM::Azure} # - {Langchain::LLM::Cohere} # - {Langchain::LLM::GooglePalm} - # - {Langchain::LLM::GoogleVertexAi} + # - {Langchain::LLM::GoogleVertexAI} + # - {Langchain::LLM::GoogleGemini} # - {Langchain::LLM::HuggingFace} # - {Langchain::LLM::LlamaCpp} # - {Langchain::LLM::OpenAI} diff --git a/lib/langchain/llm/google_gemini.rb b/lib/langchain/llm/google_gemini.rb new file mode 100644 index 000000000..5be784daf --- /dev/null +++ b/lib/langchain/llm/google_gemini.rb @@ -0,0 +1,67 @@ +# frozen_string_literal: true + +module Langchain::LLM + # Usage: + # llm = Langchain::LLM::GoogleGemini.new(api_key: ENV['GOOGLE_GEMINI_API_KEY']) + class GoogleGemini < Base + DEFAULTS = { + chat_completion_model_name: "gemini-1.5-pro-latest", + temperature: 0.0 + } + + attr_reader :defaults, :api_key + + def initialize(api_key:, default_options: {}) + @api_key = api_key + @defaults = DEFAULTS.merge(default_options) + + chat_parameters.update( + model: {default: @defaults[:chat_completion_model_name]}, + temperature: {default: @defaults[:temperature]} + ) + chat_parameters.remap( + messages: :contents, + system: :system_instruction, + tool_choice: :tool_config + ) + end + + # Generate a chat completion for a given prompt + # + # @param messages [Array] List of messages comprising the conversation so far + # @param model [String] The model to use + # @param tools [Array] A list of Tools the model may use to generate the next response + # @param tool_choice [String] Specifies the mode in which function calling should execute. If unspecified, the default value will be set to AUTO. Possible values: AUTO, ANY, NONE + # @param system [String] Developer set system instruction + def chat(params = {}) + params[:system] = {parts: [{text: params[:system]}]} if params[:system] + params[:tools] = {function_declarations: params[:tools]} if params[:tools] + params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice] + + raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty? + + parameters = chat_parameters.to_params(params) + parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature] + + uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}") + + request = Net::HTTP::Post.new(uri) + request.content_type = "application/json" + request.body = parameters.to_json + + response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http| + http.request(request) + end + + parsed_response = JSON.parse(response.body) + + wrapped_response = Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model]) + + if wrapped_response.chat_completion || Array(wrapped_response.tool_calls).any? + wrapped_response + else + raise StandardError.new(response) + end + end + end +end diff --git a/lib/langchain/llm/google_vertex_ai.rb b/lib/langchain/llm/google_vertex_ai.rb index 50c5567c0..4c936c372 100644 --- a/lib/langchain/llm/google_vertex_ai.rb +++ b/lib/langchain/llm/google_vertex_ai.rb @@ -2,150 +2,106 @@ module Langchain::LLM # - # Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai?hl=en + # Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai # # Gem requirements: - # gem "google-apis-aiplatform_v1", "~> 0.7" + # gem "googleauth" # # Usage: - # google_palm = Langchain::LLM::GoogleVertexAi.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"]) + # llm = Langchain::LLM::GoogleVertexAI.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"], region: "us-central1") # - class GoogleVertexAi < Base + class GoogleVertexAI < Base DEFAULTS = { - temperature: 0.1, # 0.1 is the default in the API, quite low ("grounded") + temperature: 0.1, max_output_tokens: 1000, top_p: 0.8, top_k: 40, dimensions: 768, - completion_model_name: "text-bison", # Optional: tect-bison@001 - embeddings_model_name: "textembedding-gecko" + embeddings_model_name: "textembedding-gecko", + chat_completion_model_name: "gemini-1.0-pro" }.freeze - # TODO: Implement token length validation - # LENGTH_VALIDATOR = Langchain::Utils::TokenLength::... - # Google Cloud has a project id and a specific region of deployment. # For GenAI-related things, a safe choice is us-central1. - attr_reader :project_id, :client, :region - - def initialize(project_id:, default_options: {}) - depends_on "google-apis-aiplatform_v1" + attr_reader :defaults, :url, :authorizer - @project_id = project_id - @region = default_options.fetch :region, "us-central1" + def initialize(project_id:, region:, default_options: {}) + depends_on "googleauth" - @client = Google::Apis::AiplatformV1::AiplatformService.new - - # TODO: Adapt for other regions; Pass it in via the constructor - # For the moment only us-central1 available so no big deal. - @client.root_url = "https://#{@region}-aiplatform.googleapis.com/" - @client.authorization = Google::Auth.get_application_default + @authorizer = ::Google::Auth.get_application_default + proj_id = project_id || @authorizer.project_id || @authorizer.quota_project_id + @url = "https://#{region}-aiplatform.googleapis.com/v1/projects/#{proj_id}/locations/#{region}/publishers/google/models/" @defaults = DEFAULTS.merge(default_options) + + chat_parameters.update( + model: {default: @defaults[:chat_completion_model_name]}, + temperature: {default: @defaults[:temperature]} + ) + chat_parameters.remap( + messages: :contents, + system: :system_instruction, + tool_choice: :tool_config + ) end # # Generate an embedding for a given text # # @param text [String] The text to generate an embedding for - # @return [Langchain::LLM::GoogleVertexAiResponse] Response object + # @param model [String] ID of the model to use + # @return [Langchain::LLM::GoogleGeminiResponse] Response object # - def embed(text:) - content = [{content: text}] - request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new(instances: content) - - api_path = "projects/#{@project_id}/locations/us-central1/publishers/google/models/#{@defaults[:embeddings_model_name]}" - - # puts("api_path: #{api_path}") - - response = client.predict_project_location_publisher_model(api_path, request) + def embed( + text:, + model: @defaults[:embeddings_model_name] + ) + params = {instances: [{content: text}]} + + response = HTTParty.post( + "#{url}#{model}:predict", + body: params.to_json, + headers: { + "Content-Type" => "application/json", + "Authorization" => "Bearer #{@authorizer.fetch_access_token!["access_token"]}" + } + ) - Langchain::LLM::GoogleVertexAiResponse.new(response.to_h, model: @defaults[:embeddings_model_name]) + Langchain::LLM::GoogleGeminiResponse.new(response, model: model) end + # Generate a chat completion for given messages # - # Generate a completion for a given prompt - # - # @param prompt [String] The prompt to generate a completion for - # @param params extra parameters passed to GooglePalmAPI::Client#generate_text - # @return [Langchain::LLM::GooglePalmResponse] Response object - # - def complete(prompt:, **params) - default_params = { - prompt: prompt, - temperature: @defaults[:temperature], - top_k: @defaults[:top_k], - top_p: @defaults[:top_p], - max_output_tokens: @defaults[:max_output_tokens], - model: @defaults[:completion_model_name] - } - - if params[:stop_sequences] - default_params[:stop_sequences] = params.delete(:stop_sequences) + # @param messages [Array] Input messages + # @param model [String] The model that will complete your prompt + # @param tools [Array] The tools to use + # @param tool_choice [String] The tool choice to use + # @param system [String] The system instruction to use + # @return [Langchain::LLM::GoogleGeminiResponse] Response object + def chat(params = {}) + params[:system] = {parts: [{text: params[:system]}]} if params[:system] + params[:tools] = {function_declarations: params[:tools]} if params[:tools] + params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice] + + raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty? + + parameters = chat_parameters.to_params(params) + parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature] + + uri = URI("#{url}#{parameters[:model]}:generateContent") + + request = Net::HTTP::Post.new(uri) + request.content_type = "application/json" + request["Authorization"] = "Bearer #{@authorizer.fetch_access_token!["access_token"]}" + request.body = parameters.to_json + + response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http| + http.request(request) end - if params[:max_output_tokens] - default_params[:max_output_tokens] = params.delete(:max_output_tokens) - end - - # to be tested - temperature = params.delete(:temperature) || @defaults[:temperature] - max_output_tokens = default_params.fetch(:max_output_tokens, @defaults[:max_output_tokens]) - - default_params.merge!(params) - - # response = client.generate_text(**default_params) - request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new \ - instances: [{ - prompt: prompt # key used to be :content, changed to :prompt - }], - parameters: { - temperature: temperature, - maxOutputTokens: max_output_tokens, - topP: 0.8, - topK: 40 - } - - response = client.predict_project_location_publisher_model \ - "projects/#{project_id}/locations/us-central1/publishers/google/models/#{@defaults[:completion_model_name]}", - request - - Langchain::LLM::GoogleVertexAiResponse.new(response, model: default_params[:model]) - end + parsed_response = JSON.parse(response.body) - # - # Generate a summarization for a given text - # - # @param text [String] The text to generate a summarization for - # @return [String] The summarization - # - # TODO(ricc): add params for Temp, topP, topK, MaxTokens and have it default to these 4 values. - def summarize(text:) - prompt_template = Langchain::Prompt.load_from_path( - file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml") - ) - prompt = prompt_template.format(text: text) - - complete( - prompt: prompt, - # For best temperature, topP, topK, MaxTokens for summarization: see - # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-summarization - temperature: 0.2, - top_p: 0.95, - top_k: 40, - # Most models have a context length of 2048 tokens (except for the newest models, which support 4096). - max_output_tokens: 256 - ) + Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model]) end - - # def chat(...) - # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat - # Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat - # \"temperature\": 0.3,\n" - # + " \"maxDecodeSteps\": 200,\n" - # + " \"topP\": 0.8,\n" - # + " \"topK\": 40\n" - # + "}"; - # end end end diff --git a/lib/langchain/llm/response/google_gemini_response.rb b/lib/langchain/llm/response/google_gemini_response.rb new file mode 100644 index 000000000..2eaeca089 --- /dev/null +++ b/lib/langchain/llm/response/google_gemini_response.rb @@ -0,0 +1,45 @@ +# frozen_string_literal: true + +module Langchain::LLM + class GoogleGeminiResponse < BaseResponse + def initialize(raw_response, model: nil) + super(raw_response, model: model) + end + + def chat_completion + raw_response.dig("candidates", 0, "content", "parts", 0, "text") + end + + def role + raw_response.dig("candidates", 0, "content", "role") + end + + def tool_calls + if raw_response.dig("candidates", 0, "content") && raw_response.dig("candidates", 0, "content", "parts", 0).has_key?("functionCall") + raw_response.dig("candidates", 0, "content", "parts") + else + [] + end + end + + def embedding + embeddings.first + end + + def embeddings + [raw_response.dig("predictions", 0, "embeddings", "values")] + end + + def prompt_tokens + raw_response.dig("usageMetadata", "promptTokenCount") + end + + def completion_tokens + raw_response.dig("usageMetadata", "candidatesTokenCount") + end + + def total_tokens + raw_response.dig("usageMetadata", "totalTokenCount") + end + end +end diff --git a/lib/langchain/llm/response/google_vertex_ai_response.rb b/lib/langchain/llm/response/google_vertex_ai_response.rb deleted file mode 100644 index f75259195..000000000 --- a/lib/langchain/llm/response/google_vertex_ai_response.rb +++ /dev/null @@ -1,33 +0,0 @@ -# frozen_string_literal: true - -module Langchain::LLM - class GoogleVertexAiResponse < BaseResponse - attr_reader :prompt_tokens - - def initialize(raw_response, model: nil) - @prompt_tokens = prompt_tokens - super(raw_response, model: model) - end - - def completion - # completions&.dig(0, "output") - raw_response.predictions[0]["content"] - end - - def embedding - embeddings.first - end - - def completions - raw_response.predictions.map { |p| p["content"] } - end - - def total_tokens - raw_response.dig(:predictions, 0, :embeddings, :statistics, :token_count) - end - - def embeddings - [raw_response.dig(:predictions, 0, :embeddings, :values)] - end - end -end diff --git a/lib/langchain/llm/response/openai_response.rb b/lib/langchain/llm/response/openai_response.rb index ef354738a..0e3006855 100644 --- a/lib/langchain/llm/response/openai_response.rb +++ b/lib/langchain/llm/response/openai_response.rb @@ -25,7 +25,11 @@ def chat_completion end def tool_calls - chat_completions&.dig(0, "message", "tool_calls") + if chat_completions.dig(0, "message").has_key?("tool_calls") + chat_completions.dig(0, "message", "tool_calls") + else + [] + end end def embedding diff --git a/lib/langchain/tool/base.rb b/lib/langchain/tool/base.rb index 74b901f92..c118122ab 100644 --- a/lib/langchain/tool/base.rb +++ b/lib/langchain/tool/base.rb @@ -66,11 +66,21 @@ def self.logger_options # Returns the tool as a list of OpenAI formatted functions # - # @return [Hash] tool as an OpenAI tool + # @return [Array] List of hashes representing the tool as OpenAI formatted functions def to_openai_tools method_annotations end + # Returns the tool as a list of Google Gemini formatted functions + # + # @return [Array] List of hashes representing the tool as Google Gemini formatted functions + def to_google_gemini_tools + method_annotations.map do |annotation| + # Slice out only the content of the "function" key + annotation["function"] + end + end + # Return tool's method annotations as JSON # # @return [Hash] Tool's method annotations diff --git a/lib/langchain/tool/calculator/calculator.json b/lib/langchain/tool/calculator/calculator.json index 2b6511f5d..ac1f0c357 100644 --- a/lib/langchain/tool/calculator/calculator.json +++ b/lib/langchain/tool/calculator/calculator.json @@ -2,7 +2,7 @@ { "type": "function", "function": { - "name": "calculator-execute", + "name": "calculator__execute", "description": "Evaluates a pure math expression or if equation contains non-math characters (e.g.: \"12F in Celsius\") then it uses the google search calculator to evaluate the expression", "parameters": { "type": "object", diff --git a/lib/langchain/tool/database/database.json b/lib/langchain/tool/database/database.json index 8c1ce726b..801539f96 100644 --- a/lib/langchain/tool/database/database.json +++ b/lib/langchain/tool/database/database.json @@ -2,7 +2,7 @@ { "type": "function", "function": { - "name": "database-describe_tables", + "name": "database__describe_tables", "description": "Database Tool: Returns the schema for a list of tables", "parameters": { "type": "object", @@ -18,7 +18,7 @@ }, { "type": "function", "function": { - "name": "database-list_tables", + "name": "database__list_tables", "description": "Database Tool: Returns a list of tables in the database", "parameters": { "type": "object", @@ -29,7 +29,7 @@ }, { "type": "function", "function": { - "name": "database-execute", + "name": "database__execute", "description": "Database Tool: Executes a SQL query and returns the results", "parameters": { "type": "object", diff --git a/lib/langchain/tool/file_system/file_system.json b/lib/langchain/tool/file_system/file_system.json index d2d308c6b..011bc4f31 100644 --- a/lib/langchain/tool/file_system/file_system.json +++ b/lib/langchain/tool/file_system/file_system.json @@ -2,7 +2,7 @@ { "type": "function", "function": { - "name": "file_system-list_directory", + "name": "file_system__list_directory", "description": "File System Tool: Lists out the content of a specified directory", "parameters": { "type": "object", @@ -19,7 +19,7 @@ { "type": "function", "function": { - "name": "file_system-read_file", + "name": "file_system__read_file", "description": "File System Tool: Reads the contents of a file", "parameters": { "type": "object", @@ -36,7 +36,7 @@ { "type": "function", "function": { - "name": "file_system-write_to_file", + "name": "file_system__write_to_file", "description": "File System Tool: Write content to a file", "parameters": { "type": "object", diff --git a/lib/langchain/tool/news_retriever/news_retriever.json b/lib/langchain/tool/news_retriever/news_retriever.json new file mode 100644 index 000000000..1c1b22211 --- /dev/null +++ b/lib/langchain/tool/news_retriever/news_retriever.json @@ -0,0 +1,121 @@ +[ + { + "type": "function", + "function": { + "name": "news_retriever__get_everything", + "description": "News Retriever: Search through millions of articles from over 150,000 large and small news sources and blogs.", + "parameters": { + "type": "object", + "properties": { + "q": { + "type": "string", + "description": "Keywords or phrases to search for in the article title and body. Surround phrases with quotes (\") for exact match. Alternatively you can use the AND / OR / NOT keywords, and optionally group these with parenthesis. Must be URL-encoded." + }, + "search_in": { + "type": "string", + "description": "The fields to restrict your q search to.", + "enum": ["title", "description", "content"] + }, + "sources": { + "type": "string", + "description": "A comma-seperated string of identifiers (maximum 20) for the news sources or blogs you want headlines from. Use the /sources endpoint to locate these programmatically or look at the sources index." + }, + "domains": { + "type": "string", + "description": "A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to restrict the search to." + }, + "exclude_domains": { + "type": "string", + "description": "A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to remove from the results." + }, + "from": { + "type": "string", + "description": "A date and optional time for the oldest article allowed. This should be in ISO 8601 format." + }, + "to": { + "type": "string", + "description": "A date and optional time for the newest article allowed. This should be in ISO 8601 format." + }, + "language": { + "type": "string", + "description": "The 2-letter ISO-639-1 code of the language you want to get headlines for.", + "enum": ["ar", "de", "en", "es", "fr", "he", "it", "nl", "no", "pt", "ru", "sv", "ud", "zh"] + }, + "sort_by": { + "type": "string", + "description": "The order to sort the articles in.", + "enum": ["relevancy", "popularity", "publishedAt"] + }, + "page_size": { + "type": "integer", + "description": "The number of results to return per page (request). 5 is the default, 100 is the maximum." + }, + "page": { + "type": "integer", + "description": "Use this to page through the results if the total results found is greater than the page size." + } + } + } + } + }, + { + "type": "function", + "function": { + "name": "news_retriever__get_top_headlines", + "description": "News Retriever: Provides live top and breaking headlines for a country, specific category in a country, single source, or multiple sources. You can also search with keywords. Articles are sorted by the earliest date published first.", + "parameters": { + "type": "object", + "properties": { + "country": { + "type": "string", + "description": "The 2-letter ISO 3166-1 code of the country you want to get headlines for." + }, + "category": { + "type": "string", + "description": "The category you want to get headlines for.", + "enum": ["business", "entertainment", "general", "health", "science", "sports", "technology"] + }, + "q": { + "type": "string", + "description": "Keywords or a phrase to search for." + }, + "page_size": { + "type": "integer", + "description": "The number of results to return per page (request). 5 is the default, 100 is the maximum." + }, + "page": { + "type": "integer", + "description": "Use this to page through the results if the total results found is greater than the page size." + } + } + } + } + }, + { + "type": "function", + "function": { + "name": "news_retriever__get_sources", + "description": "News Retriever: This endpoint returns the subset of news publishers that top headlines (/v2/top-headlines) are available from. It's mainly a convenience endpoint that you can use to keep track of the publishers available on the API, and you can pipe it straight through to your users.", + "parameters": { + "type": "object", + "properties": { + "country": { + "type": "string", + "description": "The 2-letter ISO 3166-1 code of the country you want to get headlines for. Default: all countries.", + "enum": ["ae", "ar", "at", "au", "be", "bg", "br", "ca", "ch", "cn", "co", "cu", "cz", "de", "eg", "fr", "gb", "gr", "hk", "hu", "id", "ie", "il", "in", "it", "jp", "kr", "lt", "lv", "ma", "mx", "my", "ng", "nl", "no", "nz", "ph", "pl", "pt", "ro", "rs", "ru", "sa", "se", "sg", "si", "sk", "th", "tr", "tw", "ua", "us", "ve", "za"] + }, + "category": { + "type": "string", + "description": "The category you want to get headlines for. Default: all categories.", + "enum": ["business", "entertainment", "general", "health", "science", "sports", "technology"] + }, + "language": { + "type": "string", + "description": "The 2-letter ISO-639-1 code of the language you want to get headlines for.", + "enum": ["ar", "de", "en", "es", "fr", "he", "it", "nl", "no", "pt", "ru", "sv", "ud", "zh"] + } + } + } + } + } +] diff --git a/lib/langchain/tool/news_retriever/news_retriever.rb b/lib/langchain/tool/news_retriever/news_retriever.rb new file mode 100644 index 000000000..595b6c934 --- /dev/null +++ b/lib/langchain/tool/news_retriever/news_retriever.rb @@ -0,0 +1,132 @@ +# frozen_string_literal: true + +module Langchain::Tool + class NewsRetriever < Base + # + # A tool that retrieves latest news from various sources via https://newsapi.org/. + # An API key needs to be obtained from https://newsapi.org/ to use this tool. + # + # Usage: + # news_retriever = Langchain::Tool::NewsRetriever.new(api_key: ENV["NEWS_API_KEY"]) + # + NAME = "news_retriever" + ANNOTATIONS_PATH = Langchain.root.join("./langchain/tool/#{NAME}/#{NAME}.json").to_path + + def initialize(api_key: ENV["NEWS_API_KEY"]) + @api_key = api_key + end + + # Retrieve all news + # + # @param q [String] Keywords or phrases to search for in the article title and body. + # @param search_in [String] The fields to restrict your q search to. The possible options are: title, description, content. + # @param sources [String] A comma-seperated string of identifiers (maximum 20) for the news sources or blogs you want headlines from. Use the /sources endpoint to locate these programmatically or look at the sources index. + # @param domains [String] A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to restrict the search to. + # @param exclude_domains [String] A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to remove from the results. + # @param from [String] A date and optional time for the oldest article allowed. This should be in ISO 8601 format. + # @param to [String] A date and optional time for the newest article allowed. This should be in ISO 8601 format. + # @param language [String] The 2-letter ISO-639-1 code of the language you want to get headlines for. Possible options: ar, de, en, es, fr, he, it, nl, no, pt, ru, se, ud, zh. + # @param sort_by [String] The order to sort the articles in. Possible options: relevancy, popularity, publishedAt. + # @param page_size [Integer] The number of results to return per page. 20 is the API's default, 100 is the maximum. Our default is 5. + # @param page [Integer] Use this to page through the results. + # + # @return [String] JSON response + def get_everything( + q: nil, + search_in: nil, + sources: nil, + domains: nil, + exclude_domains: nil, + from: nil, + to: nil, + language: nil, + sort_by: nil, + page_size: 5, # The API default is 20 but that's too many. + page: nil + ) + Langchain.logger.info("Retrieving all news", for: self.class) + + params = {apiKey: @api_key} + params[:q] = q if q + params[:searchIn] = search_in if search_in + params[:sources] = sources if sources + params[:domains] = domains if domains + params[:excludeDomains] = exclude_domains if exclude_domains + params[:from] = from if from + params[:to] = to if to + params[:language] = language if language + params[:sortBy] = sort_by if sort_by + params[:pageSize] = page_size if page_size + params[:page] = page if page + + send_request(path: "everything", params: params) + end + + # Retrieve top headlines + # + # @param country [String] The 2-letter ISO 3166-1 code of the country you want to get headlines for. Possible options: ae, ar, at, au, be, bg, br, ca, ch, cn, co, cu, cz, de, eg, fr, gb, gr, hk, hu, id, ie, il, in, it, jp, kr, lt, lv, ma, mx, my, ng, nl, no, nz, ph, pl, pt, ro, rs, ru, sa, se, sg, si, sk, th, tr, tw, ua, us, ve, za. + # @param category [String] The category you want to get headlines for. Possible options: business, entertainment, general, health, science, sports, technology. + # @param sources [String] A comma-seperated string of identifiers for the news sources or blogs you want headlines from. Use the /sources endpoint to locate these programmatically. + # @param q [String] Keywords or a phrase to search for. + # @param page_size [Integer] The number of results to return per page. 20 is the API's default, 100 is the maximum. Our default is 5. + # @param page [Integer] Use this to page through the results. + # + # @return [String] JSON response + def get_top_headlines( + country: nil, + category: nil, + sources: nil, + q: nil, + page_size: 5, + page: nil + ) + Langchain.logger.info("Retrieving top news headlines", for: self.class) + + params = {apiKey: @api_key} + params[:country] = country if country + params[:category] = category if category + params[:sources] = sources if sources + params[:q] = q if q + params[:pageSize] = page_size if page_size + params[:page] = page if page + + send_request(path: "top-headlines", params: params) + end + + # Retrieve news sources + # + # @param category [String] The category you want to get headlines for. Possible options: business, entertainment, general, health, science, sports, technology. + # @param language [String] The 2-letter ISO-639-1 code of the language you want to get headlines for. Possible options: ar, de, en, es, fr, he, it, nl, no, pt, ru, se, ud, zh. + # @param country [String] The 2-letter ISO 3166-1 code of the country you want to get headlines for. Possible options: ae, ar, at, au, be, bg, br, ca, ch, cn, co, cu, cz, de, eg, fr, gb, gr, hk, hu, id, ie, il, in, it, jp, kr, lt, lv, ma, mx, my, ng, nl, no, nz, ph, pl, pt, ro, rs, ru, sa, se, sg, si, sk, th, tr, tw, ua, us, ve, za. + # + # @return [String] JSON response + def get_sources( + category: nil, + language: nil, + country: nil + ) + Langchain.logger.info("Retrieving news sources", for: self.class) + + params = {apiKey: @api_key} + params[:country] = country if country + params[:category] = category if category + params[:language] = language if language + + send_request(path: "top-headlines/sources", params: params) + end + + private + + def send_request(path:, params:) + uri = URI.parse("https://newsapi.org/v2/#{path}?#{URI.encode_www_form(params)}") + http = Net::HTTP.new(uri.host, uri.port) + http.use_ssl = true + + request = Net::HTTP::Get.new(uri.request_uri) + request["Content-Type"] = "application/json" + + response = http.request(request) + response.body + end + end +end diff --git a/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json b/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json index 5fff95298..718c0a993 100644 --- a/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +++ b/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json @@ -2,7 +2,7 @@ { "type": "function", "function": { - "name": "ruby_code_interpreter-execute", + "name": "ruby_code_interpreter__execute", "description": "Executes Ruby code in a sandboxes environment.", "parameters": { "type": "object", diff --git a/lib/langchain/tool/vectorsearch/vectorsearch.json b/lib/langchain/tool/vectorsearch/vectorsearch.json index cb63f1e77..f3577dcea 100644 --- a/lib/langchain/tool/vectorsearch/vectorsearch.json +++ b/lib/langchain/tool/vectorsearch/vectorsearch.json @@ -2,7 +2,7 @@ { "type": "function", "function": { - "name": "vectorsearch-similarity_search", + "name": "vectorsearch__similarity_search", "description": "Vectorsearch: Retrieves relevant document for the query", "parameters": { "type": "object", diff --git a/lib/langchain/tool/weather/weather.json b/lib/langchain/tool/weather/weather.json index 7a2ebc2d7..79ab09baa 100644 --- a/lib/langchain/tool/weather/weather.json +++ b/lib/langchain/tool/weather/weather.json @@ -2,7 +2,7 @@ { "type": "function", "function": { - "name": "weather-execute", + "name": "weather__execute", "description": "Returns current weather for a city", "parameters": { "type": "object", diff --git a/lib/langchain/tool/wikipedia/wikipedia.json b/lib/langchain/tool/wikipedia/wikipedia.json index ee246418d..da5782a34 100644 --- a/lib/langchain/tool/wikipedia/wikipedia.json +++ b/lib/langchain/tool/wikipedia/wikipedia.json @@ -2,7 +2,7 @@ { "type": "function", "function": { - "name": "wikipedia-execute", + "name": "wikipedia__execute", "description": "Executes Wikipedia API search and returns the answer", "parameters": { "type": "object", diff --git a/lib/langchain/tool/wikipedia/wikipedia.rb b/lib/langchain/tool/wikipedia/wikipedia.rb index 73268a824..82419a2b6 100644 --- a/lib/langchain/tool/wikipedia/wikipedia.rb +++ b/lib/langchain/tool/wikipedia/wikipedia.rb @@ -9,8 +9,8 @@ class Wikipedia < Base # gem "wikipedia-client", "~> 1.17.0" # # Usage: - # weather = Langchain::Tool::Wikipedia.new - # weather.execute(input: "The Roman Empire") + # wikipedia = Langchain::Tool::Wikipedia.new + # wikipedia.execute(input: "The Roman Empire") # NAME = "wikipedia" ANNOTATIONS_PATH = Langchain.root.join("./langchain/tool/#{NAME}/#{NAME}.json").to_path diff --git a/spec/fixtures/llm/google_gemini/chat.json b/spec/fixtures/llm/google_gemini/chat.json new file mode 100644 index 000000000..c8a48ee32 --- /dev/null +++ b/spec/fixtures/llm/google_gemini/chat.json @@ -0,0 +1,13 @@ +{ + "candidates": [ + { + "content": { + "parts": [{"text": "The answer is 4.0"}], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [] + } + ] +} diff --git a/spec/fixtures/llm/google_gemini/chat_with_tool_calls.json b/spec/fixtures/llm/google_gemini/chat_with_tool_calls.json new file mode 100644 index 000000000..b8bb1a75d --- /dev/null +++ b/spec/fixtures/llm/google_gemini/chat_with_tool_calls.json @@ -0,0 +1,20 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "calculator__execute", + "args": {"input": "2+2"} + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [] + } + ] +} diff --git a/spec/langchain/assistants/assistant_spec.rb b/spec/langchain/assistants/assistant_spec.rb index d553d9c2f..ae37121a9 100644 --- a/spec/langchain/assistants/assistant_spec.rb +++ b/spec/langchain/assistants/assistant_spec.rb @@ -1,142 +1,344 @@ # frozen_string_literal: true RSpec.describe Langchain::Assistant do - let(:thread) { Langchain::Thread.new } - let(:llm) { Langchain::LLM::OpenAI.new(api_key: "123") } - let(:calculator) { Langchain::Tool::Calculator.new } - let(:instructions) { "You are an expert assistant" } - - subject { - described_class.new( - llm: llm, - thread: thread, - tools: [calculator], - instructions: instructions - ) - } - - it "raises an error if tools array contains non-Langchain::Tool instance(s)" do - expect { described_class.new(tools: [Langchain::Tool::Calculator.new, "foo"]) }.to raise_error(ArgumentError) - end + context "when llm is OpenAI" do + let(:thread) { Langchain::Thread.new } + let(:llm) { Langchain::LLM::OpenAI.new(api_key: "123") } + let(:calculator) { Langchain::Tool::Calculator.new } + let(:instructions) { "You are an expert assistant" } - it "raises an error if LLM class does not implement `chat()` method" do - expect { described_class.new(llm: llm) }.to raise_error(ArgumentError) - end + subject { + described_class.new( + llm: llm, + thread: thread, + tools: [calculator], + instructions: instructions + ) + } - it "raises an error if thread is not an instance of Langchain::Thread" do - expect { described_class.new(thread: "foo") }.to raise_error(ArgumentError) - end + it "raises an error if tools array contains non-Langchain::Tool instance(s)" do + expect { described_class.new(tools: [Langchain::Tool::Calculator.new, "foo"]) }.to raise_error(ArgumentError) + end - describe "#initialize" do - it "adds a system message to the thread" do - described_class.new(llm: llm, thread: thread, instructions: instructions) - expect(thread.messages.first.role).to eq("system") - expect(thread.messages.first.content).to eq("You are an expert assistant") + it "raises an error if LLM class does not implement `chat()` method" do + expect { described_class.new(llm: llm) }.to raise_error(ArgumentError) end - end - describe "#add_message" do - it "adds a message to the thread" do - subject.add_message(content: "foo") - expect(thread.messages.last.role).to eq("user") - expect(thread.messages.last.content).to eq("foo") + it "raises an error if thread is not an instance of Langchain::Thread" do + expect { described_class.new(thread: "foo") }.to raise_error(ArgumentError) end - end - describe "submit_tool_output" do - it "adds a message to the thread" do - subject.submit_tool_output(tool_call_id: "123", output: "bar") - expect(thread.messages.last.role).to eq("tool") - expect(thread.messages.last.content).to eq("bar") + describe "#initialize" do + it "adds a system message to the thread" do + described_class.new(llm: llm, thread: thread, instructions: instructions) + expect(thread.messages.first.role).to eq("system") + expect(thread.messages.first.content).to eq("You are an expert assistant") + end end - end - describe "#run" do - let(:raw_openai_response) do - { - "id" => "chatcmpl-96QTYLFcp0haHHRhnqvTYL288357W", - "object" => "chat.completion", - "created" => 1711318768, - "model" => "gpt-3.5-turbo-0125", - "choices" => [ + describe "#add_message" do + it "adds a message to the thread" do + subject.add_message(content: "foo") + expect(thread.messages.last.role).to eq("user") + expect(thread.messages.last.content).to eq("foo") + end + end + + describe "submit_tool_output" do + it "adds a message to the thread" do + subject.submit_tool_output(tool_call_id: "123", output: "bar") + expect(thread.messages.last.role).to eq("tool") + expect(thread.messages.last.content).to eq("bar") + end + end + + describe "#run" do + let(:raw_openai_response) do + { + "id" => "chatcmpl-96QTYLFcp0haHHRhnqvTYL288357W", + "object" => "chat.completion", + "created" => 1711318768, + "model" => "gpt-3.5-turbo-0125", + "choices" => [ + { + "index" => 0, + "message" => { + "role" => "assistant", + "content" => nil, + "tool_calls" => [ + { + "id" => "call_9TewGANaaIjzY31UCpAAGLeV", + "type" => "function", + "function" => {"name" => "calculator__execute", "arguments" => "{\"input\":\"2+2\"}"} + } + ] + }, + "logprobs" => nil, + "finish_reason" => "tool_calls" + } + ], + "usage" => {"prompt_tokens" => 91, "completion_tokens" => 18, "total_tokens" => 109}, + "system_fingerprint" => "fp_3bc1b5746b" + } + end + + context "when auto_tool_execution is false" do + before do + allow(subject.llm).to receive(:chat) + .with( + messages: [ + {role: "system", content: instructions}, + {role: "user", content: "Please calculate 2+2"} + ], + tools: calculator.to_openai_tools, + tool_choice: "auto" + ) + .and_return(Langchain::LLM::OpenAIResponse.new(raw_openai_response)) + end + + it "runs the assistant" do + subject.add_message(role: "user", content: "Please calculate 2+2") + subject.run(auto_tool_execution: false) + + expect(subject.thread.messages.last.role).to eq("assistant") + expect(subject.thread.messages.last.tool_calls).to eq([raw_openai_response["choices"][0]["message"]["tool_calls"]][0]) + end + end + + context "when auto_tool_execution is true" do + let(:raw_openai_response2) do { - "index" => 0, - "message" => { - "role" => "assistant", - "content" => nil, - "tool_calls" => [ - { - "id" => "call_9TewGANaaIjzY31UCpAAGLeV", - "type" => "function", - "function" => {"name" => "calculator-execute", "arguments" => "{\"input\":\"2+2\"}"} - } - ] - }, - "logprobs" => nil, - "finish_reason" => "tool_calls" + "id" => "chatcmpl-96P6eEMDDaiwzRIHJZAliYHQ8ov3q", + "object" => "chat.completion", + "created" => 1711313504, + "model" => "gpt-3.5-turbo-0125", + "choices" => [{"index" => 0, "message" => {"role" => "assistant", "content" => "The result of 2 + 2 is 4."}, "logprobs" => nil, "finish_reason" => "stop"}], + "usage" => {"prompt_tokens" => 121, "completion_tokens" => 13, "total_tokens" => 134}, + "system_fingerprint" => "fp_3bc1b5746c" } - ], - "usage" => {"prompt_tokens" => 91, "completion_tokens" => 18, "total_tokens" => 109}, - "system_fingerprint" => "fp_3bc1b5746b" - } + end + + before do + allow(subject.llm).to receive(:chat) + .with( + messages: [ + {role: "system", content: instructions}, + {role: "user", content: "Please calculate 2+2"}, + {role: "assistant", content: "", tool_calls: [ + { + "function" => {"arguments" => "{\"input\":\"2+2\"}", "name" => "calculator__execute"}, + "id" => "call_9TewGANaaIjzY31UCpAAGLeV", + "type" => "function" + } + ]}, + {content: "4.0", role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"} + ], + tools: calculator.to_openai_tools, + tool_choice: "auto" + ) + .and_return(Langchain::LLM::OpenAIResponse.new(raw_openai_response2)) + end + + it "runs the assistant and automatically executes tool calls" do + allow(subject.tools[0]).to receive(:execute).with( + input: "2+2" + ).and_return("4.0") + + subject.add_message(role: "user", content: "Please calculate 2+2") + subject.add_message(role: "assistant", tool_calls: raw_openai_response["choices"][0]["message"]["tool_calls"]) + + subject.run(auto_tool_execution: true) + + expect(subject.thread.messages[-2].role).to eq("tool") + expect(subject.thread.messages[-2].content).to eq("4.0") + + expect(subject.thread.messages[-1].role).to eq("assistant") + expect(subject.thread.messages[-1].content).to eq("The result of 2 + 2 is 4.") + end + end + + context "when messages are empty" do + let(:instructions) { nil } + + before do + allow_any_instance_of(Langchain::ContextualLogger).to receive(:warn).with("No messages in the thread") + end + + it "logs a warning" do + expect(subject.thread.messages).to be_empty + subject.run + expect(Langchain.logger).to have_received(:warn).with("No messages in the thread") + end + end + end + + describe "#extract_openai_tool_call" do + let(:tool_call) { {"id" => "call_9TewGANaaIjzY31UCpAAGLeV", "type" => "function", "function" => {"name" => "calculator__execute", "arguments" => "{\"input\":\"2+2\"}"}} } + + it "returns correct data" do + expect(subject.send(:extract_openai_tool_call, tool_call: tool_call)).to eq(["call_9TewGANaaIjzY31UCpAAGLeV", "calculator", "execute", {input: "2+2"}]) + end end + end + + context "when llm is GoogleGemini" do + let(:thread) { Langchain::Thread.new } + let(:llm) { Langchain::LLM::GoogleGemini.new(api_key: "123") } + let(:calculator) { Langchain::Tool::Calculator.new } + let(:instructions) { "You are an expert assistant" } - context "when auto_tool_execution is false" do - it "runs the assistant" do - allow(subject.llm).to receive(:chat).and_return(Langchain::LLM::OpenAIResponse.new(raw_openai_response)) + subject { + described_class.new( + llm: llm, + thread: thread, + tools: [calculator], + instructions: instructions + ) + } + + it "raises an error if tools array contains non-Langchain::Tool instance(s)" do + expect { described_class.new(tools: [Langchain::Tool::Calculator.new, "foo"]) }.to raise_error(ArgumentError) + end - subject.add_message(role: "user", content: "Please calculate 2+2") + it "raises an error if LLM class does not implement `chat()` method" do + expect { described_class.new(llm: llm) }.to raise_error(ArgumentError) + end - subject.run(auto_tool_execution: false) + it "raises an error if thread is not an instance of Langchain::Thread" do + expect { described_class.new(thread: "foo") }.to raise_error(ArgumentError) + end - expect(subject.thread.messages.last.role).to eq("assistant") - expect(subject.thread.messages.last.tool_calls).to eq([raw_openai_response["choices"][0]["message"]["tool_calls"]][0]) + describe "#add_message" do + it "adds a message to the thread" do + subject.add_message(content: "foo") + expect(thread.messages.last.role).to eq("user") + expect(thread.messages.last.content).to eq("foo") end end - context "when auto_tool_execution is true" do - let(:raw_openai_response2) do + describe "submit_tool_output" do + it "adds a message to the thread" do + subject.submit_tool_output(tool_call_id: "123", output: "bar") + expect(thread.messages.last.role).to eq("function") + expect(thread.messages.last.content).to eq("bar") + end + end + + describe "#run" do + let(:raw_google_gemini_response) do { - "id" => "chatcmpl-96P6eEMDDaiwzRIHJZAliYHQ8ov3q", - "object" => "chat.completion", - "created" => 1711313504, - "model" => "gpt-3.5-turbo-0125", - "choices" => [{"index" => 0, "message" => {"role" => "assistant", "content" => "The result of 2 + 2 is 4."}, "logprobs" => nil, "finish_reason" => "stop"}], - "usage" => {"prompt_tokens" => 121, "completion_tokens" => 13, "total_tokens" => 134}, - "system_fingerprint" => "fp_3bc1b5746c" + "candidates" => [ + { + "content" => { + "parts" => [ + { + "functionCall" => { + "name" => "calculator__execute", + "args" => {"input" => "2+2"} + } + } + ], + "role" => "model" + }, + "finishReason" => "STOP", + "index" => 0, + "safetyRatings" => [] + } + ] } end - it "runs the assistant and automatically executes tool calls" do - allow(subject.llm).to receive(:chat).and_return(Langchain::LLM::OpenAIResponse.new(raw_openai_response2)) - allow(subject.tools[0]).to receive(:execute).with( - input: "2+2" - ).and_return("4.0") + context "when auto_tool_execution is false" do + before do + allow(subject.llm).to receive(:chat) + .with( + messages: [{role: "user", parts: [{text: "Please calculate 2+2"}]}], + tools: calculator.to_google_gemini_tools, + tool_choice: "auto", + system: instructions + ) + .and_return(Langchain::LLM::GoogleGeminiResponse.new(raw_google_gemini_response)) + end + + it "runs the assistant" do + subject.add_message(role: "user", content: "Please calculate 2+2") + subject.run(auto_tool_execution: false) - subject.add_message(role: "user", content: "Please calculate 2+2") - subject.add_message(role: "assistant", tool_calls: raw_openai_response["choices"][0]["message"]["tool_calls"]) + expect(subject.thread.messages.last.role).to eq("model") + expect(subject.thread.messages.last.tool_calls).to eq([raw_google_gemini_response["candidates"][0]["content"]["parts"]][0]) + end + end + + context "when auto_tool_execution is true" do + let(:raw_google_gemini_response2) do + { + "candidates" => [ + { + "content" => { + "parts" => [{"text" => "The answer is 4.0"}], + "role" => "model" + }, + "finishReason" => "STOP", + "index" => 0, + "safetyRatings" => [] + } + ] + } + end - subject.run(auto_tool_execution: true) + before do + allow(subject.llm).to receive(:chat) + .with( + messages: [ + {role: "user", parts: [{text: "Please calculate 2+2"}]}, + {role: "model", parts: [{"functionCall" => {"name" => "calculator__execute", "args" => {"input" => "2+2"}}}]}, + {role: "function", parts: [{functionResponse: {name: "calculator__execute", response: {name: "calculator__execute", content: "4.0"}}}]} + ], + tools: calculator.to_google_gemini_tools, + tool_choice: "auto", + system: instructions + ) + .and_return(Langchain::LLM::GoogleGeminiResponse.new(raw_google_gemini_response2)) + end - expect(subject.thread.messages[-2].role).to eq("tool") - expect(subject.thread.messages[-2].content).to eq("4.0") + it "runs the assistant and automatically executes tool calls" do + allow(subject.tools[0]).to receive(:execute).with( + input: "2+2" + ).and_return("4.0") - expect(subject.thread.messages[-1].role).to eq("assistant") - expect(subject.thread.messages[-1].content).to eq("The result of 2 + 2 is 4.") + subject.add_message(role: "user", content: "Please calculate 2+2") + subject.add_message(role: "model", tool_calls: raw_google_gemini_response["candidates"][0]["content"]["parts"]) + + subject.run(auto_tool_execution: true) + + expect(subject.thread.messages[-2].role).to eq("function") + expect(subject.thread.messages[-2].content).to eq("4.0") + + expect(subject.thread.messages[-1].role).to eq("model") + expect(subject.thread.messages[-1].content).to eq("The answer is 4.0") + end end - end - context "when messages are empty" do - let(:instructions) { nil } + context "when messages are empty" do + let(:instructions) { nil } - before do - allow_any_instance_of(Langchain::ContextualLogger).to receive(:warn).with("No messages in the thread") + before do + allow_any_instance_of(Langchain::ContextualLogger).to receive(:warn).with("No messages in the thread") + end + + it "logs a warning" do + expect(subject.thread.messages).to be_empty + subject.run + expect(Langchain.logger).to have_received(:warn).with("No messages in the thread") + end end + end + + describe "#extract_google_gemini_tool_call" do + let(:tool_call) { {"functionCall" => {"name" => "calculator__execute", "args" => {"input" => "2+2"}}} } - it "logs a warning" do - expect(subject.thread.messages).to be_empty - subject.run - expect(Langchain.logger).to have_received(:warn).with("No messages in the thread") + it "returns correct data" do + expect(subject.send(:extract_google_gemini_tool_call, tool_call: tool_call)).to eq(["calculator__execute", "calculator", "execute", {input: "2+2"}]) end end end diff --git a/spec/langchain/assistants/messages/google_gemini_message_spec.rb b/spec/langchain/assistants/messages/google_gemini_message_spec.rb new file mode 100644 index 000000000..156e8520e --- /dev/null +++ b/spec/langchain/assistants/messages/google_gemini_message_spec.rb @@ -0,0 +1,19 @@ +# frozen_string_literal: true + +RSpec.describe Langchain::Messages::GoogleGeminiMessage do + it "raises an error if role is not one of allowed" do + expect { described_class.new(role: "foo") }.to raise_error(ArgumentError) + end + + describe "#to_hash" do + it "returns function" do + message = described_class.new(role: "function", content: "4.0", tool_call_id: "calculator__execute") + expect(message.to_hash).to eq({parts: [{functionResponse: {name: "calculator__execute", response: {content: "4.0", name: "calculator__execute"}}}], role: "function"}) + end + + it "returns tool_calls" do + message = described_class.new(role: "model", tool_calls: []) + expect(message.to_hash).to eq({parts: [{text: ""}], role: "model"}) + end + end +end diff --git a/spec/langchain/assistants/message_spec.rb b/spec/langchain/assistants/messages/openai_message_spec.rb similarity index 66% rename from spec/langchain/assistants/message_spec.rb rename to spec/langchain/assistants/messages/openai_message_spec.rb index 5df8f5ab5..fe307dae2 100644 --- a/spec/langchain/assistants/message_spec.rb +++ b/spec/langchain/assistants/messages/openai_message_spec.rb @@ -1,18 +1,16 @@ # frozen_string_literal: true -RSpec.describe Langchain::Message do - subject { described_class.new } - +RSpec.describe Langchain::Messages::OpenAIMessage do it "raises an error if role is not one of allowed" do expect { described_class.new(role: "foo") }.to raise_error(ArgumentError) end - describe "#to_openai_format" do + describe "#to_hash" do context "when role and content are not nil" do let(:message) { described_class.new(role: "user", content: "Hello, world!", tool_calls: [], tool_call_id: nil) } it "returns a hash with the role and content key" do - expect(message.to_openai_format).to eq({role: "user", content: "Hello, world!"}) + expect(message.to_hash).to eq({role: "user", content: "Hello, world!"}) end end @@ -20,7 +18,7 @@ let(:message) { described_class.new(role: "tool", content: "Hello, world!", tool_calls: [], tool_call_id: "123") } it "returns a hash with the tool_call_id key" do - expect(message.to_openai_format).to eq({role: "tool", content: "Hello, world!", tool_call_id: "123"}) + expect(message.to_hash).to eq({role: "tool", content: "Hello, world!", tool_call_id: "123"}) end end @@ -28,13 +26,13 @@ let(:tool_call) { {"id" => "call_9TewGANaaIjzY31UCpAAGLeV", "type" => "function", - "function" => {"name" => "weather-execute", "arguments" => "{\"input\":\"Saint Petersburg\"}"}} + "function" => {"name" => "weather__execute", "arguments" => "{\"input\":\"Saint Petersburg\"}"}} } let(:message) { described_class.new(role: "assistant", content: "", tool_calls: [tool_call], tool_call_id: nil) } it "returns a hash with the tool_calls key" do - expect(message.to_openai_format).to eq({role: "assistant", content: "", tool_calls: [tool_call]}) + expect(message.to_hash).to eq({role: "assistant", content: "", tool_calls: [tool_call]}) end end end diff --git a/spec/langchain/assistants/thread_spec.rb b/spec/langchain/assistants/thread_spec.rb index e70e5e2a9..a01c81d0a 100644 --- a/spec/langchain/assistants/thread_spec.rb +++ b/spec/langchain/assistants/thread_spec.rb @@ -2,16 +2,16 @@ RSpec.describe Langchain::Thread do it "raises an error if messages array contains non-Langchain::Message instance(s)" do - expect { described_class.new(messages: [Langchain::Message.new, "foo"]) }.to raise_error(ArgumentError) + expect { described_class.new(messages: [Langchain::Messages::OpenAIMessage.new, "foo"]) }.to raise_error(ArgumentError) end describe "#openai_messages" do it "returns an array of messages in OpenAI format" do - messages = [Langchain::Message.new(role: "user", content: "hello"), - Langchain::Message.new(role: "assistant", content: "hi")] + messages = [Langchain::Messages::OpenAIMessage.new(role: "user", content: "hello"), + Langchain::Messages::OpenAIMessage.new(role: "assistant", content: "hi")] thread = described_class.new(messages: messages) - openai_messages = thread.openai_messages + openai_messages = thread.array_of_message_hashes expect(openai_messages).to be_an(Array) expect(openai_messages.length).to eq(messages.length) @@ -24,7 +24,7 @@ end describe "#add_message" do - let(:message) { Langchain::Message.new(role: "user", content: "hello") } + let(:message) { Langchain::Messages::OpenAIMessage.new(role: "user", content: "hello") } it "adds a Langchain::Message instance to the messages array" do thread = described_class.new(messages: []) diff --git a/spec/langchain/llm/google_vertex_ai_spec.rb b/spec/langchain/llm/google_vertex_ai_spec.rb index d58727db0..09aed6a64 100644 --- a/spec/langchain/llm/google_vertex_ai_spec.rb +++ b/spec/langchain/llm/google_vertex_ai_spec.rb @@ -1,31 +1,28 @@ # frozen_string_literal: true -require "google-apis-aiplatform_v1" +require "googleauth" -RSpec.describe Langchain::LLM::GoogleVertexAi do - let(:subject) { described_class.new(project_id: "123") } +RSpec.describe Langchain::LLM::GoogleVertexAI do + let(:subject) { described_class.new(project_id: "123", region: "us-central1") } describe "#embed" do let(:embedding) { [-0.00879860669374466, 0.007578692398965359, 0.021136576309800148] } - let(:raw_embedding_response) { JSON.parse(File.read("spec/fixtures/llm/google_vertex_ai/embed.json"), symbolize_names: true) } + let(:raw_embedding_response) { JSON.parse(File.read("spec/fixtures/llm/google_vertex_ai/embed.json")) } before do allow(Google::Auth).to receive(:get_application_default).and_return( - double("Google::Auth::UserRefreshCredentials") + double("Google::Auth::UserRefreshCredentials", fetch_access_token!: {access_token: 123}) ) - allow(subject.client).to receive(:predict_project_location_publisher_model).and_return( - double("Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictResponse", to_h: raw_embedding_response) - ) + allow(HTTParty).to receive(:post).and_return(raw_embedding_response) end it "returns valid llm response object" do response = subject.embed(text: "Hello world") - expect(response).to be_a(Langchain::LLM::GoogleVertexAiResponse) + expect(response).to be_a(Langchain::LLM::GoogleGeminiResponse) expect(response.model).to eq("textembedding-gecko") expect(response.embedding).to eq(embedding) - expect(response.total_tokens).to eq(3) end end end diff --git a/spec/langchain/llm/response/google_gemini_response_spec.rb b/spec/langchain/llm/response/google_gemini_response_spec.rb new file mode 100644 index 000000000..d1416fa05 --- /dev/null +++ b/spec/langchain/llm/response/google_gemini_response_spec.rb @@ -0,0 +1,51 @@ +# frozen_string_literal: true + +RSpec.describe Langchain::LLM::GoogleGeminiResponse do + describe "#chat_completion" do + let(:raw_response) { + JSON.parse File.read("spec/fixtures/llm/google_gemini/chat.json") + } + let(:response) { described_class.new(raw_response) } + + it "returns text" do + expect(response.chat_completion).to eq("The answer is 4.0") + end + + it "returns role" do + expect(response.role).to eq("model") + end + end + + describe "#tool_calls" do + let(:raw_response) { + JSON.parse File.read("spec/fixtures/llm/google_gemini/chat_with_tool_calls.json") + } + let(:response) { described_class.new(raw_response) } + + it "returns tool_calls" do + expect(response.tool_calls).to eq([{"functionCall" => {"name" => "calculator__execute", "args" => {"input" => "2+2"}}}]) + end + end + + describe "#embeddings" do + let(:raw_embedding_response) { JSON.parse(File.read("spec/fixtures/llm/google_vertex_ai/embed.json")) } + + subject { described_class.new(raw_embedding_response) } + + it "returns embeddings" do + expect(subject.embeddings).to eq([[ + -0.00879860669374466, + 0.007578692398965359, + 0.021136576309800148 + ]]) + end + + it "#returns embedding" do + expect(subject.embedding).to eq([ + -0.00879860669374466, + 0.007578692398965359, + 0.021136576309800148 + ]) + end + end +end diff --git a/spec/langchain/llm/response/google_vertex_ai_response_spec.rb b/spec/langchain/llm/response/google_vertex_ai_response_spec.rb deleted file mode 100644 index ab2c40f18..000000000 --- a/spec/langchain/llm/response/google_vertex_ai_response_spec.rb +++ /dev/null @@ -1,29 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe Langchain::LLM::GoogleVertexAiResponse do - let(:raw_embedding_response) { JSON.parse(File.read("spec/fixtures/llm/google_vertex_ai/embed.json"), symbolize_names: true) } - - describe "embeddings" do - subject { described_class.new(raw_embedding_response) } - - it "returns embeddings" do - expect(subject.embeddings).to eq([[ - -0.00879860669374466, - 0.007578692398965359, - 0.021136576309800148 - ]]) - end - - it "#returns embedding" do - expect(subject.embedding).to eq([ - -0.00879860669374466, - 0.007578692398965359, - 0.021136576309800148 - ]) - end - - it "#return total_tokens" do - expect(subject.total_tokens).to eq(3) - end - end -end