diff --git a/examples/rag-playground/public/data/ml-arxiv-papers-1000.ndjson b/examples/rag-playground/public/data/ml-arxiv-papers-1000.ndjson new file mode 120000 index 0000000..79deb91 --- /dev/null +++ b/examples/rag-playground/public/data/ml-arxiv-papers-1000.ndjson @@ -0,0 +1 @@ +../../notebooks/ml-arxiv-papers-1000.ndjson \ No newline at end of file diff --git a/examples/rag-playground/src/components/text-viewer/text-viewer.ts b/examples/rag-playground/src/components/text-viewer/text-viewer.ts index 140c489..2a30545 100644 --- a/examples/rag-playground/src/components/text-viewer/text-viewer.ts +++ b/examples/rag-playground/src/components/text-viewer/text-viewer.ts @@ -2,12 +2,15 @@ import { LitElement, css, unsafeCSS, html, PropertyValues } from 'lit'; import { customElement, property, state, query } from 'lit/decorators.js'; import { unsafeHTML } from 'lit/directives/unsafe-html.js'; +import type { MememoWorkerMessage } from '../../workers/mememo-worker'; + // Assets import componentCSS from './text-viewer.css?inline'; import searchIcon from '../../images/icon-search.svg?raw'; import crossIcon from '../../images/icon-cross-thick.svg?raw'; import crossSmallIcon from '../../images/icon-cross.svg?raw'; +import MememoWorkerInline from '../../workers/mememo-worker?worker&inline'; import paperDataJSON from '../../../notebooks/ml-arxiv-papers-1000.json'; const paperData = paperDataJSON as string[]; @@ -28,11 +31,28 @@ export class MememoTextViewer extends LitElement { @state() showSearchBarCancelButton = false; + loaderWorker: Worker; + //==========================================================================|| // Lifecycle Methods || //==========================================================================|| constructor() { super(); + + this.loaderWorker = new MememoWorkerInline(); + this.loaderWorker.addEventListener( + 'message', + (e: MessageEvent) => + this.loaderWorkerMessageHandler(e) + ); + + const message: MememoWorkerMessage = { + command: 'startLoadData', + payload: { + url: '/data/ml-arxiv-papers-1000.ndjson' + } + }; + this.loaderWorker.postMessage(message); } /** @@ -53,6 +73,19 @@ export class MememoTextViewer extends LitElement { showSearchBarCancelButtonClicked() {} + loaderWorkerMessageHandler(e: MessageEvent) { + switch (e.data.command) { + case 'transferLoadData': { + console.log(e.data.payload); + break; + } + + default: { + console.error(`Unknown command ${e.data.command}`); + } + } + } + //==========================================================================|| // Private Helpers || //==========================================================================|| diff --git a/examples/rag-playground/src/types/common-types.ts b/examples/rag-playground/src/types/common-types.ts index 3a7e0f8..d0c8d06 100644 --- a/examples/rag-playground/src/types/common-types.ts +++ b/examples/rag-playground/src/types/common-types.ts @@ -2,6 +2,14 @@ * Type definitions */ +export type DocumentRecordStreamData = [string, number[]]; + +export interface DocumentRecord { + embedding: number[]; + id: number; + text: string; +} + export interface SimpleEventMessage { message: string; } @@ -40,42 +48,3 @@ export interface Size { width: number; height: number; } - -export interface PromptModel { - task: string; - prompt: string; - variables: string[]; - temperature: number; - stopSequences?: string[]; -} - -export type TextGenWorkerMessage = - | { - command: 'startTextGen'; - payload: { - requestID: string; - apiKey: string; - prompt: string; - temperature: number; - stopSequences?: string[]; - detail?: string; - }; - } - | { - command: 'finishTextGen'; - payload: { - requestID: string; - apiKey: string; - result: string; - prompt: string; - detail: string; - }; - } - | { - command: 'error'; - payload: { - requestID: string; - originalCommand: string; - message: string; - }; - }; diff --git a/examples/rag-playground/src/workers/mememo-worker.ts b/examples/rag-playground/src/workers/mememo-worker.ts new file mode 100644 index 0000000..32a3817 --- /dev/null +++ b/examples/rag-playground/src/workers/mememo-worker.ts @@ -0,0 +1,155 @@ +import { HNSW } from '../../../../src/index'; +import type { + DocumentRecord, + DocumentRecordStreamData +} from '../types/common-types'; +import { + timeit, + splitStreamTransform, + parseJSONTransform +} from '@xiaohk/utils'; + +export type MememoWorkerMessage = + | { + command: 'startLoadData'; + payload: { + /** NDJSON data url */ + url: string; + }; + } + | { + command: 'transferLoadData'; + payload: { + isFirstBatch: boolean; + isLastBatch: boolean; + points: DocumentRecord[]; + loadedPointCount: number; + }; + }; + +const DEV_MODE = import.meta.env.DEV; +const POINT_THRESHOLD = 100; + +let pendingDataPoints: DocumentRecord[] = []; +let loadedPointCount = 0; +let sentPointCount = 0; + +let lastDrawnPoints: DocumentRecord[] | null = null; + +/** + * Handle message events from the main thread + * @param e Message event + */ +self.onmessage = (e: MessageEvent) => { + // Stream point data + switch (e.data.command) { + case 'startLoadData': { + console.log('Worker: start streaming data'); + timeit('Stream data', true); + + const url = e.data.payload.url; + startLoadData(url); + break; + } + + default: { + console.error('Worker: unknown message', e.data.command); + break; + } + } +}; + +/** + * Start loading the text data + * @param url URL to the NDJSON file + */ +const startLoadData = (url: string) => { + fetch(url).then(async response => { + if (!response.ok) { + console.error('Failed to load data', response); + return; + } + + const reader = response.body + ?.pipeThrough(new TextDecoderStream()) + ?.pipeThrough(splitStreamTransform('\n')) + ?.pipeThrough(parseJSONTransform()) + ?.getReader(); + + while (true && reader !== undefined) { + const result = await reader.read(); + const point = result.value as DocumentRecordStreamData; + const done = result.done; + + if (done) { + timeit('Stream data', DEV_MODE); + pointStreamFinished(); + break; + } else { + processPointStream(point); + + // // TODO: Remove me in prod + // if (loadedPointCount >= 305000) { + // pointStreamFinished(); + // timeit('Stream data', DEBUG); + // break; + // } + } + } + }); +}; + +/** + * Process one data point + * @param point Loaded data point + */ +const processPointStream = (point: DocumentRecordStreamData) => { + const promptPoint: DocumentRecord = { + text: point[0], + embedding: point[1], + id: loadedPointCount + }; + + pendingDataPoints.push(promptPoint); + loadedPointCount += 1; + + // Notify the main thread if we have load enough data + if (pendingDataPoints.length >= POINT_THRESHOLD) { + const result: MememoWorkerMessage = { + command: 'transferLoadData', + payload: { + isFirstBatch: lastDrawnPoints === null, + isLastBatch: false, + points: pendingDataPoints, + loadedPointCount + } + }; + postMessage(result); + + sentPointCount += pendingDataPoints.length; + lastDrawnPoints = pendingDataPoints.slice(); + pendingDataPoints = []; + } +}; + +/** + * Construct tree and notify the main thread when finish reading all data + */ +const pointStreamFinished = () => { + // Send any left over points + + const result: MememoWorkerMessage = { + command: 'transferLoadData', + payload: { + isFirstBatch: lastDrawnPoints === null, + isLastBatch: true, + points: pendingDataPoints, + loadedPointCount + } + }; + postMessage(result); + + sentPointCount += pendingDataPoints.length; + lastDrawnPoints = pendingDataPoints.slice(); + pendingDataPoints = []; +}; diff --git a/examples/rag-playground/vite.config.ts b/examples/rag-playground/vite.config.ts index 88df2fe..227f3a7 100644 --- a/examples/rag-playground/vite.config.ts +++ b/examples/rag-playground/vite.config.ts @@ -8,10 +8,10 @@ export default defineConfig(({ command, mode }) => { // Development return { plugins: [ - hmrPlugin({ - include: ['./src/**/*.ts'], - presets: [presets.lit] - }) + // hmrPlugin({ + // include: ['./src/**/*.ts'], + // presets: [presets.lit] + // }) ] }; } else if (command === 'build') {