From 4c9b433039112adf3ec799012423dc79d5d4b53d Mon Sep 17 00:00:00 2001 From: Santiago Medina Date: Tue, 12 Dec 2023 17:35:20 -0800 Subject: [PATCH] improve rag --- examples/rag/Cargo.toml | 3 ++- examples/rag/src/main.rs | 39 ++++++++++++++++++++++++++++++++------- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/examples/rag/Cargo.toml b/examples/rag/Cargo.toml index e4b98c7..6fd6ee6 100644 --- a/examples/rag/Cargo.toml +++ b/examples/rag/Cargo.toml @@ -11,4 +11,5 @@ anyhow = "1.0.75" tokio = { version = "1.12.0", features = ["full"] } clap = "4.4.7" serde_json = "1.0.108" -env_logger = "0.10.0" \ No newline at end of file +env_logger = "0.10.0" +rand = "0.8.5" diff --git a/examples/rag/src/main.rs b/examples/rag/src/main.rs index 8d7af14..9021219 100644 --- a/examples/rag/src/main.rs +++ b/examples/rag/src/main.rs @@ -2,6 +2,7 @@ use anyhow::{Context, Result}; use clap::Parser; use orca::{ llm::{bert::Bert, quantized::Quantized, Embedding}, + memory::Buffer, pipeline::simple::LLMPipeline, pipeline::Pipeline, prompt, @@ -10,6 +11,7 @@ use orca::{ qdrant::Qdrant, record::{pdf::Pdf, Spin}, }; +use rand::Rng; use serde_json::json; #[derive(Parser, Debug)] @@ -55,7 +57,7 @@ async fn main() -> Result<()> { // Use prompt to query Qdrant let query_embedding = bert.generate_embedding(prompt!(args.prompt)).await?; - let result = qdrant.search(&collection, query_embedding.to_vec()?.clone(), 5, None).await?; + let result = qdrant.search(&collection, query_embedding.to_vec()?.clone(), 1, None).await?; let prompt_for_model = r#" {{#chat}} @@ -94,18 +96,41 @@ async fn main() -> Result<()> { }); let mistral = Quantized::new() - .with_model(orca::llm::quantized::Model::L7bChat) - .with_sample_len(7500) - .load_model_from_path("../../weights/llama-2-7b-chat.ggmlv3.q2_K.bin")? + .with_model(orca::llm::quantized::Model::Mistral7bInstruct) + .with_sample_len(4000) + .with_seed(rand::thread_rng().gen_range(0..100)) + .load_model_from_path("../../weights/mistral-7b-instruct-v0.1.Q4_K_M.gguf")? .build_model()?; let pipe = LLMPipeline::new(&mistral) .load_template("query", prompt_for_model)? - .load_context(&OrcaContext::new(context)?)?; + .load_context(&OrcaContext::new(context)?)? + .load_memory(Buffer::new()); - let response = pipe.execute("query").await?; + let res = pipe.execute("query").await?; - println!("Response: {}", response.content()); + println!("\nResponse: {}", res.content()); + + let stdin = std::io::stdin(); + let mut input = String::new(); + + loop { + println!("Enter your prompt (type 'exit' to quit): "); + input.clear(); + stdin.read_line(&mut input)?; + let trimmed_input = input.trim(); + + // Exit condition + if trimmed_input.eq_ignore_ascii_case("exit") { + break; + } + + let pipe = pipe.clone().load_template("query", trimmed_input)?; + + let res = pipe.execute("query").await?; + + println!("\nResponse: {}", res.content()); + } Ok(()) }