Skip to content

Commit

Permalink
Merge pull request #14 from Butch78/feat-openai-json-mode
Browse files Browse the repository at this point in the history
feat: OpenAI: JSON mode
  • Loading branch information
santiagomed authored Dec 15, 2023
2 parents 785c566 + 3f3a4dd commit 03613aa
Showing 1 changed file with 61 additions and 1 deletion.
62 changes: 61 additions & 1 deletion orca-core/src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub struct Payload {
stop: Option<Vec<String>>,
messages: Vec<Message>,
stream: bool,
response_format: ResponseFormatWrapper,
}

#[derive(Serialize, Deserialize, Debug)]
Expand All @@ -29,6 +30,12 @@ pub struct EmbeddingPayload {
model: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ResponseFormatWrapper {
#[serde(rename = "type")]
pub format: ResponseFormat,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct Response {
id: String,
Expand Down Expand Up @@ -97,6 +104,19 @@ pub struct Usage {
total_tokens: i32,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ResponseFormat {
Text,
JsonObject,
}

impl From<ResponseFormat> for ResponseFormatWrapper {
fn from(format: ResponseFormat) -> Self {
ResponseFormatWrapper { format }
}
}

#[derive(Serialize, Deserialize, Debug)]
pub struct Choice {
index: i32,
Expand Down Expand Up @@ -151,6 +171,10 @@ pub struct OpenAI {
///
/// The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.
max_tokens: u16,

/// The format of the returned data. With the new update, the response can be set to a JSON object.
/// https://platform.openai.com/docs/guides/text-generation/json-mode
response_format: ResponseFormat,
}

impl Default for OpenAI {
Expand All @@ -159,12 +183,13 @@ impl Default for OpenAI {
client: Client::new(),
url: OPENAI_COMPLETIONS_URL.to_string(),
api_key: std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"),
model: "gpt-3.5-turbo".to_string(),
model: "gpt-3.5-turbo-1106".to_string(),
emedding_model: "text-embedding-ada-002".to_string(),
temperature: 1.0,
top_p: 1.0,
stream: false,
max_tokens: 1024u16,
response_format: ResponseFormat::Text,
}
}
}
Expand Down Expand Up @@ -215,6 +240,11 @@ impl OpenAI {
self
}

pub fn with_response_format(mut self, response_format: ResponseFormat) -> Self {
self.response_format = response_format;
self
}

/// Generate a request for the OpenAI API and set the parameters
pub fn generate_request(&self, messages: &[Message]) -> Result<reqwest::Request> {
let payload = Payload {
Expand All @@ -225,6 +255,7 @@ impl OpenAI {
stop: None,
messages: messages.to_vec(),
stream: self.stream,
response_format: self.response_format.clone().into(),
};
let req = self
.client
Expand Down Expand Up @@ -384,6 +415,35 @@ mod test {
assert!(response.to_string().to_lowercase().contains("berlin"));
}

#[tokio::test]
async fn test_generate_json_mode() {
let client = OpenAI::new().with_model("gpt-3.5-turbo-1106").with_response_format(ResponseFormat::JsonObject);
let mut context = HashMap::new();
context.insert("country1", "France");
context.insert("country2", "Germany");
let prompt = template!(
"my template",
r#"
{{#chat}}
{{#user}}
What is the capital of {{country1}}?
{{/user}}
{{#assistant}}
Paris
{{/assistant}}
{{#user}}
What is the capital of {{country2}} in a JSON format?
{{/user}}
{{/chat}}
"#
);
let prompt = prompt.render_context("my template", &context).unwrap();
let response = client.generate(prompt).await.unwrap();
assert!(response.to_string().to_lowercase().contains("berlin"));
// Assert response is a JSON object
assert!(response.to_string().starts_with("{"));
}

#[tokio::test]
async fn test_embedding() {
let client = OpenAI::new();
Expand Down

0 comments on commit 03613aa

Please sign in to comment.