diff --git a/CHANGELOG.md b/CHANGELOG.md
index fcb037e..2a16cb1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,8 @@
## [Unreleased]
+## [0.9.11] - 2024-08-01
+- New `rerank()` method
+
## [0.9.10] - 2024-05-10
- /chat endpoint does not require `message:` parameter anymore
diff --git a/Gemfile.lock b/Gemfile.lock
index 80489a8..cc15050 100644
--- a/Gemfile.lock
+++ b/Gemfile.lock
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
- cohere-ruby (0.9.10)
+ cohere-ruby (0.9.11)
faraday (>= 2.0.1, < 3.0)
GEM
diff --git a/README.md b/README.md
index 1ab14fa..a54b487 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# Cohere
-
+
+
@@ -42,7 +42,7 @@ client = Cohere::Client.new(
```ruby
client.generate(
- prompt: "Once upon a time in a magical land called"
+ prompt: "Once upon a time in a magical land called"
)
```
@@ -50,7 +50,7 @@ client.generate(
```ruby
client.chat(
- message: "Hey! How are you?"
+ message: "Hey! How are you?"
)
```
@@ -90,30 +90,45 @@ client.chat(
)
```
-
-
### Embed
```ruby
client.embed(
- texts: ["hello!"]
+ texts: ["hello!"]
+)
+```
+
+### Rerank
+
+```ruby
+docs = [
+ "Carson City is the capital city of the American state of Nevada.",
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
+ "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
+ "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
+ "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
+]
+
+client.rerank(
+ texts: ["hello!"]
)
```
+
### Classify
```ruby
examples = [
- { text: "Dermatologists don't like her!", label: "Spam" },
- { text: "Hello, open to this?", label: "Spam" },
- { text: "I need help please wire me $1000 right now", label: "Spam" },
- { text: "Nice to know you ;)", label: "Spam" },
- { text: "Please help me?", label: "Spam" },
- { text: "Your parcel will be delivered today", label: "Not spam" },
- { text: "Review changes to our Terms and Conditions", label: "Not spam" },
- { text: "Weekly sync notes", label: "Not spam" },
- { text: "Re: Follow up from today's meeting", label: "Not spam" },
- { text: "Pre-read for tomorrow", label: "Not spam" }
+ { text: "Dermatologists don't like her!", label: "Spam" },
+ { text: "Hello, open to this?", label: "Spam" },
+ { text: "I need help please wire me $1000 right now", label: "Spam" },
+ { text: "Nice to know you ;)", label: "Spam" },
+ { text: "Please help me?", label: "Spam" },
+ { text: "Your parcel will be delivered today", label: "Not spam" },
+ { text: "Review changes to our Terms and Conditions", label: "Not spam" },
+ { text: "Weekly sync notes", label: "Not spam" },
+ { text: "Re: Follow up from today's meeting", label: "Not spam" },
+ { text: "Pre-read for tomorrow", label: "Not spam" }
]
inputs = [
@@ -122,8 +137,8 @@ inputs = [
]
client.classify(
- examples: examples,
- inputs: inputs
+ examples: examples,
+ inputs: inputs
)
```
@@ -131,7 +146,7 @@ client.classify(
```ruby
client.tokenize(
- text: "hello world!"
+ text: "hello world!"
)
```
@@ -139,7 +154,7 @@ client.tokenize(
```ruby
client.detokenize(
- tokens: [33555, 1114 , 34]
+ tokens: [33555, 1114 , 34]
)
```
@@ -147,7 +162,7 @@ client.detokenize(
```ruby
client.detect_language(
- texts: ["Здравствуй, Мир"]
+ texts: ["Здравствуй, Мир"]
)
```
@@ -155,7 +170,7 @@ client.detect_language(
```ruby
client.summarize(
- text: "..."
+ text: "..."
)
```
diff --git a/lib/cohere/client.rb b/lib/cohere/client.rb
index ed5b2a4..523e8c6 100644
--- a/lib/cohere/client.rb
+++ b/lib/cohere/client.rb
@@ -119,6 +119,29 @@ def embed(
response.body
end
+ def rerank(
+ query:,
+ documents:,
+ model: nil,
+ top_n: nil,
+ rank_fields: nil,
+ return_documents: nil,
+ max_chunks_per_doc: nil
+ )
+ response = connection.post("rerank") do |req|
+ req.body = {
+ query: query,
+ documents: documents
+ }
+ req.body[:model] = model if model
+ req.body[:top_n] = top_n if top_n
+ req.body[:rank_fields] = rank_fields if rank_fields
+ req.body[:return_documents] = return_documents if return_documents
+ req.body[:max_chunks_per_doc] = max_chunks_per_doc if max_chunks_per_doc
+ end
+ response.body
+ end
+
def classify(
inputs:,
examples:,
diff --git a/lib/cohere/version.rb b/lib/cohere/version.rb
index 4c696a1..4c189a2 100644
--- a/lib/cohere/version.rb
+++ b/lib/cohere/version.rb
@@ -1,5 +1,5 @@
# frozen_string_literal: true
module Cohere
- VERSION = "0.9.10"
+ VERSION = "0.9.11"
end
diff --git a/spec/cohere/client_spec.rb b/spec/cohere/client_spec.rb
index 772ec7d..4e1a929 100644
--- a/spec/cohere/client_spec.rb
+++ b/spec/cohere/client_spec.rb
@@ -3,7 +3,7 @@
require "spec_helper"
RSpec.describe Cohere::Client do
- let(:instance) { described_class.new(api_key: "123") }
+ subject { described_class.new(api_key: "123") }
describe "#generate" do
let(:generate_result) { JSON.parse(File.read("spec/fixtures/generate_result.json")) }
@@ -16,7 +16,7 @@
end
it "returns a response" do
- expect(instance.generate(
+ expect(subject.generate(
prompt: "Once upon a time in a magical land called"
).dig("generations").first.dig("text")).to eq(" The Past there was a Game called Warhammer Fantasy Battle.")
end
@@ -33,12 +33,41 @@
end
it "returns a response" do
- expect(instance.embed(
+ expect(subject.embed(
texts: ["hello!"]
).dig("embeddings")).to eq([[1.2177734, 0.67529297, 2.0742188]])
end
end
+ describe "#rerank" do
+ let(:embed_result) { JSON.parse(File.read("spec/fixtures/rerank.json")) }
+ let(:response) { OpenStruct.new(body: embed_result) }
+ let(:docs) {
+ [
+ "Carson City is the capital city of the American state of Nevada.",
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
+ "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
+ "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
+ "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
+ ]
+ }
+
+ before do
+ allow_any_instance_of(Faraday::Connection).to receive(:post)
+ .with("rerank")
+ .and_return(response)
+ end
+
+ it "returns a response" do
+ expect(
+ subject
+ .rerank(query: "What is the capital of the United States?", documents: docs)
+ .dig("results")
+ .map { |h| h["index"] }
+ ).to eq([3, 4, 2, 0, 1])
+ end
+ end
+
describe "#classify" do
let(:classify_result) { JSON.parse(File.read("spec/fixtures/classify_result.json")) }
let(:response) { OpenStruct.new(body: classify_result) }
@@ -64,7 +93,7 @@
end
it "returns a response" do
- res = instance.classify(
+ res = subject.classify(
inputs: inputs,
examples: examples
).dig("classifications")
@@ -85,7 +114,7 @@
end
it "returns a response" do
- expect(instance.tokenize(
+ expect(subject.tokenize(
text: "Hello, world!"
).dig("tokens")).to eq([33555, 1114, 34])
end
@@ -102,7 +131,7 @@
end
it "returns a response" do
- expect(instance.tokenize(
+ expect(subject.tokenize(
text: "Hello, world!",
model: "base"
).dig("tokens")).to eq([33555, 1114, 34])
@@ -120,7 +149,7 @@
end
it "returns a response" do
- expect(instance.detokenize(
+ expect(subject.detokenize(
tokens: [33555, 1114, 34]
).dig("text")).to eq("hello world!")
end
@@ -137,7 +166,7 @@
end
it "returns a response" do
- expect(instance.detokenize(
+ expect(subject.detokenize(
tokens: [33555, 1114, 34],
model: "base"
).dig("text")).to eq("hello world!")
@@ -155,7 +184,7 @@
end
it "returns a response" do
- expect(instance.detect_language(
+ expect(subject.detect_language(
texts: ["Здравствуй, Мир"]
).dig("results").first.dig("language_code")).to eq("ru")
end
@@ -172,7 +201,7 @@
end
it "returns a response" do
- expect(instance.summarize(
+ expect(subject.summarize(
text: "Ice cream is a sweetened frozen food typically eaten as a snack or dessert. " \
"It may be made from milk or cream and is flavoured with a sweetener, " \
"either sugar or an alternative, and a spice, such as cocoa or vanilla, " \
diff --git a/spec/fixtures/rerank.json b/spec/fixtures/rerank.json
new file mode 100644
index 0000000..aa6bb7c
--- /dev/null
+++ b/spec/fixtures/rerank.json
@@ -0,0 +1,33 @@
+{
+ "id": "fd2f37a7-78e5-4d43-9230-ca0804f8cab5",
+ "results": [
+ {
+ "index": 3,
+ "relevance_score": 0.97997653
+ },
+ {
+ "index": 4,
+ "relevance_score": 0.27963173
+ },
+ {
+ "index": 2,
+ "relevance_score": 0.10502681
+ },
+ {
+ "index": 0,
+ "relevance_score": 0.10212547
+ },
+ {
+ "index": 1,
+ "relevance_score": 0.0721122
+ }
+ ],
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "search_units": 1
+ }
+ }
+}
\ No newline at end of file