diff --git a/examples/rag-playground/package.json b/examples/rag-playground/package.json index 1a85b1d..9effe8c 100644 --- a/examples/rag-playground/package.json +++ b/examples/rag-playground/package.json @@ -13,18 +13,19 @@ }, "devDependencies": { "@floating-ui/dom": "^1.6.1", - "@types/flexsearch": "^0.7.6", - "@typescript-eslint/eslint-plugin": "^6.20.0", - "@xenova/transformers": "^2.14.2", - "@xiaohk/utils": "^0.0.6", "@types/d3-array": "^3.2.1", "@types/d3-format": "^3.0.4", "@types/d3-random": "^3.0.3", "@types/d3-time-format": "^4.0.3", - "d3-format": "^3.1.0", + "@types/flexsearch": "^0.7.6", + "@typescript-eslint/eslint-plugin": "^6.20.0", + "@xenova/transformers": "^2.14.2", + "@xiaohk/utils": "^0.0.6", "d3-array": "^3.2.4", + "d3-format": "^3.1.0", "d3-random": "^3.0.1", "d3-time-format": "^4.1.0", + "dexie": "^3.2.4", "eslint": "^8.56.0", "eslint-config-prettier": "^9.1.0", "eslint-plugin-lit": "^1.11.0", @@ -32,7 +33,6 @@ "eslint-plugin-wc": "^2.0.4", "flexsearch": "^0.7.43", "gh-pages": "^6.1.1", - "idb": "^8.0.0", "lit": "^3.1.2", "prettier": "^3.2.4", "typescript": "^5.3.3", diff --git a/examples/rag-playground/src/types/common-types.ts b/examples/rag-playground/src/types/common-types.ts index d0c8d06..8404525 100644 --- a/examples/rag-playground/src/types/common-types.ts +++ b/examples/rag-playground/src/types/common-types.ts @@ -4,6 +4,11 @@ export type DocumentRecordStreamData = [string, number[]]; +export interface DocumentDBEntry { + id: number; + text: string; +} + export interface DocumentRecord { embedding: number[]; id: number; diff --git a/examples/rag-playground/src/workers/mememo-worker.ts b/examples/rag-playground/src/workers/mememo-worker.ts index 280848a..bf1a2f8 100644 --- a/examples/rag-playground/src/workers/mememo-worker.ts +++ b/examples/rag-playground/src/workers/mememo-worker.ts @@ -1,6 +1,7 @@ import { HNSW } from '../../../../src/index'; import type { DocumentRecord, + DocumentDBEntry, DocumentRecordStreamData } from '../types/common-types'; import { @@ -9,7 +10,8 @@ import { parseJSONTransform } from '@xiaohk/utils'; import Flexsearch from 'flexsearch'; -import { openDB, IDBPDatabase } from 'idb'; +import Dexie from 'dexie'; +import type { Table, PromiseExtended } from 'dexie'; //==========================================================================|| // Types & Constants || @@ -63,11 +65,12 @@ const flexIndex: Flexsearch.Index = new Flexsearch.Index({ tokenize: 'forward' }) as Flexsearch.Index; let workerDatasetName = 'my-dataset'; -let documentDBPromise: Promise> | null = null; +let documentDBPromise: PromiseExtended> | null = + null; const hnswIndex = new HNSW({ - distanceFunction: 'cosine', + distanceFunction: 'cosine-normalized', seed: 123, - useIndexedDB: false + useIndexedDB: true }); //==========================================================================|| @@ -114,11 +117,13 @@ const startLoadCompressedData = (url: string, datasetName: string) => { // Update the indexed db store reference workerDatasetName = datasetName; - documentDBPromise = openDB(`${workerDatasetName}-store`, 1, { - upgrade(db) { - db.createObjectStore(workerDatasetName); - } + // Create a new store, clear content from previous sessions + const myDexie = new Dexie('mememo-document-store'); + myDexie.version(1).stores({ + mememo: 'id' }); + const db = myDexie.table('mememo'); + documentDBPromise = db.clear().then(() => db); fetch(url).then( async response => { @@ -168,16 +173,25 @@ const processPointStream = async (point: DocumentRecordStreamData) => { id: loadedPointCount }; - // Index the point + // Index the point in flex pendingDataPoints.push(documentPoint); flexIndex.add(documentPoint.id, documentPoint.text); - await documentDB.put(workerDatasetName, documentPoint.text, documentPoint.id); - await hnswIndex.insert(String(documentPoint.id), documentPoint.embedding); loadedPointCount += 1; - // Notify the main thread if we have load enough data if (pendingDataPoints.length >= POINT_THRESHOLD) { + // Batched index the documents to IndexedDB and MeMemo + const keys = pendingDataPoints.map(d => String(d.id)); + const embeddings = pendingDataPoints.map(d => d.embedding); + const documentEntries: DocumentDBEntry[] = pendingDataPoints.map(d => ({ + id: d.id, + text: d.text + })); + + await documentDB.bulkPut(documentEntries); + await hnswIndex.bulkInsert(keys, embeddings); + + // Notify the main thread if we have load enough data const result: MememoWorkerMessage = { command: 'transferLoadData', payload: { @@ -235,17 +249,19 @@ const searchPoint = async (query: string, limit: number, requestID: number) => { }) as unknown as number[]; // Look up the indexes in indexedDB - const results = []; - for (const i of resultIndexes) { - const result = (await documentDB.get(workerDatasetName, i)) as string; - results.push(result); + const results = await documentDB.bulkGet(resultIndexes); + const documents: string[] = []; + for (const r of results) { + if (r !== undefined) { + documents.push(r.text); + } } const message: MememoWorkerMessage = { command: 'finishLexicalSearch', payload: { query, - results, + results: documents, requestID } };