Skip to content

Commit

Permalink
Langchain::Assistant when using OpenAI accepts a message with image_u…
Browse files Browse the repository at this point in the history
…rl (#799)

* Langchain::Assistant when using OpenAI accept a message with image_url

* CHANGELOG entry + fixing linter
  • Loading branch information
andreibondarev authored Sep 30, 2024
1 parent bdafbc1 commit 33ad323
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 35 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## [Unreleased]
- Assistant can now process image_urls in the messages (currently only for OpenAI)

## [0.16.1] - 2024-09-30
- Deprecate Langchain::LLM::GooglePalm
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,12 @@ assistant = Langchain::Assistant.new(
# Add a user message and run the assistant
assistant.add_message_and_run!(content: "What's the latest news about AI?")

# Supply an image to the assistant
assistant.add_message_and_run!(
content: "Show me a picture of a cat",
image: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
)

# Access the conversation thread
messages = assistant.messages

Expand Down
42 changes: 26 additions & 16 deletions lib/langchain/assistants/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ def initialize(

# Add a user message to the messages array
#
# @param content [String] The content of the message
# @param role [String] The role attribute of the message. Default: "user"
# @param content [String] The content of the message
# @param image_url [String] The URL of the image to include in the message
# @param tool_calls [Array<Hash>] The tool calls to include in the message
# @param tool_call_id [String] The ID of the tool call to include in the message
# @return [Array<Langchain::Message>] The messages
def add_message(content: nil, role: "user", tool_calls: [], tool_call_id: nil)
message = build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
def add_message(role: "user", content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
message = build_message(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)

# Call the callback with the message
add_message_callback.call(message) if add_message_callback # rubocop:disable Style/SafeNavigation
Expand Down Expand Up @@ -145,17 +146,17 @@ def run!
# @param content [String] The content of the message
# @param auto_tool_execution [Boolean] Whether or not to automatically run tools
# @return [Array<Langchain::Message>] The messages
def add_message_and_run(content:, auto_tool_execution: false)
add_message(content: content, role: "user")
def add_message_and_run(content: nil, image_url: nil, auto_tool_execution: false)
add_message(content: content, image_url: image_url, role: "user")
run(auto_tool_execution: auto_tool_execution)
end

# Add a user message and run the assistant with automatic tool execution
#
# @param content [String] The content of the message
# @return [Array<Langchain::Message>] The messages
def add_message_and_run!(content:)
add_message_and_run(content: content, auto_tool_execution: true)
def add_message_and_run!(content: nil, image_url: nil)
add_message_and_run(content: content, image_url: image_url, auto_tool_execution: true)
end

# Submit tool output
Expand Down Expand Up @@ -388,11 +389,12 @@ def run_tools(tool_calls)
#
# @param role [String] The role of the message
# @param content [String] The content of the message
# @param image_url [String] The URL of the image to include in the message
# @param tool_calls [Array<Hash>] The tool calls to include in the message
# @param tool_call_id [String] The ID of the tool call to include in the message
# @return [Langchain::Message] The Message object
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
@llm_adapter.build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
@llm_adapter.build_message(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

# Increment the tokens count based on the last interaction with the LLM
Expand Down Expand Up @@ -443,7 +445,7 @@ def extract_tool_call_args(tool_call:)
raise NotImplementedError, "Subclasses must implement extract_tool_call_args"
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
raise NotImplementedError, "Subclasses must implement build_message"
end
end
Expand All @@ -457,7 +459,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
params
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Ollama currently" if image_url

Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

Expand Down Expand Up @@ -506,8 +510,8 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
params
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
Langchain::Messages::OpenAIMessage.new(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

# Extract the tool call information from the OpenAI tool call hash
Expand Down Expand Up @@ -564,7 +568,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
params
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by MistralAI currently" if image_url

Langchain::Messages::MistralAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

Expand Down Expand Up @@ -623,7 +629,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
params
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Google Gemini" if image_url

Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

Expand Down Expand Up @@ -676,7 +684,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
params
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Anthropic currently" if image_url

Langchain::Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

Expand Down
6 changes: 5 additions & 1 deletion lib/langchain/assistants/messages/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
module Langchain
module Messages
class Base
attr_reader :role, :content, :tool_calls, :tool_call_id
attr_reader :role,
:content,
:image_url,
:tool_calls,
:tool_call_id

# Check if the message came from a user
#
Expand Down
45 changes: 37 additions & 8 deletions lib/langchain/assistants/messages/openai_message.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,25 @@ class OpenAIMessage < Base

# Initialize a new OpenAI message
#
# @param [String] The role of the message
# @param [String] The content of the message
# @param [Array<Hash>] The tool calls made in the message
# @param [String] The ID of the tool call
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
# @param role [String] The role of the message
# @param content [String] The content of the message
# @param image_url [String] The URL of the image
# @param tool_calls [Array<Hash>] The tool calls made in the message
# @param tool_call_id [String] The ID of the tool call
def initialize(
role:,
content: nil,
image_url: nil,
tool_calls: [],
tool_call_id: nil
)
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }

@role = role
# Some Tools return content as a JSON hence `.to_s`
@content = content.to_s
@image_url = image_url
@tool_calls = tool_calls
@tool_call_id = tool_call_id
end
Expand All @@ -43,9 +51,30 @@ def llm?
def to_hash
{}.tap do |h|
h[:role] = role
h[:content] = content if content # Content is nil for tool calls
h[:tool_calls] = tool_calls if tool_calls.any?
h[:tool_call_id] = tool_call_id if tool_call_id

if tool_calls.any?
h[:tool_calls] = tool_calls
else
h[:tool_call_id] = tool_call_id if tool_call_id

h[:content] = []

if content && !content.empty?
h[:content] << {
type: "text",
text: content
}
end

if image_url
h[:content] << {
type: "image_url",
image_url: {
url: image_url
}
}
end
end
end
end

Expand Down
25 changes: 19 additions & 6 deletions spec/langchain/assistants/assistant_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@
expect(subject.messages.first.content).to eq("hello")
end

it "adds a message with image_url" do
message_with_image = {role: "user", content: "hello", image_url: "https://example.com/image.jpg"}
subject = described_class.new(llm: llm, messages: [])

expect {
subject.add_message(**message_with_image)
}.to change { subject.messages.count }.from(0).to(1)
expect(subject.messages.first).to be_a(Langchain::Messages::OpenAIMessage)
expect(subject.messages.first.role).to eq("user")
expect(subject.messages.first.content).to eq("hello")
expect(subject.messages.first.image_url).to eq("https://example.com/image.jpg")
end

it "calls the add_message_callback with the message" do
callback = double("callback", call: true)
subject = described_class.new(llm: llm, messages: [], add_message_callback: callback)
Expand Down Expand Up @@ -211,8 +224,8 @@
allow(subject.llm).to receive(:chat)
.with(
messages: [
{role: "system", content: instructions},
{role: "user", content: "Please calculate 2+2"}
{role: "system", content: [{type: "text", text: instructions}]},
{role: "user", content: [{type: "text", text: "Please calculate 2+2"}]}
],
tools: calculator.class.function_schemas.to_openai_format,
tool_choice: "auto"
Expand Down Expand Up @@ -255,16 +268,16 @@
allow(subject.llm).to receive(:chat)
.with(
messages: [
{role: "system", content: instructions},
{role: "user", content: "Please calculate 2+2"},
{role: "assistant", content: "", tool_calls: [
{role: "system", content: [{type: "text", text: instructions}]},
{role: "user", content: [{type: "text", text: "Please calculate 2+2"}]},
{role: "assistant", tool_calls: [
{
"function" => {"arguments" => "{\"input\":\"2+2\"}", "name" => "langchain_tool_calculator__execute"},
"id" => "call_9TewGANaaIjzY31UCpAAGLeV",
"type" => "function"
}
]},
{content: "4.0", role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"}
{content: [{type: "text", text: "4.0"}], role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"}
],
tools: calculator.class.function_schemas.to_openai_format,
tool_choice: "auto"
Expand Down
22 changes: 18 additions & 4 deletions spec/langchain/assistants/messages/openai_message_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
let(:message) { described_class.new(role: "user", content: "Hello, world!", tool_calls: [], tool_call_id: nil) }

it "returns a hash with the role and content key" do
expect(message.to_hash).to eq({role: "user", content: "Hello, world!"})
expect(message.to_hash).to eq({role: "user", content: [{type: "text", text: "Hello, world!"}]})
end
end

context "when tool_call_id is not nil" do
let(:message) { described_class.new(role: "tool", content: "Hello, world!", tool_calls: [], tool_call_id: "123") }

it "returns a hash with the tool_call_id key" do
expect(message.to_hash).to eq({role: "tool", content: "Hello, world!", tool_call_id: "123"})
expect(message.to_hash).to eq({role: "tool", content: [{type: "text", text: "Hello, world!"}], tool_call_id: "123"})
end
end

Expand All @@ -29,10 +29,24 @@
"function" => {"name" => "weather__execute", "arguments" => "{\"input\":\"Saint Petersburg\"}"}}
}

let(:message) { described_class.new(role: "assistant", content: "", tool_calls: [tool_call], tool_call_id: nil) }
let(:message) { described_class.new(role: "assistant", tool_calls: [tool_call], tool_call_id: nil) }

it "returns a hash with the tool_calls key" do
expect(message.to_hash).to eq({role: "assistant", content: "", tool_calls: [tool_call]})
expect(message.to_hash).to eq({role: "assistant", tool_calls: [tool_call]})
end
end

context "when image_url is present" do
let(:message) { described_class.new(role: "user", content: "Please describe this image", image_url: "https://example.com/image.jpg") }

it "returns a hash with the image_url key" do
expect(message.to_hash).to eq({
role: "user",
content: [
{type: "text", text: "Please describe this image"},
{type: "image_url", image_url: {url: "https://example.com/image.jpg"}}
]
})
end
end
end
Expand Down

0 comments on commit 33ad323

Please sign in to comment.