Skip to content

Commit

Permalink
Enhance async configuration of bindgen! macro (bytecodealliance#6942
Browse files Browse the repository at this point in the history
)

This commit takes a leaf out of `wiggle`'s book to enable bindings
generation for async host functions where only some host functions are
async instead of all of them. This enhances the `async` key with a few
more options:

    async: {
        except_imports: ["foo"],
        only_imports: ["bar"],
    }

This is beyond what `wiggle` supports where either an allow-list or
deny-list can be specified (although only one can be specified). This
can be useful if either the list of sync imports or the list of async
imports is small.
  • Loading branch information
alexcrichton authored Aug 31, 2023
1 parent c56cc24 commit 326837d
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 21 deletions.
56 changes: 50 additions & 6 deletions crates/component-macro/src/bindgen.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use proc_macro2::{Span, TokenStream};
use std::collections::HashMap;
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use syn::parse::{Error, Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
use syn::{braced, token, Ident, Token};
use wasmtime_wit_bindgen::{Opts, Ownership, TrappableError};
use wasmtime_wit_bindgen::{AsyncConfig, Opts, Ownership, TrappableError};
use wit_parser::{PackageId, Resolve, UnresolvedPackage, WorldId};

pub struct Config {
Expand All @@ -15,7 +16,7 @@ pub struct Config {
}

pub fn expand(input: &Config) -> Result<TokenStream> {
if !cfg!(feature = "async") && input.opts.async_ {
if !cfg!(feature = "async") && input.opts.async_.maybe_async() {
return Err(Error::new(
Span::call_site(),
"cannot enable async bindings unless `async` crate feature is active",
Expand Down Expand Up @@ -45,6 +46,7 @@ impl Parse for Config {
let mut world = None;
let mut inline = None;
let mut path = None;
let mut async_configured = false;

if input.peek(token::Brace) {
let content;
Expand All @@ -71,7 +73,13 @@ impl Parse for Config {
inline = Some(s.value());
}
Opt::Tracing(val) => opts.tracing = val,
Opt::Async(val) => opts.async_ = val,
Opt::Async(val, span) => {
if async_configured {
return Err(Error::new(span, "cannot specify second async config"));
}
async_configured = true;
opts.async_ = val;
}
Opt::TrappableErrorType(val) => opts.trappable_error_type = val,
Opt::Ownership(val) => opts.ownership = val,
Opt::Interfaces(s) => {
Expand Down Expand Up @@ -171,14 +179,16 @@ mod kw {
syn::custom_keyword!(ownership);
syn::custom_keyword!(interfaces);
syn::custom_keyword!(with);
syn::custom_keyword!(except_imports);
syn::custom_keyword!(only_imports);
}

enum Opt {
World(syn::LitStr),
Path(syn::LitStr),
Inline(syn::LitStr),
Tracing(bool),
Async(bool),
Async(AsyncConfig, Span),
TrappableErrorType(Vec<TrappableError>),
Ownership(Ownership),
Interfaces(syn::LitStr),
Expand All @@ -205,9 +215,43 @@ impl Parse for Opt {
input.parse::<Token![:]>()?;
Ok(Opt::Tracing(input.parse::<syn::LitBool>()?.value))
} else if l.peek(Token![async]) {
input.parse::<Token![async]>()?;
let span = input.parse::<Token![async]>()?.span;
input.parse::<Token![:]>()?;
Ok(Opt::Async(input.parse::<syn::LitBool>()?.value))
if input.peek(syn::LitBool) {
match input.parse::<syn::LitBool>()?.value {
true => Ok(Opt::Async(AsyncConfig::All, span)),
false => Ok(Opt::Async(AsyncConfig::None, span)),
}
} else {
let contents;
syn::braced!(contents in input);

let l = contents.lookahead1();
let ctor: fn(HashSet<String>) -> AsyncConfig = if l.peek(kw::except_imports) {
contents.parse::<kw::except_imports>()?;
contents.parse::<Token![:]>()?;
AsyncConfig::AllExceptImports
} else if l.peek(kw::only_imports) {
contents.parse::<kw::only_imports>()?;
contents.parse::<Token![:]>()?;
AsyncConfig::OnlyImports
} else {
return Err(l.error());
};

let list;
syn::bracketed!(list in contents);
let fields: Punctuated<syn::LitStr, Token![,]> =
list.parse_terminated(Parse::parse, Token![,])?;

if contents.peek(Token![,]) {
contents.parse::<Token![,]>()?;
}
Ok(Opt::Async(
ctor(fields.iter().map(|s| s.value()).collect()),
span,
))
}
} else if l.peek(kw::ownership) {
input.parse::<kw::ownership>()?;
input.parse::<Token![:]>()?;
Expand Down
16 changes: 16 additions & 0 deletions crates/wasmtime/src/component/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,22 @@ pub(crate) use self::store::ComponentStoreData;
/// // This option defaults to `false`.
/// async: true,
///
/// // Alternative mode of async configuration where this still implies
/// // async instantiation happens, for example, but more control is
/// // provided over which imports are async and which aren't.
/// //
/// // Note that in this mode all exports are still async.
/// async: {
/// // All imports are async except for functions with these names
/// except_imports: ["foo", "bar"],
///
/// // All imports are synchronous except for functions with these names
/// //
/// // Note that this key cannot be specified with `except_imports`,
/// // only one or the other is accepted.
/// only_imports: ["foo", "bar"],
/// },
///
/// // This can be used to translate WIT return values of the form
/// // `result<T, error-type>` into `Result<T, RustErrorType>` in Rust.
/// // The `RustErrorType` structure will have an automatically generated
Expand Down
70 changes: 55 additions & 15 deletions crates/wit-bindgen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::types::{TypeInfo, Types};
use anyhow::{anyhow, bail, Context};
use heck::*;
use indexmap::IndexMap;
use std::collections::{BTreeMap, HashMap};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::Write as _;
use std::io::{Read, Write};
use std::mem;
Expand Down Expand Up @@ -94,7 +94,7 @@ pub struct Opts {
pub tracing: bool,

/// Whether or not to use async rust functions and traits.
pub async_: bool,
pub async_: AsyncConfig,

/// A list of "trappable errors" which are used to replace the `E` in
/// `result<T, E>` found in WIT.
Expand Down Expand Up @@ -123,6 +123,42 @@ pub struct TrappableError {
pub rust_type_name: String,
}

#[derive(Default, Debug, Clone)]
pub enum AsyncConfig {
/// No functions are `async`.
#[default]
None,
/// All generated functions should be `async`.
All,
/// These imported functions should not be async, but everything else is.
AllExceptImports(HashSet<String>),
/// These functions are the only imports that are async, all other imports
/// are sync.
///
/// Note that all exports are still async in this situation.
OnlyImports(HashSet<String>),
}

impl AsyncConfig {
pub fn is_import_async(&self, f: &str) -> bool {
match self {
AsyncConfig::None => false,
AsyncConfig::All => true,
AsyncConfig::AllExceptImports(set) => !set.contains(f),
AsyncConfig::OnlyImports(set) => set.contains(f),
}
}

pub fn maybe_async(&self) -> bool {
match self {
AsyncConfig::None => false,
AsyncConfig::All | AsyncConfig::AllExceptImports(_) | AsyncConfig::OnlyImports(_) => {
true
}
}
}
}

impl Opts {
pub fn generate(&self, resolve: &Resolve, world: WorldId) -> String {
let mut r = Wasmtime::default();
Expand Down Expand Up @@ -412,7 +448,7 @@ impl Wasmtime {
}
self.src.push_str("}\n");

let (async_, async__, send, await_) = if self.opts.async_ {
let (async_, async__, send, await_) = if self.opts.async_.maybe_async() {
("async", "_async", ":Send", ".await")
} else {
("", "", "", "")
Expand Down Expand Up @@ -577,7 +613,7 @@ impl Wasmtime {
}

let world_camel = to_rust_upper_camel_case(&resolve.worlds[world].name);
if self.opts.async_ {
if self.opts.async_.maybe_async() {
uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]")
}
uwrite!(self.src, "pub trait {world_camel}Imports");
Expand Down Expand Up @@ -646,7 +682,7 @@ impl Wasmtime {
self.src.push_str(&name);
}

let maybe_send = if self.opts.async_ {
let maybe_send = if self.opts.async_.maybe_async() {
" + Send, T: Send"
} else {
""
Expand Down Expand Up @@ -854,7 +890,7 @@ impl<'a> InterfaceGenerator<'a> {
self.rustdoc(docs);
uwriteln!(self.src, "pub enum {camel} {{}}");

if self.gen.opts.async_ {
if self.gen.opts.async_.maybe_async() {
uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]")
}

Expand Down Expand Up @@ -1375,7 +1411,7 @@ impl<'a> InterfaceGenerator<'a> {
let iface = &self.resolve.interfaces[id];
let owner = TypeOwner::Interface(id);

if self.gen.opts.async_ {
if self.gen.opts.async_.maybe_async() {
uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]")
}
// Generate the `pub trait` which represents the host functionality for
Expand All @@ -1400,7 +1436,7 @@ impl<'a> InterfaceGenerator<'a> {
}
uwriteln!(self.src, "}}");

let where_clause = if self.gen.opts.async_ {
let where_clause = if self.gen.opts.async_.maybe_async() {
"T: Send, U: Host + Send".to_string()
} else {
"U: Host".to_string()
Expand Down Expand Up @@ -1443,7 +1479,7 @@ impl<'a> InterfaceGenerator<'a> {
uwrite!(
self.src,
"{linker}.{}(\"{}\", ",
if self.gen.opts.async_ {
if self.gen.opts.async_.is_import_async(&func.name) {
"func_wrap_async"
} else {
"func_wrap"
Expand Down Expand Up @@ -1472,7 +1508,7 @@ impl<'a> InterfaceGenerator<'a> {
self.src.push_str(", ");
}
self.src.push_str(") |");
if self.gen.opts.async_ {
if self.gen.opts.async_.is_import_async(&func.name) {
self.src.push_str(" Box::new(async move { \n");
} else {
self.src.push_str(" { \n");
Expand Down Expand Up @@ -1541,7 +1577,7 @@ impl<'a> InterfaceGenerator<'a> {
for (i, _) in func.params.iter().enumerate() {
uwrite!(self.src, "arg{},", i);
}
if self.gen.opts.async_ {
if self.gen.opts.async_.is_import_async(&func.name) {
uwrite!(self.src, ").await;\n");
} else {
uwrite!(self.src, ");\n");
Expand Down Expand Up @@ -1571,7 +1607,7 @@ impl<'a> InterfaceGenerator<'a> {
uwrite!(self.src, "r\n");
}

if self.gen.opts.async_ {
if self.gen.opts.async_.is_import_async(&func.name) {
// Need to close Box::new and async block
self.src.push_str("})");
} else {
Expand All @@ -1582,7 +1618,7 @@ impl<'a> InterfaceGenerator<'a> {
fn generate_function_trait_sig(&mut self, func: &Function) {
self.rustdoc(&func.docs);

if self.gen.opts.async_ {
if self.gen.opts.async_.is_import_async(&func.name) {
self.push_str("async ");
}
self.push_str("fn ");
Expand Down Expand Up @@ -1658,7 +1694,11 @@ impl<'a> InterfaceGenerator<'a> {
ns: Option<&WorldKey>,
func: &Function,
) {
let (async_, async__, await_) = if self.gen.opts.async_ {
// Exports must be async if anything could be async, it's just imports
// that get to be optionally async/sync.
let is_async = self.gen.opts.async_.maybe_async();

let (async_, async__, await_) = if is_async {
("async", "_async", ".await")
} else {
("", "", "")
Expand All @@ -1681,7 +1721,7 @@ impl<'a> InterfaceGenerator<'a> {
self.src.push_str(") -> wasmtime::Result<");
self.print_result_ty(&func.results, TypeMode::Owned);

if self.gen.opts.async_ {
if is_async {
self.src
.push_str("> where <S as wasmtime::AsContext>::Data: Send {\n");
} else {
Expand Down
Loading

0 comments on commit 326837d

Please sign in to comment.