-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Jay Wang <[email protected]>
- Loading branch information
Showing
5 changed files
with
201 additions
and
43 deletions.
There are no files selected for viewing
1 change: 1 addition & 0 deletions
1
examples/rag-playground/public/data/ml-arxiv-papers-1000.ndjson
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../notebooks/ml-arxiv-papers-1000.ndjson |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import { HNSW } from '../../../../src/index'; | ||
import type { | ||
DocumentRecord, | ||
DocumentRecordStreamData | ||
} from '../types/common-types'; | ||
import { | ||
timeit, | ||
splitStreamTransform, | ||
parseJSONTransform | ||
} from '@xiaohk/utils'; | ||
|
||
export type MememoWorkerMessage = | ||
| { | ||
command: 'startLoadData'; | ||
payload: { | ||
/** NDJSON data url */ | ||
url: string; | ||
}; | ||
} | ||
| { | ||
command: 'transferLoadData'; | ||
payload: { | ||
isFirstBatch: boolean; | ||
isLastBatch: boolean; | ||
points: DocumentRecord[]; | ||
loadedPointCount: number; | ||
}; | ||
}; | ||
|
||
const DEV_MODE = import.meta.env.DEV; | ||
const POINT_THRESHOLD = 100; | ||
|
||
let pendingDataPoints: DocumentRecord[] = []; | ||
let loadedPointCount = 0; | ||
let sentPointCount = 0; | ||
|
||
let lastDrawnPoints: DocumentRecord[] | null = null; | ||
|
||
/** | ||
* Handle message events from the main thread | ||
* @param e Message event | ||
*/ | ||
self.onmessage = (e: MessageEvent<MememoWorkerMessage>) => { | ||
// Stream point data | ||
switch (e.data.command) { | ||
case 'startLoadData': { | ||
console.log('Worker: start streaming data'); | ||
timeit('Stream data', true); | ||
|
||
const url = e.data.payload.url; | ||
startLoadData(url); | ||
break; | ||
} | ||
|
||
default: { | ||
console.error('Worker: unknown message', e.data.command); | ||
break; | ||
} | ||
} | ||
}; | ||
|
||
/** | ||
* Start loading the text data | ||
* @param url URL to the NDJSON file | ||
*/ | ||
const startLoadData = (url: string) => { | ||
fetch(url).then(async response => { | ||
if (!response.ok) { | ||
console.error('Failed to load data', response); | ||
return; | ||
} | ||
|
||
const reader = response.body | ||
?.pipeThrough(new TextDecoderStream()) | ||
?.pipeThrough(splitStreamTransform('\n')) | ||
?.pipeThrough(parseJSONTransform()) | ||
?.getReader(); | ||
|
||
while (true && reader !== undefined) { | ||
const result = await reader.read(); | ||
const point = result.value as DocumentRecordStreamData; | ||
const done = result.done; | ||
|
||
if (done) { | ||
timeit('Stream data', DEV_MODE); | ||
pointStreamFinished(); | ||
break; | ||
} else { | ||
processPointStream(point); | ||
|
||
// // TODO: Remove me in prod | ||
// if (loadedPointCount >= 305000) { | ||
// pointStreamFinished(); | ||
// timeit('Stream data', DEBUG); | ||
// break; | ||
// } | ||
} | ||
} | ||
}); | ||
}; | ||
|
||
/** | ||
* Process one data point | ||
* @param point Loaded data point | ||
*/ | ||
const processPointStream = (point: DocumentRecordStreamData) => { | ||
const promptPoint: DocumentRecord = { | ||
text: point[0], | ||
embedding: point[1], | ||
id: loadedPointCount | ||
}; | ||
|
||
pendingDataPoints.push(promptPoint); | ||
loadedPointCount += 1; | ||
|
||
// Notify the main thread if we have load enough data | ||
if (pendingDataPoints.length >= POINT_THRESHOLD) { | ||
const result: MememoWorkerMessage = { | ||
command: 'transferLoadData', | ||
payload: { | ||
isFirstBatch: lastDrawnPoints === null, | ||
isLastBatch: false, | ||
points: pendingDataPoints, | ||
loadedPointCount | ||
} | ||
}; | ||
postMessage(result); | ||
|
||
sentPointCount += pendingDataPoints.length; | ||
lastDrawnPoints = pendingDataPoints.slice(); | ||
pendingDataPoints = []; | ||
} | ||
}; | ||
|
||
/** | ||
* Construct tree and notify the main thread when finish reading all data | ||
*/ | ||
const pointStreamFinished = () => { | ||
// Send any left over points | ||
|
||
const result: MememoWorkerMessage = { | ||
command: 'transferLoadData', | ||
payload: { | ||
isFirstBatch: lastDrawnPoints === null, | ||
isLastBatch: true, | ||
points: pendingDataPoints, | ||
loadedPointCount | ||
} | ||
}; | ||
postMessage(result); | ||
|
||
sentPointCount += pendingDataPoints.length; | ||
lastDrawnPoints = pendingDataPoints.slice(); | ||
pendingDataPoints = []; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters