Skip to content

Commit

Permalink
improve rag
Browse files Browse the repository at this point in the history
  • Loading branch information
santiagomed committed Dec 13, 2023
1 parent cfd9cbe commit 4c9b433
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
3 changes: 2 additions & 1 deletion examples/rag/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ anyhow = "1.0.75"
tokio = { version = "1.12.0", features = ["full"] }
clap = "4.4.7"
serde_json = "1.0.108"
env_logger = "0.10.0"
env_logger = "0.10.0"
rand = "0.8.5"
39 changes: 32 additions & 7 deletions examples/rag/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use anyhow::{Context, Result};
use clap::Parser;
use orca::{
llm::{bert::Bert, quantized::Quantized, Embedding},
memory::Buffer,
pipeline::simple::LLMPipeline,
pipeline::Pipeline,
prompt,
Expand All @@ -10,6 +11,7 @@ use orca::{
qdrant::Qdrant,
record::{pdf::Pdf, Spin},
};
use rand::Rng;
use serde_json::json;

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -55,7 +57,7 @@ async fn main() -> Result<()> {

// Use prompt to query Qdrant
let query_embedding = bert.generate_embedding(prompt!(args.prompt)).await?;
let result = qdrant.search(&collection, query_embedding.to_vec()?.clone(), 5, None).await?;
let result = qdrant.search(&collection, query_embedding.to_vec()?.clone(), 1, None).await?;

let prompt_for_model = r#"
{{#chat}}
Expand Down Expand Up @@ -94,18 +96,41 @@ async fn main() -> Result<()> {
});

let mistral = Quantized::new()
.with_model(orca::llm::quantized::Model::L7bChat)
.with_sample_len(7500)
.load_model_from_path("../../weights/llama-2-7b-chat.ggmlv3.q2_K.bin")?
.with_model(orca::llm::quantized::Model::Mistral7bInstruct)
.with_sample_len(4000)
.with_seed(rand::thread_rng().gen_range(0..100))
.load_model_from_path("../../weights/mistral-7b-instruct-v0.1.Q4_K_M.gguf")?
.build_model()?;

let pipe = LLMPipeline::new(&mistral)
.load_template("query", prompt_for_model)?
.load_context(&OrcaContext::new(context)?)?;
.load_context(&OrcaContext::new(context)?)?
.load_memory(Buffer::new());

let response = pipe.execute("query").await?;
let res = pipe.execute("query").await?;

println!("Response: {}", response.content());
println!("\nResponse: {}", res.content());

let stdin = std::io::stdin();
let mut input = String::new();

loop {
println!("Enter your prompt (type 'exit' to quit): ");
input.clear();
stdin.read_line(&mut input)?;
let trimmed_input = input.trim();

// Exit condition
if trimmed_input.eq_ignore_ascii_case("exit") {
break;
}

let pipe = pipe.clone().load_template("query", trimmed_input)?;

let res = pipe.execute("query").await?;

println!("\nResponse: {}", res.content());
}

Ok(())
}

0 comments on commit 4c9b433

Please sign in to comment.