Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Langchain::Assistant when using OpenAI accepts a message with image_url #799

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## [Unreleased]
- Assistant can now process image_urls in the messages (currently only for OpenAI)

## [0.16.1] - 2024-09-30
- Deprecate Langchain::LLM::GooglePalm
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,12 @@ assistant = Langchain::Assistant.new(
# Add a user message and run the assistant
assistant.add_message_and_run!(content: "What's the latest news about AI?")

# Supply an image to the assistant
assistant.add_message_and_run!(
content: "Show me a picture of a cat",
image: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
)

# Access the conversation thread
messages = assistant.messages

Expand Down
42 changes: 26 additions & 16 deletions lib/langchain/assistants/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ def initialize(

# Add a user message to the messages array
#
# @param content [String] The content of the message
# @param role [String] The role attribute of the message. Default: "user"
# @param content [String] The content of the message
# @param image_url [String] The URL of the image to include in the message
# @param tool_calls [Array<Hash>] The tool calls to include in the message
# @param tool_call_id [String] The ID of the tool call to include in the message
# @return [Array<Langchain::Message>] The messages
def add_message(content: nil, role: "user", tool_calls: [], tool_call_id: nil)
message = build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
def add_message(role: "user", content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
message = build_message(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)

# Call the callback with the message
add_message_callback.call(message) if add_message_callback # rubocop:disable Style/SafeNavigation
Expand Down Expand Up @@ -145,17 +146,17 @@ def run!
# @param content [String] The content of the message
# @param auto_tool_execution [Boolean] Whether or not to automatically run tools
# @return [Array<Langchain::Message>] The messages
def add_message_and_run(content:, auto_tool_execution: false)
add_message(content: content, role: "user")
def add_message_and_run(content: nil, image_url: nil, auto_tool_execution: false)
add_message(content: content, image_url: image_url, role: "user")
run(auto_tool_execution: auto_tool_execution)
end

# Add a user message and run the assistant with automatic tool execution
#
# @param content [String] The content of the message
# @return [Array<Langchain::Message>] The messages
def add_message_and_run!(content:)
add_message_and_run(content: content, auto_tool_execution: true)
def add_message_and_run!(content: nil, image_url: nil)
add_message_and_run(content: content, image_url: image_url, auto_tool_execution: true)
end

# Submit tool output
Expand Down Expand Up @@ -388,11 +389,12 @@ def run_tools(tool_calls)
#
# @param role [String] The role of the message
# @param content [String] The content of the message
# @param image_url [String] The URL of the image to include in the message
# @param tool_calls [Array<Hash>] The tool calls to include in the message
# @param tool_call_id [String] The ID of the tool call to include in the message
# @return [Langchain::Message] The Message object
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
@llm_adapter.build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
@llm_adapter.build_message(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

# Increment the tokens count based on the last interaction with the LLM
Expand Down Expand Up @@ -443,7 +445,7 @@ def extract_tool_call_args(tool_call:)
raise NotImplementedError, "Subclasses must implement extract_tool_call_args"
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
raise NotImplementedError, "Subclasses must implement build_message"
end
end
Expand All @@ -457,7 +459,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
params
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Ollama currently" if image_url

Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

Expand Down Expand Up @@ -506,8 +510,8 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
params
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
Langchain::Messages::OpenAIMessage.new(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

# Extract the tool call information from the OpenAI tool call hash
Expand Down Expand Up @@ -564,7 +568,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
params
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by MistralAI currently" if image_url

Langchain::Messages::MistralAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

Expand Down Expand Up @@ -623,7 +629,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
params
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Google Gemini" if image_url

Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

Expand Down Expand Up @@ -676,7 +684,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
params
end

def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Anthropic currently" if image_url

Langchain::Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

Expand Down
6 changes: 5 additions & 1 deletion lib/langchain/assistants/messages/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
module Langchain
module Messages
class Base
attr_reader :role, :content, :tool_calls, :tool_call_id
attr_reader :role,
:content,
:image_url,
:tool_calls,
:tool_call_id

# Check if the message came from a user
#
Expand Down
45 changes: 37 additions & 8 deletions lib/langchain/assistants/messages/openai_message.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,25 @@ class OpenAIMessage < Base

# Initialize a new OpenAI message
#
# @param [String] The role of the message
# @param [String] The content of the message
# @param [Array<Hash>] The tool calls made in the message
# @param [String] The ID of the tool call
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
# @param role [String] The role of the message
# @param content [String] The content of the message
# @param image_url [String] The URL of the image
# @param tool_calls [Array<Hash>] The tool calls made in the message
# @param tool_call_id [String] The ID of the tool call
def initialize(
role:,
content: nil,
image_url: nil,
tool_calls: [],
tool_call_id: nil
)
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }

@role = role
# Some Tools return content as a JSON hence `.to_s`
@content = content.to_s
@image_url = image_url
@tool_calls = tool_calls
@tool_call_id = tool_call_id
end
Expand All @@ -43,9 +51,30 @@ def llm?
def to_hash
{}.tap do |h|
h[:role] = role
h[:content] = content if content # Content is nil for tool calls
h[:tool_calls] = tool_calls if tool_calls.any?
h[:tool_call_id] = tool_call_id if tool_call_id

if tool_calls.any?
h[:tool_calls] = tool_calls
else
h[:tool_call_id] = tool_call_id if tool_call_id

h[:content] = []

if content && !content.empty?
h[:content] << {
type: "text",
text: content
}
end

if image_url
h[:content] << {
type: "image_url",
image_url: {
url: image_url
}
}
end
end
end
end

Expand Down
25 changes: 19 additions & 6 deletions spec/langchain/assistants/assistant_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@
expect(subject.messages.first.content).to eq("hello")
end

it "adds a message with image_url" do
message_with_image = {role: "user", content: "hello", image_url: "https://example.com/image.jpg"}
subject = described_class.new(llm: llm, messages: [])

expect {
subject.add_message(**message_with_image)
}.to change { subject.messages.count }.from(0).to(1)
expect(subject.messages.first).to be_a(Langchain::Messages::OpenAIMessage)
expect(subject.messages.first.role).to eq("user")
expect(subject.messages.first.content).to eq("hello")
expect(subject.messages.first.image_url).to eq("https://example.com/image.jpg")
end

it "calls the add_message_callback with the message" do
callback = double("callback", call: true)
subject = described_class.new(llm: llm, messages: [], add_message_callback: callback)
Expand Down Expand Up @@ -211,8 +224,8 @@
allow(subject.llm).to receive(:chat)
.with(
messages: [
{role: "system", content: instructions},
{role: "user", content: "Please calculate 2+2"}
{role: "system", content: [{type: "text", text: instructions}]},
{role: "user", content: [{type: "text", text: "Please calculate 2+2"}]}
],
tools: calculator.class.function_schemas.to_openai_format,
tool_choice: "auto"
Expand Down Expand Up @@ -255,16 +268,16 @@
allow(subject.llm).to receive(:chat)
.with(
messages: [
{role: "system", content: instructions},
{role: "user", content: "Please calculate 2+2"},
{role: "assistant", content: "", tool_calls: [
{role: "system", content: [{type: "text", text: instructions}]},
{role: "user", content: [{type: "text", text: "Please calculate 2+2"}]},
{role: "assistant", tool_calls: [
{
"function" => {"arguments" => "{\"input\":\"2+2\"}", "name" => "langchain_tool_calculator__execute"},
"id" => "call_9TewGANaaIjzY31UCpAAGLeV",
"type" => "function"
}
]},
{content: "4.0", role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"}
{content: [{type: "text", text: "4.0"}], role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"}
],
tools: calculator.class.function_schemas.to_openai_format,
tool_choice: "auto"
Expand Down
22 changes: 18 additions & 4 deletions spec/langchain/assistants/messages/openai_message_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
let(:message) { described_class.new(role: "user", content: "Hello, world!", tool_calls: [], tool_call_id: nil) }

it "returns a hash with the role and content key" do
expect(message.to_hash).to eq({role: "user", content: "Hello, world!"})
expect(message.to_hash).to eq({role: "user", content: [{type: "text", text: "Hello, world!"}]})
end
end

context "when tool_call_id is not nil" do
let(:message) { described_class.new(role: "tool", content: "Hello, world!", tool_calls: [], tool_call_id: "123") }

it "returns a hash with the tool_call_id key" do
expect(message.to_hash).to eq({role: "tool", content: "Hello, world!", tool_call_id: "123"})
expect(message.to_hash).to eq({role: "tool", content: [{type: "text", text: "Hello, world!"}], tool_call_id: "123"})
end
end

Expand All @@ -29,10 +29,24 @@
"function" => {"name" => "weather__execute", "arguments" => "{\"input\":\"Saint Petersburg\"}"}}
}

let(:message) { described_class.new(role: "assistant", content: "", tool_calls: [tool_call], tool_call_id: nil) }
let(:message) { described_class.new(role: "assistant", tool_calls: [tool_call], tool_call_id: nil) }

it "returns a hash with the tool_calls key" do
expect(message.to_hash).to eq({role: "assistant", content: "", tool_calls: [tool_call]})
expect(message.to_hash).to eq({role: "assistant", tool_calls: [tool_call]})
end
end

context "when image_url is present" do
let(:message) { described_class.new(role: "user", content: "Please describe this image", image_url: "https://example.com/image.jpg") }

it "returns a hash with the image_url key" do
expect(message.to_hash).to eq({
role: "user",
content: [
{type: "text", text: "Please describe this image"},
{type: "image_url", image_url: {url: "https://example.com/image.jpg"}}
]
})
end
end
end
Expand Down