-
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f0623c4
commit 4957225
Showing
21 changed files
with
404 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# frozen_string_literal: true | ||
|
||
require "active_record" | ||
|
||
module Langchain | ||
class Assistant | ||
attr_accessor :id | ||
|
||
alias_method :original_initialize, :initialize | ||
|
||
def initialize(id: nil, **) | ||
@id = id | ||
original_initialize(**) | ||
end | ||
|
||
def save | ||
::ActiveRecord::Base.transaction do | ||
ar_assistant = if id | ||
self.class.find_assistant(id) | ||
else | ||
::Assistant.new | ||
end | ||
|
||
ar_assistant.update!( | ||
instructions: instructions, | ||
tool_choice: tool_choice, | ||
tools: tools.map(&:class).map(&:name) | ||
) | ||
|
||
messages.each do |message| | ||
ar_message = ar_assistant.messages.find_or_initialize_by(id: message.id) | ||
ar_message.update!( | ||
role: message.role, | ||
content: message.content, | ||
tool_calls: message.tool_calls, | ||
tool_call_id: message.tool_call_id | ||
) | ||
message.id = ar_message.id | ||
end | ||
|
||
@id = ar_assistant.id | ||
true | ||
end | ||
end | ||
|
||
# def save | ||
# if @persistence_adapter | ||
# @record = @persistence_adapter.save(self) | ||
# self.id = @record.id | ||
# @record | ||
# else | ||
# warn "No persistence adapter set, cannot save assistant" | ||
# false | ||
# end | ||
# end | ||
|
||
class << self | ||
def find_assistant(id) | ||
::Assistant.find(id) | ||
end | ||
|
||
def load(id) | ||
ar_assistant = find_assistant(id) | ||
|
||
tools = ar_assistant.tools.map { |tool_name| Object.const_get(tool_name).new } | ||
|
||
assistant = Langchain::Assistant.new( | ||
id: ar_assistant.id, | ||
llm: ar_assistant.llm, | ||
tools: tools, | ||
instructions: ar_assistant.instructions, | ||
tool_choice: ar_assistant.tool_choice | ||
) | ||
|
||
ar_assistant.messages.each do |ar_message| | ||
messages = assistant.add_message( | ||
role: ar_message.role, | ||
content: ar_message.content, | ||
tool_calls: ar_message.tool_calls, | ||
tool_call_id: ar_message.tool_call_id | ||
) | ||
messages.last.id = ar_message.id | ||
end | ||
|
||
assistant | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
module Langchain | ||
module Messages | ||
class Base | ||
attr_accessor :id | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
lib/langchainrb_rails/generators/langchainrb_rails/assistant_generator.rb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# frozen_string_literal: true | ||
|
||
require "rails/generators" | ||
require "rails/generators/active_record" | ||
|
||
module LangchainrbRails | ||
module Generators | ||
# | ||
# Usage: | ||
# rails generate langchainrb_rails:assistant --llm=openai | ||
# | ||
class AssistantGenerator < Rails::Generators::Base | ||
include ::ActiveRecord::Generators::Migration | ||
|
||
# TODO: Move constant this to a shared place | ||
LLMS = { | ||
"anthropic" => "Langchain::LLM::Anthropic", | ||
"cohere" => "Langchain::LLM::Cohere", | ||
"google_palm" => "Langchain::LLM::GooglePalm", | ||
"google_gemini" => "Langchain::LLM::GoogleGemini", | ||
"google_vertex_ai" => "Langchain::LLM::GoogleVertexAI", | ||
"hugging_face" => "Langchain::LLM::HuggingFace", | ||
"llama_cpp" => "Langchain::LLM::LlamaCpp", | ||
"mistral_ai" => "Langchain::LLM::MistralAI", | ||
"ollama" => "Langchain::LLM::Ollama", | ||
"openai" => "Langchain::LLM::OpenAI", | ||
"replicate" => "Langchain::LLM::Replicate" | ||
}.freeze | ||
|
||
class_option :llm, | ||
type: :string, | ||
required: true, | ||
default: "openai", | ||
desc: "LLM provider that will be used to generate embeddings and completions", | ||
enum: LLMS.keys | ||
|
||
desc "This generator adds Assistant and Message models and tables to your Rails app" | ||
source_root File.join(__dir__, "templates") | ||
|
||
def copy_migration | ||
migration_template "assistant/migrations/create_assistants.rb", "db/migrate/create_assistants.rb", migration_version: migration_version | ||
migration_template "assistant/migrations/create_messages.rb", "db/migrate/create_messages.rb", migration_version: migration_version | ||
end | ||
|
||
def create_model_file | ||
template "assistant/models/assistant.rb", "app/models/assistant.rb" | ||
template "assistant/models/message.rb", "app/models/message.rb" | ||
end | ||
|
||
def migration_version | ||
"[#{::ActiveRecord::VERSION::MAJOR}.#{::ActiveRecord::VERSION::MINOR}]" | ||
end | ||
|
||
# TODO: Depending on the LLM provider, we may need to add additional gems | ||
# def add_to_gemfile | ||
# end | ||
|
||
private | ||
|
||
# @return [String] LLM provider to use | ||
def llm | ||
options["llm"] | ||
end | ||
|
||
# @return [Langchain::LLM::*] LLM class | ||
def llm_class | ||
Langchain::LLM.const_get(LLMS[llm]) | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.