Skip to content

Commit

Permalink
Add rerank endpoint (#284)
Browse files Browse the repository at this point in the history
This PR adds the new [rerank
endpoint](https://docs.pinecone.io/guides/inference/rerank) into the TS
client.

Note, README reflects addition.

- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [x] This change requires a documentation update
- [ ] Infrastructure change (CI configs, etc)
- [ ] Non-code change (docs, etc)
- [ ] None of the above: (explain here)

CI passes. Added new unit & integration tests.

---
- To see the specific tasks where the Asana app for GitHub is being
used, see below:
  - https://app.asana.com/0/0/1208242198563722
  • Loading branch information
aulorbe committed Oct 23, 2024
1 parent 94f1a7b commit 02f4ef9
Show file tree
Hide file tree
Showing 63 changed files with 801 additions and 466 deletions.
72 changes: 68 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1018,14 +1018,15 @@ If you do not specify a namespace, the records in the default namespace `''` wil

## Inference

Interact with Pinecone's Inference API (currently in public preview). The Pinecone Inference API is a service that gives
you access to embedding models hosted on Pinecone's infrastructure. Read more at [Understanding Pinecone Inference](https://docs.pinecone.io/guides/inference/understanding-inference).
Interact with Pinecone's [Inference API](https://docs.pinecone.io/guides/inference/understanding-inference) (currently in [public preview](https://docs.pinecone.io/release-notes/feature-availability)). The Pinecone Inference API is a service that gives
you access to inference models hosted on Pinecone's infrastructure.

**Notes:**

Models currently supported:
Supported models:

- [multilingual-e5-large](https://arxiv.org/pdf/2402.05672)
- Embedding: [multilingual-e5-large](https://docs.pinecone.io/models/multilingual-e5-large)
- Reranking: [bge-reranker-v2-m3](https://docs.pinecone.io/models/bge-reranker-v2-m3)

## Create embeddings

Expand Down Expand Up @@ -1089,3 +1090,66 @@ generateQueryEmbeddings().then((embeddingsResponse) => {

// << Send query to Pinecone to retrieve similar documents >>
```

## Rerank documents

Rerank documents in descending relevance-order against a query.

**Note:** The `score` represents the absolute measure of relevance of a given query and passage pair. Normalized
between [0, 1], the `score` represents how closely relevant a specific item and query are, with scores closer to 1
indicating higher relevance.

```typescript
import { Pinecone } from '@pinecone-database/pinecone';
const pc = new Pinecone();
const rerankingModel = 'bge-reranker-v2-m3';
const myQuery = 'What are some good Turkey dishes for Thanksgiving?';
const myDocuments = [
{ text: 'I love turkey sandwiches with pastrami' },
{
text: 'A lemon brined Turkey with apple sausage stuffing is a classic Thanksgiving main',
},
{ text: 'My favorite Thanksgiving dish is pumpkin pie' },
{ text: 'Turkey is a great source of protein' },
];

// >>> Sample without passing an `options` object:
const response = await pc.inference.rerank(
rerankingModel,
myQuery,
myDocuments
);
console.log(response);
// {
// model: 'bge-reranker-v2-m3',
// data: [
// { index: 1, score: 0.5633179, document: [Object] },
// { index: 2, score: 0.02013874, document: [Object] },
// { index: 3, score: 0.00035419367, document: [Object] },
// { index: 0, score: 0.00021485926, document: [Object] }
// ],
// usage: { rerankUnits: 1 }
// }

// >>> Sample with an `options` object:
const rerankOptions = {
topN: 3,
returnDocuments: false,
};
const response = await pc.inference.rerank(
rerankingModel,
myQuery,
myDocuments,
rerankOptions
);
console.log(response);
// {
// model: 'bge-reranker-v2-m3',
// data: [
// { index: 1, score: 0.5633179, document: undefined },
// { index: 2, score: 0.02013874, document: undefined },
// { index: 3, score: 0.00035419367, document: undefined },
// ],
// usage: { rerankUnits: 1 }
//}
```
4 changes: 2 additions & 2 deletions src/__tests__/pinecone.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ jest.mock('../utils', () => {
getFetch: () => fakeFetch,
};
});
jest.mock('../data/fetch');
jest.mock('../data/upsert');
jest.mock('../data/vectors/fetch');
jest.mock('../data/vectors/upsert');
jest.mock('../data/indexHostSingleton');
jest.mock('../control', () => {
const realControl = jest.requireActual('../control');
Expand Down
2 changes: 1 addition & 1 deletion src/control/indexOperationsBuilder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
normalizeUrl,
} from '../utils';
import { middleware } from '../utils/middleware';
import type { PineconeConfiguration } from '../data/types';
import type { PineconeConfiguration } from '../data/vectors/types';
import type { ConfigurationParameters as IndexOperationsApiConfigurationParameters } from '../pinecone-generated-ts-fetch/db_control';

export const indexOperationsBuilder = (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { StartImportCommand } from '../../bulkImport/startImport';
import { ListImportsCommand } from '../../bulkImport/listImports';
import { DescribeImportCommand } from '../../bulkImport/describeImport';
import { CancelImportCommand } from '../../bulkImport/cancelImport';
import { BulkOperationsProvider } from '../../bulkImport/bulkOperationsProvider';
import { StartImportCommand } from '../../bulk/startImport';
import { ListImportsCommand } from '../../bulk/listImports';
import { DescribeImportCommand } from '../../bulk/describeImport';
import { CancelImportCommand } from '../../bulk/cancelImport';
import { BulkOperationsProvider } from '../../bulk/bulkOperationsProvider';
import {
ImportErrorModeOnErrorEnum,
ListImportsRequest,
StartImportOperationRequest,
ListBulkImportsRequest,
StartBulkImportRequest,
} from '../../../pinecone-generated-ts-fetch/db_data';
import { PineconeArgumentError } from '../../../errors';

Expand All @@ -20,10 +20,10 @@ describe('StartImportCommand', () => {

beforeEach(() => {
apiMock = {
startImport: jest.fn(),
listImports: jest.fn(),
describeImport: jest.fn(),
cancelImport: jest.fn(),
startBulkImport: jest.fn(),
listBulkImports: jest.fn(),
describeBulkImport: jest.fn(),
cancelBulkImport: jest.fn(),
};

apiProviderMock = {
Expand All @@ -40,7 +40,7 @@ describe('StartImportCommand', () => {
const uri = 's3://my-bucket/my-file.csv';
const errorMode = 'continue';

const expectedRequest: StartImportOperationRequest = {
const expectedRequest: StartBulkImportRequest = {
startImportRequest: {
uri,
errorMode: { onError: ImportErrorModeOnErrorEnum.Continue },
Expand All @@ -50,14 +50,14 @@ describe('StartImportCommand', () => {
await startImportCommand.run(uri, errorMode);

expect(apiProviderMock.provide).toHaveBeenCalled();
expect(apiMock.startImport).toHaveBeenCalledWith(expectedRequest);
expect(apiMock.startBulkImport).toHaveBeenCalledWith(expectedRequest);
});

test('should call startImport with correct request when errorMode is "abort"', async () => {
const uri = 's3://my-bucket/my-file.csv';
const errorMode = 'abort';

const expectedRequest: StartImportOperationRequest = {
const expectedRequest: StartBulkImportRequest = {
startImportRequest: {
uri,
errorMode: { onError: ImportErrorModeOnErrorEnum.Abort },
Expand All @@ -67,7 +67,7 @@ describe('StartImportCommand', () => {
await startImportCommand.run(uri, errorMode);

expect(apiProviderMock.provide).toHaveBeenCalled();
expect(apiMock.startImport).toHaveBeenCalledWith(expectedRequest);
expect(apiMock.startBulkImport).toHaveBeenCalledWith(expectedRequest);
});

test('should throw PineconeArgumentError for invalid errorMode', async () => {
Expand All @@ -78,13 +78,13 @@ describe('StartImportCommand', () => {
PineconeArgumentError
);

expect(apiMock.startImport).not.toHaveBeenCalled();
expect(apiMock.startBulkImport).not.toHaveBeenCalled();
});

test('should use "continue" as default when errorMode is undefined', async () => {
const uri = 's3://my-bucket/my-file.csv';

const expectedRequest: StartImportOperationRequest = {
const expectedRequest: StartBulkImportRequest = {
startImportRequest: {
uri,
errorMode: { onError: ImportErrorModeOnErrorEnum.Continue },
Expand All @@ -94,7 +94,7 @@ describe('StartImportCommand', () => {
await startImportCommand.run(uri, undefined);

expect(apiProviderMock.provide).toHaveBeenCalled();
expect(apiMock.startImport).toHaveBeenCalledWith(expectedRequest);
expect(apiMock.startBulkImport).toHaveBeenCalledWith(expectedRequest);
});

test('should throw error when URI/1st arg is missing', async () => {
Expand All @@ -112,14 +112,14 @@ describe('StartImportCommand', () => {
test('should call listImport with correct request', async () => {
const limit = 1;

const expectedRequest: ListImportsRequest = {
const expectedRequest: ListBulkImportsRequest = {
limit,
};

await listImportCommand.run(limit);

expect(apiProviderMock.provide).toHaveBeenCalled();
expect(apiMock.listImports).toHaveBeenCalledWith(expectedRequest);
expect(apiMock.listBulkImports).toHaveBeenCalledWith(expectedRequest);
});

test('should call describeImport with correct request', async () => {
Expand All @@ -129,7 +129,7 @@ describe('StartImportCommand', () => {
await describeImportCommand.run(importId);

expect(apiProviderMock.provide).toHaveBeenCalled();
expect(apiMock.describeImport).toHaveBeenCalledWith(req);
expect(apiMock.describeBulkImport).toHaveBeenCalledWith(req);
});

test('should call cancelImport with correct request', async () => {
Expand All @@ -139,6 +139,6 @@ describe('StartImportCommand', () => {
await cancelImportCommand.run(importId);

expect(apiProviderMock.provide).toHaveBeenCalled();
expect(apiMock.cancelImport).toHaveBeenCalledWith(req);
expect(apiMock.cancelBulkImport).toHaveBeenCalledWith(req);
});
});
28 changes: 14 additions & 14 deletions src/data/__tests__/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import { FetchCommand } from '../fetch';
import { QueryCommand } from '../query';
import { UpdateCommand } from '../update';
import { UpsertCommand } from '../upsert';
import { DataOperationsProvider } from '../dataOperationsProvider';
import { FetchCommand } from '../vectors/fetch';
import { QueryCommand } from '../vectors/query';
import { UpdateCommand } from '../vectors/update';
import { UpsertCommand } from '../vectors/upsert';
import { VectorOperationsProvider } from '../vectors/vectorOperationsProvider';
import { Index } from '../index';
import type { ScoredPineconeRecord } from '../query';
import type { ScoredPineconeRecord } from '../vectors/query';

jest.mock('../fetch');
jest.mock('../query');
jest.mock('../update');
jest.mock('../upsert');
jest.mock('../dataOperationsProvider');
jest.mock('../vectors/fetch');
jest.mock('../vectors/query');
jest.mock('../vectors/update');
jest.mock('../vectors/upsert');
jest.mock('../vectors/vectorOperationsProvider');

describe('Index', () => {
let config;
Expand All @@ -27,7 +27,7 @@ describe('Index', () => {
});

describe('index initialization', () => {
test('passes config, indexName, indexHostUrl, and additionalHeaders to DataOperationsProvider', () => {
test('passes config, indexName, indexHostUrl, and additionalHeaders to VectorOperationsProvider', () => {
const indexHostUrl = 'https://test-api-pinecone.io';
const additionalHeaders = { 'x-custom-header': 'custom-value' };
new Index(
Expand All @@ -37,8 +37,8 @@ describe('Index', () => {
indexHostUrl,
additionalHeaders
);
expect(DataOperationsProvider).toHaveBeenCalledTimes(1);
expect(DataOperationsProvider).toHaveBeenCalledWith(
expect(VectorOperationsProvider).toHaveBeenCalledTimes(1);
expect(VectorOperationsProvider).toHaveBeenCalledWith(
config,
'index-name',
indexHostUrl,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import { deleteAll } from '../deleteAll';
import { deleteAll } from '../../vectors/deleteAll';
import { setupDeleteSuccess } from './deleteOne.test';

describe('deleteAll', () => {
test('calls the openapi delete endpoint, passing deleteAll with target namespace', async () => {
const { DataProvider, DPA } = setupDeleteSuccess(undefined);
const { VectorProvider, VOA } = setupDeleteSuccess(undefined);

const deleteAllFn = deleteAll(DataProvider, 'namespace');
const deleteAllFn = deleteAll(VectorProvider, 'namespace');
const returned = await deleteAllFn();

expect(returned).toBe(void 0);
expect(DPA._delete).toHaveBeenCalledWith({
expect(VOA.deleteVectors).toHaveBeenCalledWith({
deleteRequest: { deleteAll: true, namespace: 'namespace' },
});
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
import { deleteMany } from '../deleteMany';
import { deleteMany } from '../../vectors/deleteMany';
import { setupDeleteSuccess } from './deleteOne.test';
import { PineconeArgumentError } from '../../errors';
import { PineconeArgumentError } from '../../../errors';

describe('deleteMany', () => {
test('calls the openapi delete endpoint, passing ids with target namespace', async () => {
const { DataProvider, DPA } = setupDeleteSuccess(undefined);
const { VectorProvider, VOA } = setupDeleteSuccess(undefined);

const deleteManyFn = deleteMany(DataProvider, 'namespace');
const deleteManyFn = deleteMany(VectorProvider, 'namespace');
const returned = await deleteManyFn(['123', '456', '789']);

expect(returned).toBe(void 0);
expect(DPA._delete).toHaveBeenCalledWith({
expect(VOA.deleteVectors).toHaveBeenCalledWith({
deleteRequest: { ids: ['123', '456', '789'], namespace: 'namespace' },
});
});

test('calls the openapi delete endpoint, passing filter with target namespace', async () => {
const { DPA, DataProvider } = setupDeleteSuccess(undefined);
const { VOA, VectorProvider } = setupDeleteSuccess(undefined);

const deleteManyFn = deleteMany(DataProvider, 'namespace');
const deleteManyFn = deleteMany(VectorProvider, 'namespace');
const returned = await deleteManyFn({ genre: 'ambient' });

expect(returned).toBe(void 0);
expect(DPA._delete).toHaveBeenCalledWith({
expect(VOA.deleteVectors).toHaveBeenCalledWith({
deleteRequest: { filter: { genre: 'ambient' }, namespace: 'namespace' },
});
});

test('throws if pass in empty filter obj', async () => {
const { DataProvider } = setupDeleteSuccess(undefined);
const deleteManyFn = deleteMany(DataProvider, 'namespace');
const { VectorProvider } = setupDeleteSuccess(undefined);
const deleteManyFn = deleteMany(VectorProvider, 'namespace');
const toThrow = async () => {
await deleteManyFn({ some: '' });
};
Expand All @@ -40,8 +40,8 @@ describe('deleteMany', () => {
});

test('throws if pass no record IDs', async () => {
const { DataProvider } = setupDeleteSuccess(undefined);
const deleteManyFn = deleteMany(DataProvider, 'namespace');
const { VectorProvider } = setupDeleteSuccess(undefined);
const deleteManyFn = deleteMany(VectorProvider, 'namespace');
const toThrow = async () => {
await deleteManyFn([]);
};
Expand Down
Loading

0 comments on commit 02f4ef9

Please sign in to comment.