From 785c566839298a28e353cc225bc3959315907f9a Mon Sep 17 00:00:00 2001 From: Santiago Medina Date: Tue, 12 Dec 2023 17:35:36 -0800 Subject: [PATCH] improve rag --- orca-core/src/llm/quantized.rs | 8 +++++++- orca-core/src/memory.rs | 4 ++++ orca-core/src/pipeline/simple.rs | 1 + orca-core/src/prompt/mod.rs | 3 ++- 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/orca-core/src/llm/quantized.rs b/orca-core/src/llm/quantized.rs index a636d0f..3f75a58 100644 --- a/orca-core/src/llm/quantized.rs +++ b/orca-core/src/llm/quantized.rs @@ -122,11 +122,17 @@ impl Quantized { self.sample_len = sample_len; self } + pub fn with_model(mut self, model: Model) -> Self { self.which = model; self } + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } + fn tokenizer(&self) -> anyhow::Result { let tokenizer_path = match &self.tokenizer { Some(config) => std::path::PathBuf::from(config), @@ -314,7 +320,7 @@ impl LLM for Quantized { Quantized::format_chat_prompt(prompt.to_chat()?) }; - log::info!("prompt:\n{}", &prompt); + log::debug!("prompt:\n{}", &prompt); let mut result = String::new(); let tokens = tokenizer.encode(prompt, true).map_err(anyhow::Error::msg)?; if log::log_enabled!(log::Level::Debug) { diff --git a/orca-core/src/memory.rs b/orca-core/src/memory.rs index 2b1ddd3..b43edd0 100644 --- a/orca-core/src/memory.rs +++ b/orca-core/src/memory.rs @@ -84,6 +84,10 @@ impl ChatBuffer { memory: ChatPrompt(Vec::new()), } } + + pub fn from_chat(chat: &ChatPrompt) -> Self { + Self { memory: chat.clone() } + } } impl Memory for ChatBuffer { diff --git a/orca-core/src/pipeline/simple.rs b/orca-core/src/pipeline/simple.rs index 4417e05..4bc5a6d 100644 --- a/orca-core/src/pipeline/simple.rs +++ b/orca-core/src/pipeline/simple.rs @@ -201,6 +201,7 @@ impl Pipeline for LLMPipeline { let mut locked_memory = memory.lock().await; // Lock the memory let mem = locked_memory.memory(); mem.save(prompt); + log::debug!("Memory: {}", mem); self.llm.generate(mem.clone_prompt()).await? } else { self.llm.generate(prompt.clone_prompt()).await? diff --git a/orca-core/src/prompt/mod.rs b/orca-core/src/prompt/mod.rs index 42d4eb5..4736104 100644 --- a/orca-core/src/prompt/mod.rs +++ b/orca-core/src/prompt/mod.rs @@ -235,6 +235,7 @@ pub trait Prompt: Sync + Send + Display { /// ``` fn save(&mut self, _data: Box) { unimplemented!("save not implemented for this prompt type"); + // Err(anyhow::anyhow!("save not implemented for this prompt type")) } /// Convert the current prompt to a `ChatPrompt`. @@ -242,7 +243,7 @@ pub trait Prompt: Sync + Send + Display { /// # Returns /// * `Result` - The `ChatPrompt` representation of the prompt or an error. fn to_chat(&self) -> Result { - unimplemented!("Unable to convert prompt to ChatPrompt"); + Err(anyhow::anyhow!("Unable to convert prompt to ChatPrompt")) } /// Clone the current prompt into a Boxed trait object.