From b1c4fdb58c7a373a2ea691e0cd20dbf3dae40ad4 Mon Sep 17 00:00:00 2001 From: Andrei Bondarev Date: Thu, 1 Aug 2024 16:18:01 -0400 Subject: [PATCH 1/2] Add rerank method and bump version --- CHANGELOG.md | 3 ++ Gemfile.lock | 2 +- README.md | 59 ++++++++++++++++++++++++-------------- lib/cohere/client.rb | 24 ++++++++++++++++ lib/cohere/version.rb | 2 +- spec/cohere/client_spec.rb | 47 +++++++++++++++++++++++------- spec/fixtures/rerank.json | 33 +++++++++++++++++++++ 7 files changed, 136 insertions(+), 34 deletions(-) create mode 100644 spec/fixtures/rerank.json diff --git a/CHANGELOG.md b/CHANGELOG.md index fcb037e..2a16cb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ ## [Unreleased] +## [0.9.11] - 2024-08-01 +- New `rerank()` method + ## [0.9.10] - 2024-05-10 - /chat endpoint does not require `message:` parameter anymore diff --git a/Gemfile.lock b/Gemfile.lock index 80489a8..cc15050 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -1,7 +1,7 @@ PATH remote: . specs: - cohere-ruby (0.9.10) + cohere-ruby (0.9.11) faraday (>= 2.0.1, < 3.0) GEM diff --git a/README.md b/README.md index 1ab14fa..a54b487 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Cohere

- Weaviate logo + Cohere logo +   Ruby logo

