Skip to content

Commit

Permalink
Assistant geneerator
Browse files Browse the repository at this point in the history
  • Loading branch information
andreibondarev committed Sep 19, 2024
1 parent f0623c4 commit 4957225
Show file tree
Hide file tree
Showing 21 changed files with 404 additions and 35 deletions.
32 changes: 10 additions & 22 deletions Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ PATH
remote: .
specs:
langchainrb_rails (0.1.11)
langchainrb (>= 0.7, < 0.15)
langchainrb (>= 0.7, < 0.17)

GEM
remote: https://rubygems.org/
Expand Down Expand Up @@ -81,10 +81,10 @@ GEM
minitest (>= 5.1)
mutex_m
tzinfo (~> 2.0)
addressable (2.8.6)
public_suffix (>= 2.0.2, < 6.0)
addressable (2.8.7)
public_suffix (>= 2.0.2, < 7.0)
ast (2.4.2)
baran (0.1.11)
baran (0.1.12)
base64 (0.2.0)
bigdecimal (3.1.8)
brakeman (6.1.2)
Expand All @@ -95,7 +95,6 @@ GEM
thor (~> 1.0)
byebug (11.1.3)
coderay (1.1.3)
colorize (1.1.0)
concurrent-ruby (1.3.4)
connection_pool (2.4.1)
crass (1.0.6)
Expand All @@ -108,24 +107,21 @@ GEM
railties (>= 3.0.0)
globalid (1.2.1)
activesupport (>= 6.1)
i18n (1.14.5)
i18n (1.14.6)
concurrent-ruby (~> 1.0)
io-console (0.7.2)
irb (1.14.0)
rdoc (>= 4.0.0)
reline (>= 0.4.2)
json (2.7.2)
json-schema (4.3.0)
json-schema (4.3.1)
addressable (>= 2.8)
langchainrb (0.11.4)
activesupport (>= 7.0.8)
langchainrb (0.16.0)
baran (~> 0.1.9)
colorize (~> 1.1.0)
json-schema (~> 4)
matrix
pragmatic_segmenter (~> 0.3.0)
tiktoken_ruby (~> 0.0.8)
to_bool (~> 2.0.0)
rainbow (~> 3.1.0)
zeitwerk (~> 2.5)
language_server-protocol (3.17.0.3)
lint_roller (1.1.0)
Expand Down Expand Up @@ -165,8 +161,7 @@ GEM
parser (3.3.4.0)
ast (~> 2.4.1)
racc
pragmatic_segmenter (0.3.23)
unicode
pragmatic_segmenter (0.3.24)
pry (0.14.2)
coderay (~> 1.1)
method_source (~> 1.0)
Expand All @@ -175,7 +170,7 @@ GEM
pry (>= 0.13, < 0.15)
psych (5.1.2)
stringio
public_suffix (5.0.5)
public_suffix (6.0.1)
racc (1.8.1)
rack (3.1.7)
rack-session (2.0.0)
Expand Down Expand Up @@ -216,7 +211,6 @@ GEM
zeitwerk (~> 2.6)
rainbow (3.1.1)
rake (13.2.1)
rb_sys (0.9.96)
rdoc (6.7.0)
psych (>= 4.0.0)
regexp_parser (2.9.2)
Expand Down Expand Up @@ -271,15 +265,9 @@ GEM
stringio (3.1.1)
strscan (3.1.0)
thor (1.3.2)
tiktoken_ruby (0.0.8)
rb_sys (>= 0.9.86)
tiktoken_ruby (0.0.8-x86_64-darwin)
tiktoken_ruby (0.0.8-x86_64-linux)
timeout (0.4.1)
to_bool (2.0.0)
tzinfo (2.0.6)
concurrent-ruby (~> 1.0)
unicode (0.4.4.5)
unicode-display_width (2.5.0)
webrick (1.8.1)
websocket-driver (0.7.6)
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,8 @@ prompt = Prompt.create!(template: "Tell me a {adjective} joke about {subject}.")
prompt.render(adjective: "funny", subject: "elephants")
# => "Tell me a funny joke about elephants."
```

### Assistant Generator - adds assistant capabilities to your ActiveRecord model
```bash
rails generate langchainrb_rails:assistant
```
2 changes: 1 addition & 1 deletion langchainrb_rails.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Gem::Specification.new do |spec|
spec.executables = spec.files.grep(%r{\Aexe/}) { |f| File.basename(f) }
spec.require_paths = ["lib"]

spec.add_dependency "langchainrb", ">= 0.7", "< 0.15"
spec.add_dependency "langchainrb", ">= 0.7", "< 0.17"

spec.add_development_dependency "pry-byebug", "~> 3.10.0"
spec.add_development_dependency "yard", "~> 0.9.34"
Expand Down
89 changes: 89 additions & 0 deletions lib/langchainrb_overrides/assistant.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# frozen_string_literal: true

require "active_record"

module Langchain
class Assistant
attr_accessor :id

alias_method :original_initialize, :initialize

def initialize(id: nil, **)
@id = id
original_initialize(**)
end

def save
::ActiveRecord::Base.transaction do
ar_assistant = if id
self.class.find_assistant(id)
else
::Assistant.new
end

ar_assistant.update!(
instructions: instructions,
tool_choice: tool_choice,
tools: tools.map(&:class).map(&:name)
)

messages.each do |message|
ar_message = ar_assistant.messages.find_or_initialize_by(id: message.id)
ar_message.update!(
role: message.role,
content: message.content,
tool_calls: message.tool_calls,
tool_call_id: message.tool_call_id
)
message.id = ar_message.id
end

@id = ar_assistant.id
true
end
end

# def save
# if @persistence_adapter
# @record = @persistence_adapter.save(self)
# self.id = @record.id
# @record
# else
# warn "No persistence adapter set, cannot save assistant"
# false
# end
# end

class << self
def find_assistant(id)
::Assistant.find(id)
end

def load(id)
ar_assistant = find_assistant(id)

tools = ar_assistant.tools.map { |tool_name| Object.const_get(tool_name).new }

assistant = Langchain::Assistant.new(
id: ar_assistant.id,
llm: ar_assistant.llm,
tools: tools,
instructions: ar_assistant.instructions,
tool_choice: ar_assistant.tool_choice
)

ar_assistant.messages.each do |ar_message|
messages = assistant.add_message(
role: ar_message.role,
content: ar_message.content,
tool_calls: ar_message.tool_calls,
tool_call_id: ar_message.tool_call_id
)
messages.last.id = ar_message.id
end

assistant
end
end
end
end
7 changes: 7 additions & 0 deletions lib/langchainrb_overrides/message.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module Langchain
module Messages
class Base
attr_accessor :id
end
end
end
9 changes: 7 additions & 2 deletions lib/langchainrb_rails.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
require "forwardable"
require "langchain"
require "rails"
require_relative "langchainrb_rails/version"
require "langchainrb_rails/railtie"

require "langchainrb_rails/config"
require "langchainrb_rails/prompting"
require "langchainrb_rails/railtie"
require "langchainrb_rails/version"

require_relative "langchainrb_overrides/vectorsearch/pgvector"
require_relative "langchainrb_overrides/assistant"
require_relative "langchainrb_overrides/message"

module LangchainrbRails
class Error < StandardError; end
Expand All @@ -18,6 +22,7 @@ module ActiveRecord

module Generators
autoload :BaseGenerator, "langchainrb_rails/generators/langchainrb_rails/base_generator"
autoload :AssistantGenerator, "langchainrb_rails/generators/langchainrb_rails/assistant_generator"
autoload :ChromaGenerator, "langchainrb_rails/generators/langchainrb_rails/chroma_generator"
autoload :PgvectorGenerator, "langchainrb_rails/generators/langchainrb_rails/pgvector_generator"
autoload :QdrantGenerator, "langchainrb_rails/generators/langchainrb_rails/qdrant_generator"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# frozen_string_literal: true

require "rails/generators"
require "rails/generators/active_record"

module LangchainrbRails
module Generators
#
# Usage:
# rails generate langchainrb_rails:assistant --llm=openai
#
class AssistantGenerator < Rails::Generators::Base
include ::ActiveRecord::Generators::Migration

# TODO: Move constant this to a shared place
LLMS = {
"anthropic" => "Langchain::LLM::Anthropic",
"cohere" => "Langchain::LLM::Cohere",
"google_palm" => "Langchain::LLM::GooglePalm",
"google_gemini" => "Langchain::LLM::GoogleGemini",
"google_vertex_ai" => "Langchain::LLM::GoogleVertexAI",
"hugging_face" => "Langchain::LLM::HuggingFace",
"llama_cpp" => "Langchain::LLM::LlamaCpp",
"mistral_ai" => "Langchain::LLM::MistralAI",
"ollama" => "Langchain::LLM::Ollama",
"openai" => "Langchain::LLM::OpenAI",
"replicate" => "Langchain::LLM::Replicate"
}.freeze

class_option :llm,
type: :string,
required: true,
default: "openai",
desc: "LLM provider that will be used to generate embeddings and completions",
enum: LLMS.keys

desc "This generator adds Assistant and Message models and tables to your Rails app"
source_root File.join(__dir__, "templates")

def copy_migration
migration_template "assistant/migrations/create_assistants.rb", "db/migrate/create_assistants.rb", migration_version: migration_version
migration_template "assistant/migrations/create_messages.rb", "db/migrate/create_messages.rb", migration_version: migration_version
end

def create_model_file
template "assistant/models/assistant.rb", "app/models/assistant.rb"
template "assistant/models/message.rb", "app/models/message.rb"
end

def migration_version
"[#{::ActiveRecord::VERSION::MAJOR}.#{::ActiveRecord::VERSION::MINOR}]"
end

# TODO: Depending on the LLM provider, we may need to add additional gems
# def add_to_gemfile
# end

private

# @return [String] LLM provider to use
def llm
options["llm"]
end

# @return [Langchain::LLM::*] LLM class
def llm_class
Langchain::LLM.const_get(LLMS[llm])
end
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,39 @@ module Generators
class BaseGenerator < Rails::Generators::Base
include ::ActiveRecord::Generators::Migration

class_option :model, type: :string, required: true, desc: "ActiveRecord Model to add vectorsearch to", aliases: "-m"
class_option :llm, type: :string, required: true, desc: "LLM provider that will be used to generate embeddings and completions"

# Available LLM providers to be passed in as --llm option
LLMS = {
"anthropic" => "Langchain::LLM::Anthropic",
"cohere" => "Langchain::LLM::Cohere",
"google_palm" => "Langchain::LLM::GooglePalm",
"google_gemini" => "Langchain::LLM::GoogleGemini",
"google_vertex_ai" => "Langchain::LLM::GoogleVertexAI",
"hugging_face" => "Langchain::LLM::HuggingFace",
"llama_cpp" => "Langchain::LLM::LlamaCpp",
"mistral_ai" => "Langchain::LLM::MistralAI",
"ollama" => "Langchain::LLM::Ollama",
"openai" => "Langchain::LLM::OpenAI",
"replicate" => "Langchain::LLM::Replicate"
}.freeze

class_option :model,
type: :string,
required: true,
aliases: "-m",
desc: "ActiveRecord Model to add vectorsearch to"

class_option :llm,
type: :string,
required: true,
default: "openai",
desc: "LLM provider that will be used to generate embeddings and completions",
enum: LLMS.keys

# Run bundle install after running the generator
def after_generate
run "bundle install"
end

def post_install_message
say "Please do the following to start Q&A with your #{model_name} records:", :green
say "1. Run `bundle install` to install the new gems."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def add_to_model
end

# Adds `chroma-db` gem to the Gemfile
# TODO: Can we automatically run `bundle install`?
def add_to_gemfile
gem "chroma-db", version: "~> 0.6.0"
gem "chroma-db"
end

private
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def add_to_model
end

# Adds `pinecone` gem to the Gemfile
# TODO: Can we automatically run `bundle install`?
def add_to_gemfile
gem "pinecone", version: "~> 0.1.6"
gem "pinecone"
end

private
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ class PromptGenerator < Rails::Generators::Base

source_root File.join(__dir__, "templates")

def create_prompt_model
def create_model_file
template "prompt_model.rb", "app/models/prompt.rb"
migration_template "create_prompts.rb", "db/migrate/create_prompts.rb"
end

def copy_migration
migration_template "create_prompts.rb", "db/migrate/create_prompts.rb", migration_version: migration_version
end

def migration_version
Expand Down
Loading

0 comments on commit 4957225

Please sign in to comment.