From 792fcaa2aabde1012924b666b968217badfd2507 Mon Sep 17 00:00:00 2001 From: Andrei Bondarev Date: Mon, 30 Sep 2024 15:45:17 -0400 Subject: [PATCH 1/2] Langchain::Assistant when using OpenAI accept a message with image_url --- README.md | 6 +++ lib/langchain/assistants/assistant.rb | 42 ++++++++++------- lib/langchain/assistants/messages/base.rb | 6 ++- .../assistants/messages/openai_message.rb | 45 +++++++++++++++---- spec/langchain/assistants/assistant_spec.rb | 25 ++++++++--- .../messages/openai_message_spec.rb | 22 +++++++-- 6 files changed, 111 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 509c37bf7..fcc43937b 100644 --- a/README.md +++ b/README.md @@ -501,6 +501,12 @@ assistant = Langchain::Assistant.new( # Add a user message and run the assistant assistant.add_message_and_run!(content: "What's the latest news about AI?") +# Supply an image to the assistant +assistant.add_message_and_run!( + content: "Show me a picture of a cat", + image: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" +) + # Access the conversation thread messages = assistant.messages diff --git a/lib/langchain/assistants/assistant.rb b/lib/langchain/assistants/assistant.rb index 5b5cf9a38..0ad20c390 100644 --- a/lib/langchain/assistants/assistant.rb +++ b/lib/langchain/assistants/assistant.rb @@ -63,13 +63,14 @@ def initialize( # Add a user message to the messages array # - # @param content [String] The content of the message # @param role [String] The role attribute of the message. Default: "user" + # @param content [String] The content of the message + # @param image_url [String] The URL of the image to include in the message # @param tool_calls [Array] The tool calls to include in the message # @param tool_call_id [String] The ID of the tool call to include in the message # @return [Array] The messages - def add_message(content: nil, role: "user", tool_calls: [], tool_call_id: nil) - message = build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) + def add_message(role: "user", content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) + message = build_message(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id) # Call the callback with the message add_message_callback.call(message) if add_message_callback # rubocop:disable Style/SafeNavigation @@ -145,8 +146,8 @@ def run! # @param content [String] The content of the message # @param auto_tool_execution [Boolean] Whether or not to automatically run tools # @return [Array] The messages - def add_message_and_run(content:, auto_tool_execution: false) - add_message(content: content, role: "user") + def add_message_and_run(content: nil, image_url: nil, auto_tool_execution: false) + add_message(content: content, image_url: image_url, role: "user") run(auto_tool_execution: auto_tool_execution) end @@ -154,8 +155,8 @@ def add_message_and_run(content:, auto_tool_execution: false) # # @param content [String] The content of the message # @return [Array] The messages - def add_message_and_run!(content:) - add_message_and_run(content: content, auto_tool_execution: true) + def add_message_and_run!(content: nil, image_url: nil) + add_message_and_run(content: content, image_url: image_url, auto_tool_execution: true) end # Submit tool output @@ -388,11 +389,12 @@ def run_tools(tool_calls) # # @param role [String] The role of the message # @param content [String] The content of the message + # @param image_url [String] The URL of the image to include in the message # @param tool_calls [Array] The tool calls to include in the message # @param tool_call_id [String] The ID of the tool call to include in the message # @return [Langchain::Message] The Message object - def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil) - @llm_adapter.build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) + def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) + @llm_adapter.build_message(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id) end # Increment the tokens count based on the last interaction with the LLM @@ -443,7 +445,7 @@ def extract_tool_call_args(tool_call:) raise NotImplementedError, "Subclasses must implement extract_tool_call_args" end - def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil) + def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) raise NotImplementedError, "Subclasses must implement build_message" end end @@ -457,7 +459,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:) params end - def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil) + def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) + warn "Image URL is not supported by Ollama currently" if image_url + Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) end @@ -506,8 +510,8 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:) params end - def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil) - Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) + def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) + Langchain::Messages::OpenAIMessage.new(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id) end # Extract the tool call information from the OpenAI tool call hash @@ -564,7 +568,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:) params end - def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil) + def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) + warn "Image URL is not supported by MistralAI currently" if image_url + Langchain::Messages::MistralAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) end @@ -623,7 +629,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:) params end - def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil) + def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) + warn "Image URL is not supported by Google Gemini" if image_url + Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) end @@ -676,7 +684,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:) params end - def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil) + def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) + warn "Image URL is not supported by Anthropic currently" if image_url + Langchain::Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) end diff --git a/lib/langchain/assistants/messages/base.rb b/lib/langchain/assistants/messages/base.rb index 3f950e316..54f9b2126 100644 --- a/lib/langchain/assistants/messages/base.rb +++ b/lib/langchain/assistants/messages/base.rb @@ -3,7 +3,11 @@ module Langchain module Messages class Base - attr_reader :role, :content, :tool_calls, :tool_call_id + attr_reader :role, + :content, + :image_url, + :tool_calls, + :tool_call_id # Check if the message came from a user # diff --git a/lib/langchain/assistants/messages/openai_message.rb b/lib/langchain/assistants/messages/openai_message.rb index b673ff377..0be0b0f53 100644 --- a/lib/langchain/assistants/messages/openai_message.rb +++ b/lib/langchain/assistants/messages/openai_message.rb @@ -15,17 +15,25 @@ class OpenAIMessage < Base # Initialize a new OpenAI message # - # @param [String] The role of the message - # @param [String] The content of the message - # @param [Array] The tool calls made in the message - # @param [String] The ID of the tool call - def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content) + # @param role [String] The role of the message + # @param content [String] The content of the message + # @param image_url [String] The URL of the image + # @param tool_calls [Array] The tool calls made in the message + # @param tool_call_id [String] The ID of the tool call + def initialize( + role:, + content: nil, + image_url: nil, + tool_calls: [], + tool_call_id: nil + ) raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role) raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) } @role = role # Some Tools return content as a JSON hence `.to_s` @content = content.to_s + @image_url = image_url @tool_calls = tool_calls @tool_call_id = tool_call_id end @@ -43,9 +51,30 @@ def llm? def to_hash {}.tap do |h| h[:role] = role - h[:content] = content if content # Content is nil for tool calls - h[:tool_calls] = tool_calls if tool_calls.any? - h[:tool_call_id] = tool_call_id if tool_call_id + + if tool_calls.any? + h[:tool_calls] = tool_calls + else + h[:tool_call_id] = tool_call_id if tool_call_id + + h[:content] = [] + + if content && !content.empty? + h[:content] << { + type: "text", + text: content + } + end + + if image_url + h[:content] << { + type: "image_url", + image_url: { + url: image_url + } + } + end + end end end diff --git a/spec/langchain/assistants/assistant_spec.rb b/spec/langchain/assistants/assistant_spec.rb index f68f42593..95037ac49 100644 --- a/spec/langchain/assistants/assistant_spec.rb +++ b/spec/langchain/assistants/assistant_spec.rb @@ -87,6 +87,19 @@ expect(subject.messages.first.content).to eq("hello") end + it "adds a message with image_url" do + message_with_image = {role: "user", content: "hello", image_url: "https://example.com/image.jpg"} + subject = described_class.new(llm: llm, messages: []) + + expect { + subject.add_message(**message_with_image) + }.to change { subject.messages.count }.from(0).to(1) + expect(subject.messages.first).to be_a(Langchain::Messages::OpenAIMessage) + expect(subject.messages.first.role).to eq("user") + expect(subject.messages.first.content).to eq("hello") + expect(subject.messages.first.image_url).to eq("https://example.com/image.jpg") + end + it "calls the add_message_callback with the message" do callback = double("callback", call: true) subject = described_class.new(llm: llm, messages: [], add_message_callback: callback) @@ -211,8 +224,8 @@ allow(subject.llm).to receive(:chat) .with( messages: [ - {role: "system", content: instructions}, - {role: "user", content: "Please calculate 2+2"} + {role: "system", content: [{type:"text", text: instructions}]}, + {role: "user", content: [{type:"text", text: "Please calculate 2+2"}]} ], tools: calculator.class.function_schemas.to_openai_format, tool_choice: "auto" @@ -255,16 +268,16 @@ allow(subject.llm).to receive(:chat) .with( messages: [ - {role: "system", content: instructions}, - {role: "user", content: "Please calculate 2+2"}, - {role: "assistant", content: "", tool_calls: [ + {role: "system", content: [{type:"text", text: instructions}]}, + {role: "user", content: [{type:"text", text:"Please calculate 2+2"}]}, + {role: "assistant", tool_calls: [ { "function" => {"arguments" => "{\"input\":\"2+2\"}", "name" => "langchain_tool_calculator__execute"}, "id" => "call_9TewGANaaIjzY31UCpAAGLeV", "type" => "function" } ]}, - {content: "4.0", role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"} + {content: [{type:"text", text:"4.0"}], role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"} ], tools: calculator.class.function_schemas.to_openai_format, tool_choice: "auto" diff --git a/spec/langchain/assistants/messages/openai_message_spec.rb b/spec/langchain/assistants/messages/openai_message_spec.rb index fe307dae2..59006b6fd 100644 --- a/spec/langchain/assistants/messages/openai_message_spec.rb +++ b/spec/langchain/assistants/messages/openai_message_spec.rb @@ -10,7 +10,7 @@ let(:message) { described_class.new(role: "user", content: "Hello, world!", tool_calls: [], tool_call_id: nil) } it "returns a hash with the role and content key" do - expect(message.to_hash).to eq({role: "user", content: "Hello, world!"}) + expect(message.to_hash).to eq({role: "user", content: [{type:"text", text: "Hello, world!"}]}) end end @@ -18,7 +18,7 @@ let(:message) { described_class.new(role: "tool", content: "Hello, world!", tool_calls: [], tool_call_id: "123") } it "returns a hash with the tool_call_id key" do - expect(message.to_hash).to eq({role: "tool", content: "Hello, world!", tool_call_id: "123"}) + expect(message.to_hash).to eq({role: "tool", content: [{type:"text", text: "Hello, world!"}], tool_call_id: "123"}) end end @@ -29,10 +29,24 @@ "function" => {"name" => "weather__execute", "arguments" => "{\"input\":\"Saint Petersburg\"}"}} } - let(:message) { described_class.new(role: "assistant", content: "", tool_calls: [tool_call], tool_call_id: nil) } + let(:message) { described_class.new(role: "assistant", tool_calls: [tool_call], tool_call_id: nil) } it "returns a hash with the tool_calls key" do - expect(message.to_hash).to eq({role: "assistant", content: "", tool_calls: [tool_call]}) + expect(message.to_hash).to eq({role: "assistant", tool_calls: [tool_call]}) + end + end + + context "when image_url is present" do + let(:message) { described_class.new(role: "user", content: "Please describe this image", image_url: "https://example.com/image.jpg") } + + it "returns a hash with the image_url key" do + expect(message.to_hash).to eq({ + role: "user", + content: [ + {type:"text", text: "Please describe this image"}, + {type:"image_url", image_url: {url: "https://example.com/image.jpg"}} + ] + }) end end end From feac51cc8ad987f7bc61e6fa52d2220349b28847 Mon Sep 17 00:00:00 2001 From: Andrei Bondarev Date: Mon, 30 Sep 2024 15:47:07 -0400 Subject: [PATCH 2/2] CHANGELOG entry + fixing linter --- CHANGELOG.md | 1 + spec/langchain/assistants/assistant_spec.rb | 10 +++++----- .../assistants/messages/openai_message_spec.rb | 8 ++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73daca6cb..6319c7682 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ ## [Unreleased] +- Assistant can now process image_urls in the messages (currently only for OpenAI) ## [0.16.1] - 2024-09-30 - Deprecate Langchain::LLM::GooglePalm diff --git a/spec/langchain/assistants/assistant_spec.rb b/spec/langchain/assistants/assistant_spec.rb index 95037ac49..1f375636d 100644 --- a/spec/langchain/assistants/assistant_spec.rb +++ b/spec/langchain/assistants/assistant_spec.rb @@ -224,8 +224,8 @@ allow(subject.llm).to receive(:chat) .with( messages: [ - {role: "system", content: [{type:"text", text: instructions}]}, - {role: "user", content: [{type:"text", text: "Please calculate 2+2"}]} + {role: "system", content: [{type: "text", text: instructions}]}, + {role: "user", content: [{type: "text", text: "Please calculate 2+2"}]} ], tools: calculator.class.function_schemas.to_openai_format, tool_choice: "auto" @@ -268,8 +268,8 @@ allow(subject.llm).to receive(:chat) .with( messages: [ - {role: "system", content: [{type:"text", text: instructions}]}, - {role: "user", content: [{type:"text", text:"Please calculate 2+2"}]}, + {role: "system", content: [{type: "text", text: instructions}]}, + {role: "user", content: [{type: "text", text: "Please calculate 2+2"}]}, {role: "assistant", tool_calls: [ { "function" => {"arguments" => "{\"input\":\"2+2\"}", "name" => "langchain_tool_calculator__execute"}, @@ -277,7 +277,7 @@ "type" => "function" } ]}, - {content: [{type:"text", text:"4.0"}], role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"} + {content: [{type: "text", text: "4.0"}], role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"} ], tools: calculator.class.function_schemas.to_openai_format, tool_choice: "auto" diff --git a/spec/langchain/assistants/messages/openai_message_spec.rb b/spec/langchain/assistants/messages/openai_message_spec.rb index 59006b6fd..3ca719a87 100644 --- a/spec/langchain/assistants/messages/openai_message_spec.rb +++ b/spec/langchain/assistants/messages/openai_message_spec.rb @@ -10,7 +10,7 @@ let(:message) { described_class.new(role: "user", content: "Hello, world!", tool_calls: [], tool_call_id: nil) } it "returns a hash with the role and content key" do - expect(message.to_hash).to eq({role: "user", content: [{type:"text", text: "Hello, world!"}]}) + expect(message.to_hash).to eq({role: "user", content: [{type: "text", text: "Hello, world!"}]}) end end @@ -18,7 +18,7 @@ let(:message) { described_class.new(role: "tool", content: "Hello, world!", tool_calls: [], tool_call_id: "123") } it "returns a hash with the tool_call_id key" do - expect(message.to_hash).to eq({role: "tool", content: [{type:"text", text: "Hello, world!"}], tool_call_id: "123"}) + expect(message.to_hash).to eq({role: "tool", content: [{type: "text", text: "Hello, world!"}], tool_call_id: "123"}) end end @@ -43,8 +43,8 @@ expect(message.to_hash).to eq({ role: "user", content: [ - {type:"text", text: "Please describe this image"}, - {type:"image_url", image_url: {url: "https://example.com/image.jpg"}} + {type: "text", text: "Please describe this image"}, + {type: "image_url", image_url: {url: "https://example.com/image.jpg"}} ] }) end