Skip to content

Commit

Permalink
Adding in thread_pool::spawn function
Browse files Browse the repository at this point in the history
  • Loading branch information
Pauan committed May 14, 2024
1 parent 9396331 commit 2956ee4
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 184 deletions.
316 changes: 144 additions & 172 deletions sdk/src/polyfill/worker.ts
Original file line number Diff line number Diff line change
@@ -1,216 +1,188 @@
function patch($worker: typeof import("node:worker_threads"), $os: typeof import("node:os")) {
// This is technically not a part of the Worker polyfill,
// but Workers are used for multi-threading, so this is often
// needed when writing Worker code.
if (globalThis.navigator == null) {
globalThis.navigator = {
hardwareConcurrency: $os.cpus().length,
} as Navigator;
}

globalThis.Worker = class Worker extends EventTarget {
private _worker: import("node:worker_threads").Worker;
import * as $worker from "node:worker_threads";
import * as $os from "node:os";

// This is technically not a part of the Worker polyfill,
// but Workers are used for multi-threading, so this is often
// needed when writing Worker code.
if (globalThis.navigator == null) {
globalThis.navigator = {
hardwareConcurrency: $os.cpus().length,
} as Navigator;
}

constructor(url: string | URL, options?: WorkerOptions | undefined) {
super();
globalThis.Worker = class Worker extends EventTarget {
private _worker: import("node:worker_threads").Worker;

if (url instanceof URL) {
if (url.protocol !== "file:") {
throw new Error("Worker only supports file: URLs");
}
constructor(url: string | URL, options?: WorkerOptions | undefined) {
super();

url = url.href;

} else {
throw new Error("Filepaths are unreliable, use `new URL(\"...\", import.meta.url)` instead.");
if (url instanceof URL) {
if (url.protocol !== "file:") {
throw new Error("Worker only supports file: URLs");
}

if (!options || options.type !== "module") {
throw new Error("Workers must use \`type: \"module\"\`");
}
url = url.href;

// This uses some funky stuff like `patch.toString()`.
//
// This is needed so that it can synchronously run the polyfill code
// inside of the worker.
//
// It can't use `require` because the file doesn't have a `.cjs` file extension.
//
// It can't use `import` because that's asynchronous, and the file path
// might be different if using a bundler.
const code = `
${patch.toString()}
// Inject the polyfill into the worker
patch(require("node:worker_threads"), require("node:os"));
const { workerData } = require("node:worker_threads");
// This actually loads and runs the worker file
import(workerData.url)
.catch((e) => {
// TODO maybe it should send a message to the parent?
console.error(e.stack);
});
`;

this._worker = new $worker.Worker(code, {
eval: true,
workerData: {
url,
},
});

this._worker.on("message", (data) => {
this.dispatchEvent(new MessageEvent("message", { data }));
});

this._worker.on("messageerror", (error) => {
throw new Error("UNIMPLEMENTED");
});

this._worker.on("error", (error) => {
// TODO attach the error to the event somehow
const event = new Event("error");
this.dispatchEvent(event);
});
} else {
throw new Error("Filepaths are unreliable, use `new URL(\"...\", import.meta.url)` instead.");
}

set onmessage(f: () => void) {
throw new Error("UNIMPLEMENTED");
if (!options || options.type !== "module") {
throw new Error("Workers must use \`type: \"module\"\`");
}

set onmessageerror(f: () => void) {
throw new Error("UNIMPLEMENTED");
}
const code = `
const { workerData } = require("node:worker_threads");
import(workerData.polyfill)
.then(() => import(workerData.url))
.catch((e) => {
// TODO maybe it should send a message to the parent?
console.error(e.stack);
});
`;

this._worker = new $worker.Worker(code, {
eval: true,
workerData: {
url,
polyfill: new URL("node-polyfill.js", import.meta.url).href,
},
});

set onerror(f: () => void) {
this._worker.on("message", (data) => {
this.dispatchEvent(new MessageEvent("message", { data }));
});

this._worker.on("messageerror", (error) => {
throw new Error("UNIMPLEMENTED");
}
});

postMessage(message: any, transfer: Array<Transferable>): void;
postMessage(message: any, options?: StructuredSerializeOptions | undefined): void;
postMessage(value: any, transfer: any) {
this._worker.postMessage(value, transfer);
}
this._worker.on("error", (error) => {
// TODO attach the error to the event somehow
const event = new Event("error");
this.dispatchEvent(event);
});
}

terminate() {
this._worker.terminate();
}
set onmessage(f: () => void) {
throw new Error("UNIMPLEMENTED");
}

// This is Node-specific, it allows the process to exit
// even if the Worker is still running.
unref() {
this._worker.unref();
}
};
set onmessageerror(f: () => void) {
throw new Error("UNIMPLEMENTED");
}

set onerror(f: () => void) {
throw new Error("UNIMPLEMENTED");
}

if (!$worker.isMainThread) {
const globals = globalThis as unknown as DedicatedWorkerGlobalScope;
postMessage(message: any, transfer: Array<Transferable>): void;
postMessage(message: any, options?: StructuredSerializeOptions | undefined): void;
postMessage(value: any, transfer: any) {
this._worker.postMessage(value, transfer);
}

// This is used to create the onmessage, onmessageerror, and onerror setters
const makeSetter = (prop: string, event: string) => {
let oldvalue: () => void;
terminate() {
this._worker.terminate();
}

Object.defineProperty(globals, prop, {
get() {
return oldvalue;
},
set(value) {
if (oldvalue) {
globals.removeEventListener(event, oldvalue);
}
// This is Node-specific, it allows the process to exit
// even if the Worker is still running.
unref() {
this._worker.unref();
}
};

oldvalue = value;

if (oldvalue) {
globals.addEventListener(event, oldvalue);
}
},
});
};
if (!$worker.isMainThread) {
const globals = globalThis as unknown as DedicatedWorkerGlobalScope;

// This makes sure that `f` is only run once
const memoize = (f: () => void) => {
let run = false;
// This is used to create the onmessage, onmessageerror, and onerror setters
const makeSetter = (prop: string, event: string) => {
let oldvalue: () => void;

return () => {
if (!run) {
run = true;
f();
Object.defineProperty(globals, prop, {
get() {
return oldvalue;
},
set(value) {
if (oldvalue) {
globals.removeEventListener(event, oldvalue);
}
};
};

oldvalue = value;

// We only start listening for messages / errors when the worker calls addEventListener
const startOnMessage = memoize(() => {
$worker.parentPort!.on("message", (data) => {
workerEvents.dispatchEvent(new MessageEvent("message", { data }));
});
if (oldvalue) {
globals.addEventListener(event, oldvalue);
}
},
});
};

const startOnMessageError = memoize(() => {
throw new Error("UNIMPLEMENTED");
});
// This makes sure that `f` is only run once
const memoize = (f: () => void) => {
let run = false;

const startOnError = memoize(() => {
$worker.parentPort!.on("error", (data) => {
workerEvents.dispatchEvent(new Event("error"));
});
return () => {
if (!run) {
run = true;
f();
}
};
};


// We only start listening for messages / errors when the worker calls addEventListener
const startOnMessage = memoize(() => {
$worker.parentPort!.on("message", (data) => {
workerEvents.dispatchEvent(new MessageEvent("message", { data }));
});
});

const startOnMessageError = memoize(() => {
throw new Error("UNIMPLEMENTED");
});

// Node workers don't have top-level events, so we have to make our own
const workerEvents = new EventTarget();
const startOnError = memoize(() => {
$worker.parentPort!.on("error", (data) => {
workerEvents.dispatchEvent(new Event("error"));
});
});

globals.close = () => {
process.exit();
};

globals.addEventListener = (type: string, callback: EventListenerOrEventListenerObject | null, options?: boolean | EventListenerOptions | undefined) => {
workerEvents.addEventListener(type, callback, options);
// Node workers don't have top-level events, so we have to make our own
const workerEvents = new EventTarget();

if (type === "message") {
startOnMessage();
} else if (type === "messageerror") {
startOnMessageError();
} else if (type === "error") {
startOnError();
}
};
globals.close = () => {
process.exit();
};

globals.removeEventListener = (type: string, callback: EventListenerOrEventListenerObject | null, options?: boolean | EventListenerOptions | undefined) => {
workerEvents.removeEventListener(type, callback, options);
};
globals.addEventListener = (type: string, callback: EventListenerOrEventListenerObject | null, options?: boolean | EventListenerOptions | undefined) => {
workerEvents.addEventListener(type, callback, options);

function postMessage(message: any, transfer: Transferable[]): void;
function postMessage(message: any, options?: StructuredSerializeOptions | undefined): void;
function postMessage(value: any, transfer: any) {
$worker.parentPort!.postMessage(value, transfer);
if (type === "message") {
startOnMessage();
} else if (type === "messageerror") {
startOnMessageError();
} else if (type === "error") {
startOnError();
}
};

globals.postMessage = postMessage;
globals.removeEventListener = (type: string, callback: EventListenerOrEventListenerObject | null, options?: boolean | EventListenerOptions | undefined) => {
workerEvents.removeEventListener(type, callback, options);
};

makeSetter("onmessage", "message");
makeSetter("onmessageerror", "messageerror");
makeSetter("onerror", "error");
function postMessage(message: any, transfer: Transferable[]): void;
function postMessage(message: any, options?: StructuredSerializeOptions | undefined): void;
function postMessage(value: any, transfer: any) {
$worker.parentPort!.postMessage(value, transfer);
}
}


async function polyfill() {
const [$worker, $os] = await Promise.all([
import("node:worker_threads"),
import("node:os"),
]);
globals.postMessage = postMessage;

patch($worker, $os);
makeSetter("onmessage", "message");
makeSetter("onmessageerror", "messageerror");
makeSetter("onerror", "error");
}

if (globalThis.Worker == null) {
await polyfill();
}

export {};
18 changes: 14 additions & 4 deletions wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,20 @@ pub use types::Field;
#[cfg(not(test))]
mod thread_pool;

use wasm_bindgen::prelude::*;
#[cfg(test)]
mod thread_pool {
use std::future::Future;

pub fn spawn<A, F>(f: F) -> impl Future<Output = A>
where
A: Send + 'static,
F: FnOnce() -> A + Send + 'static,
{
async move { f() }
}
}

#[cfg(not(test))]
use thread_pool::ThreadPool;
use wasm_bindgen::prelude::*;

use std::str::FromStr;

Expand Down Expand Up @@ -219,7 +229,7 @@ use types::native;
pub async fn init_thread_pool(url: web_sys::Url, num_threads: usize) -> Result<(), JsValue> {
console_error_panic_hook::set_once();

ThreadPool::builder().url(url).num_threads(num_threads).build_global().await?;
thread_pool::ThreadPool::builder().url(url).num_threads(num_threads).build_global().await?;

Ok(())
}
Loading

0 comments on commit 2956ee4

Please sign in to comment.