@@ -42,7 +42,7 @@ client = Cohere::Client.new( ```ruby client.generate( - prompt: "Once upon a time in a magical land called" + prompt: "Once upon a time in a magical land called" ) ``` @@ -50,7 +50,7 @@ client.generate( ```ruby client.chat( - message: "Hey! How are you?" + message: "Hey! How are you?" ) ``` @@ -90,30 +90,45 @@ client.chat( ) ``` - - ### Embed ```ruby client.embed( - texts: ["hello!"] + texts: ["hello!"] +) +``` + +### Rerank + +```ruby +docs = [ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", +] + +client.rerank( + texts: ["hello!"] ) ``` + ### Classify ```ruby examples = [ - { text: "Dermatologists don't like her!", label: "Spam" }, - { text: "Hello, open to this?", label: "Spam" }, - { text: "I need help please wire me $1000 right now", label: "Spam" }, - { text: "Nice to know you ;)", label: "Spam" }, - { text: "Please help me?", label: "Spam" }, - { text: "Your parcel will be delivered today", label: "Not spam" }, - { text: "Review changes to our Terms and Conditions", label: "Not spam" }, - { text: "Weekly sync notes", label: "Not spam" }, - { text: "Re: Follow up from today's meeting", label: "Not spam" }, - { text: "Pre-read for tomorrow", label: "Not spam" } + { text: "Dermatologists don't like her!", label: "Spam" }, + { text: "Hello, open to this?", label: "Spam" }, + { text: "I need help please wire me $1000 right now", label: "Spam" }, + { text: "Nice to know you ;)", label: "Spam" }, + { text: "Please help me?", label: "Spam" }, + { text: "Your parcel will be delivered today", label: "Not spam" }, + { text: "Review changes to our Terms and Conditions", label: "Not spam" }, + { text: "Weekly sync notes", label: "Not spam" }, + { text: "Re: Follow up from today's meeting", label: "Not spam" }, + { text: "Pre-read for tomorrow", label: "Not spam" } ] inputs = [ @@ -122,8 +137,8 @@ inputs = [ ] client.classify( - examples: examples, - inputs: inputs + examples: examples, + inputs: inputs ) ``` @@ -131,7 +146,7 @@ client.classify( ```ruby client.tokenize( - text: "hello world!" + text: "hello world!" ) ``` @@ -139,7 +154,7 @@ client.tokenize( ```ruby client.detokenize( - tokens: [33555, 1114 , 34] + tokens: [33555, 1114 , 34] ) ``` @@ -147,7 +162,7 @@ client.detokenize( ```ruby client.detect_language( - texts: ["Здравствуй, Мир"] + texts: ["Здравствуй, Мир"] ) ``` @@ -155,7 +170,7 @@ client.detect_language( ```ruby client.summarize( - text: "..." + text: "..." ) ``` diff --git a/lib/cohere/client.rb b/lib/cohere/client.rb index ed5b2a4..eee47eb 100644 --- a/lib/cohere/client.rb +++ b/lib/cohere/client.rb @@ -119,6 +119,30 @@ def embed( response.body end + def rerank( + query:, + documents:, + model: nil, + top_n: nil, + rank_fields: nil, + return_documents: nil, + max_chunks_per_doc: nil + ) + response = connection.post("rerank") do |req| + req.body = { + query: query, + documents: documents + } + req.body[:model] = model if model + req.body[:top_n] = top_n if top_n + req.body[:rank_fields] = rank_fields if rank_fields + req.body[:return_documents] = return_documents if return_documents + req.body[:max_chunks_per_doc] = max_chunks_per_doc if max_chunks_per_doc + end + response.body + + end + def classify( inputs:, examples:, diff --git a/lib/cohere/version.rb b/lib/cohere/version.rb index 4c696a1..4c189a2 100644 --- a/lib/cohere/version.rb +++ b/lib/cohere/version.rb @@ -1,5 +1,5 @@ # frozen_string_literal: true module Cohere - VERSION = "0.9.10" + VERSION = "0.9.11" end diff --git a/spec/cohere/client_spec.rb b/spec/cohere/client_spec.rb index 772ec7d..8df59f6 100644 --- a/spec/cohere/client_spec.rb +++ b/spec/cohere/client_spec.rb @@ -3,7 +3,7 @@ require "spec_helper" RSpec.describe Cohere::Client do - let(:instance) { described_class.new(api_key: "123") } + subject { described_class.new(api_key: "123") } describe "#generate" do let(:generate_result) { JSON.parse(File.read("spec/fixtures/generate_result.json")) } @@ -16,7 +16,7 @@ end it "returns a response" do - expect(instance.generate( + expect(subject.generate( prompt: "Once upon a time in a magical land called" ).dig("generations").first.dig("text")).to eq(" The Past there was a Game called Warhammer Fantasy Battle.") end @@ -33,12 +33,39 @@ end it "returns a response" do - expect(instance.embed( + expect(subject.embed( texts: ["hello!"] ).dig("embeddings")).to eq([[1.2177734, 0.67529297, 2.0742188]]) end end + describe "#rerank" do + let(:embed_result) { JSON.parse(File.read("spec/fixtures/rerank.json")) } + let(:response) { OpenStruct.new(body: embed_result) } + let(:docs) {[ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", + ]} + + before do + allow_any_instance_of(Faraday::Connection).to receive(:post) + .with("rerank") + .and_return(response) + end + + it "returns a response" do + expect( + subject + .rerank(query: "What is the capital of the United States?", documents: docs) + .dig("results") + .map {|h| h["index"]} + ).to eq([3, 4, 2, 0, 1]) + end + end + describe "#classify" do let(:classify_result) { JSON.parse(File.read("spec/fixtures/classify_result.json")) } let(:response) { OpenStruct.new(body: classify_result) } @@ -64,7 +91,7 @@ end it "returns a response" do - res = instance.classify( + res = subject.classify( inputs: inputs, examples: examples ).dig("classifications") @@ -85,7 +112,7 @@ end it "returns a response" do - expect(instance.tokenize( + expect(subject.tokenize( text: "Hello, world!" ).dig("tokens")).to eq([33555, 1114, 34]) end @@ -102,7 +129,7 @@ end it "returns a response" do - expect(instance.tokenize( + expect(subject.tokenize( text: "Hello, world!", model: "base" ).dig("tokens")).to eq([33555, 1114, 34]) @@ -120,7 +147,7 @@ end it "returns a response" do - expect(instance.detokenize( + expect(subject.detokenize( tokens: [33555, 1114, 34] ).dig("text")).to eq("hello world!") end @@ -137,7 +164,7 @@ end it "returns a response" do - expect(instance.detokenize( + expect(subject.detokenize( tokens: [33555, 1114, 34], model: "base" ).dig("text")).to eq("hello world!") @@ -155,7 +182,7 @@ end it "returns a response" do - expect(instance.detect_language( + expect(subject.detect_language( texts: ["Здравствуй, Мир"] ).dig("results").first.dig("language_code")).to eq("ru") end @@ -172,7 +199,7 @@ end it "returns a response" do - expect(instance.summarize( + expect(subject.summarize( text: "Ice cream is a sweetened frozen food typically eaten as a snack or dessert. " \ "It may be made from milk or cream and is flavoured with a sweetener, " \ "either sugar or an alternative, and a spice, such as cocoa or vanilla, " \ diff --git a/spec/fixtures/rerank.json b/spec/fixtures/rerank.json new file mode 100644 index 0000000..aa6bb7c --- /dev/null +++ b/spec/fixtures/rerank.json @@ -0,0 +1,33 @@ +{ + "id": "fd2f37a7-78e5-4d43-9230-ca0804f8cab5", + "results": [ + { + "index": 3, + "relevance_score": 0.97997653 + }, + { + "index": 4, + "relevance_score": 0.27963173 + }, + { + "index": 2, + "relevance_score": 0.10502681 + }, + { + "index": 0, + "relevance_score": 0.10212547 + }, + { + "index": 1, + "relevance_score": 0.0721122 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } +} \ No newline at end of file From 449e39ad60f499c0f5f047e1afcbf105bb3fe9b3 Mon Sep 17 00:00:00 2001 From: Andrei Bondarev Date: Thu, 1 Aug 2024 16:22:16 -0400 Subject: [PATCH 2/2] Fix linter --- lib/cohere/client.rb | 1 - spec/cohere/client_spec.rb | 18 ++++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/lib/cohere/client.rb b/lib/cohere/client.rb index eee47eb..523e8c6 100644 --- a/lib/cohere/client.rb +++ b/lib/cohere/client.rb @@ -140,7 +140,6 @@ def rerank( req.body[:max_chunks_per_doc] = max_chunks_per_doc if max_chunks_per_doc end response.body - end def classify( diff --git a/spec/cohere/client_spec.rb b/spec/cohere/client_spec.rb index 8df59f6..4e1a929 100644 --- a/spec/cohere/client_spec.rb +++ b/spec/cohere/client_spec.rb @@ -42,13 +42,15 @@ describe "#rerank" do let(:embed_result) { JSON.parse(File.read("spec/fixtures/rerank.json")) } let(:response) { OpenStruct.new(body: embed_result) } - let(:docs) {[ - "Carson City is the capital city of the American state of Nevada.", - "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", - "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", - "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", - "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", - ]} + let(:docs) { + [ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." + ] + } before do allow_any_instance_of(Faraday::Connection).to receive(:post) @@ -61,7 +63,7 @@ subject .rerank(query: "What is the capital of the United States?", documents: docs) .dig("results") - .map {|h| h["index"]} + .map { |h| h["index"] } ).to eq([3, 4, 2, 0, 1]) end end