Skip to content

Commit

Permalink
Enable streaming ndjson data
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Wang <[email protected]>
  • Loading branch information
xiaohk committed Feb 3, 2024
1 parent 88efe83 commit 5c88e9e
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 43 deletions.
33 changes: 33 additions & 0 deletions examples/rag-playground/src/components/text-viewer/text-viewer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ import { LitElement, css, unsafeCSS, html, PropertyValues } from 'lit';
import { customElement, property, state, query } from 'lit/decorators.js';
import { unsafeHTML } from 'lit/directives/unsafe-html.js';

import type { MememoWorkerMessage } from '../../workers/mememo-worker';

// Assets
import componentCSS from './text-viewer.css?inline';
import searchIcon from '../../images/icon-search.svg?raw';
import crossIcon from '../../images/icon-cross-thick.svg?raw';
import crossSmallIcon from '../../images/icon-cross.svg?raw';

import MememoWorkerInline from '../../workers/mememo-worker?worker&inline';
import paperDataJSON from '../../../notebooks/ml-arxiv-papers-1000.json';
const paperData = paperDataJSON as string[];

Expand All @@ -28,11 +31,28 @@ export class MememoTextViewer extends LitElement {
@state()
showSearchBarCancelButton = false;

loaderWorker: Worker;

//==========================================================================||
// Lifecycle Methods ||
//==========================================================================||
constructor() {
super();

this.loaderWorker = new MememoWorkerInline();
this.loaderWorker.addEventListener(
'message',
(e: MessageEvent<MememoWorkerMessage>) =>
this.loaderWorkerMessageHandler(e)
);

const message: MememoWorkerMessage = {
command: 'startLoadData',
payload: {
url: '/data/ml-arxiv-papers-1000.ndjson'
}
};
this.loaderWorker.postMessage(message);
}

/**
Expand All @@ -53,6 +73,19 @@ export class MememoTextViewer extends LitElement {

showSearchBarCancelButtonClicked() {}

loaderWorkerMessageHandler(e: MessageEvent<MememoWorkerMessage>) {
switch (e.data.command) {
case 'transferLoadData': {
console.log(e.data.payload);
break;
}

default: {
console.error(`Unknown command ${e.data.command}`);
}
}
}

//==========================================================================||
// Private Helpers ||
//==========================================================================||
Expand Down
47 changes: 8 additions & 39 deletions examples/rag-playground/src/types/common-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
* Type definitions
*/

export type DocumentRecordStreamData = [string, number[]];

export interface DocumentRecord {
embedding: number[];
id: number;
text: string;
}

export interface SimpleEventMessage {
message: string;
}
Expand Down Expand Up @@ -40,42 +48,3 @@ export interface Size {
width: number;
height: number;
}

export interface PromptModel {
task: string;
prompt: string;
variables: string[];
temperature: number;
stopSequences?: string[];
}

export type TextGenWorkerMessage =
| {
command: 'startTextGen';
payload: {
requestID: string;
apiKey: string;
prompt: string;
temperature: number;
stopSequences?: string[];
detail?: string;
};
}
| {
command: 'finishTextGen';
payload: {
requestID: string;
apiKey: string;
result: string;
prompt: string;
detail: string;
};
}
| {
command: 'error';
payload: {
requestID: string;
originalCommand: string;
message: string;
};
};
155 changes: 155 additions & 0 deletions examples/rag-playground/src/workers/mememo-worker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import { HNSW } from '../../../../src/index';
import type {
DocumentRecord,
DocumentRecordStreamData
} from '../types/common-types';
import {
timeit,
splitStreamTransform,
parseJSONTransform
} from '@xiaohk/utils';

export type MememoWorkerMessage =
| {
command: 'startLoadData';
payload: {
/** NDJSON data url */
url: string;
};
}
| {
command: 'transferLoadData';
payload: {
isFirstBatch: boolean;
isLastBatch: boolean;
points: DocumentRecord[];
loadedPointCount: number;
};
};

const DEV_MODE = import.meta.env.DEV;
const POINT_THRESHOLD = 100;

let pendingDataPoints: DocumentRecord[] = [];
let loadedPointCount = 0;
let sentPointCount = 0;

let lastDrawnPoints: DocumentRecord[] | null = null;

/**
* Handle message events from the main thread
* @param e Message event
*/
self.onmessage = (e: MessageEvent<MememoWorkerMessage>) => {
// Stream point data
switch (e.data.command) {
case 'startLoadData': {
console.log('Worker: start streaming data');
timeit('Stream data', true);

const url = e.data.payload.url;
startLoadData(url);
break;
}

default: {
console.error('Worker: unknown message', e.data.command);
break;
}
}
};

/**
* Start loading the text data
* @param url URL to the NDJSON file
*/
const startLoadData = (url: string) => {
fetch(url).then(async response => {
if (!response.ok) {
console.error('Failed to load data', response);
return;
}

const reader = response.body
?.pipeThrough(new TextDecoderStream())
?.pipeThrough(splitStreamTransform('\n'))
?.pipeThrough(parseJSONTransform())
?.getReader();

while (true && reader !== undefined) {
const result = await reader.read();
const point = result.value as DocumentRecordStreamData;
const done = result.done;

if (done) {
timeit('Stream data', DEV_MODE);
pointStreamFinished();
break;
} else {
processPointStream(point);

// // TODO: Remove me in prod
// if (loadedPointCount >= 305000) {
// pointStreamFinished();
// timeit('Stream data', DEBUG);
// break;
// }
}
}
});
};

/**
* Process one data point
* @param point Loaded data point
*/
const processPointStream = (point: DocumentRecordStreamData) => {
const promptPoint: DocumentRecord = {
text: point[0],
embedding: point[1],
id: loadedPointCount
};

pendingDataPoints.push(promptPoint);
loadedPointCount += 1;

// Notify the main thread if we have load enough data
if (pendingDataPoints.length >= POINT_THRESHOLD) {
const result: MememoWorkerMessage = {
command: 'transferLoadData',
payload: {
isFirstBatch: lastDrawnPoints === null,
isLastBatch: false,
points: pendingDataPoints,
loadedPointCount
}
};
postMessage(result);

sentPointCount += pendingDataPoints.length;
lastDrawnPoints = pendingDataPoints.slice();
pendingDataPoints = [];
}
};

/**
* Construct tree and notify the main thread when finish reading all data
*/
const pointStreamFinished = () => {
// Send any left over points

const result: MememoWorkerMessage = {
command: 'transferLoadData',
payload: {
isFirstBatch: lastDrawnPoints === null,
isLastBatch: true,
points: pendingDataPoints,
loadedPointCount
}
};
postMessage(result);

sentPointCount += pendingDataPoints.length;
lastDrawnPoints = pendingDataPoints.slice();
pendingDataPoints = [];
};
8 changes: 4 additions & 4 deletions examples/rag-playground/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ export default defineConfig(({ command, mode }) => {
// Development
return {
plugins: [
hmrPlugin({
include: ['./src/**/*.ts'],
presets: [presets.lit]
})
// hmrPlugin({
// include: ['./src/**/*.ts'],
// presets: [presets.lit]
// })
]
};
} else if (command === 'build') {
Expand Down

0 comments on commit 5c88e9e

Please sign in to comment.