diff --git a/Cargo.lock b/Cargo.lock index 249045ef28d8..278d9b9c2594 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -441,6 +441,17 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "component-fuzz-util" +version = "0.1.0" +dependencies = [ + "anyhow", + "arbitrary", + "proc-macro2", + "quote", + "wasmtime-component-util", +] + [[package]] name = "component-macro-test" version = "0.1.0" @@ -450,6 +461,16 @@ dependencies = [ "syn", ] +[[package]] +name = "component-test-util" +version = "0.1.0" +dependencies = [ + "anyhow", + "arbitrary", + "env_logger 0.9.0", + "wasmtime", +] + [[package]] name = "console" version = "0.15.0" @@ -3415,6 +3436,7 @@ dependencies = [ "async-trait", "clap 3.2.8", "component-macro-test", + "component-test-util", "criterion", "env_logger 0.9.0", "filecheck", @@ -3434,6 +3456,7 @@ dependencies = [ "wasmtime", "wasmtime-cache", "wasmtime-cli-flags", + "wasmtime-component-util", "wasmtime-cranelift", "wasmtime-environ", "wasmtime-runtime", @@ -3520,6 +3543,7 @@ name = "wasmtime-environ-fuzz" version = "0.0.0" dependencies = [ "arbitrary", + "component-fuzz-util", "env_logger 0.9.0", "libfuzzer-sys", "wasmparser", @@ -3543,6 +3567,10 @@ dependencies = [ name = "wasmtime-fuzz" version = "0.0.0" dependencies = [ + "anyhow", + "arbitrary", + "component-fuzz-util", + "component-test-util", "cranelift-codegen", "cranelift-filetests", "cranelift-fuzzgen", @@ -3550,6 +3578,9 @@ dependencies = [ "cranelift-reader", "cranelift-wasm", "libfuzzer-sys", + "proc-macro2", + "quote", + "rand 0.8.5", "target-lexicon", "wasmtime", "wasmtime-fuzzing", @@ -3561,6 +3592,8 @@ version = "0.19.0" dependencies = [ "anyhow", "arbitrary", + "component-fuzz-util", + "component-test-util", "env_logger 0.9.0", "log", "rand 0.8.5", diff --git a/Cargo.toml b/Cargo.toml index 76484a280c85..6b6db578906a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,8 @@ once_cell = "1.9.0" rayon = "1.5.0" component-macro-test = { path = "crates/misc/component-macro-test" } wasmtime-wast = { path = "crates/wast", version = "=0.40.0", features = ['component-model'] } +component-test-util = { path = "crates/misc/component-test-util" } +wasmtime-component-util = { path = "crates/component-util" } [target.'cfg(windows)'.dev-dependencies] windows-sys = { version = "0.36.0", features = ["Win32_System_Memory"] } @@ -110,7 +112,11 @@ memory-init-cow = ["wasmtime/memory-init-cow", "wasmtime-cli-flags/memory-init-c pooling-allocator = ["wasmtime/pooling-allocator", "wasmtime-cli-flags/pooling-allocator"] all-arch = ["wasmtime/all-arch"] posix-signals-on-macos = ["wasmtime/posix-signals-on-macos"] -component-model = ["wasmtime/component-model", "wasmtime-wast/component-model", "wasmtime-cli-flags/component-model"] +component-model = [ + "wasmtime/component-model", + "wasmtime-wast/component-model", + "wasmtime-cli-flags/component-model" +] # Stub feature that does nothing, for Cargo-features compatibility: the new # backend is the default now. diff --git a/build.rs b/build.rs index 17ae15aa5476..d44b8c0c0e2d 100644 --- a/build.rs +++ b/build.rs @@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> { } else { println!( "cargo:warning=The spec testsuite is disabled. To enable, run `git submodule \ - update --remote`." + update --remote`." ); } Ok(()) diff --git a/crates/component-macro/src/lib.rs b/crates/component-macro/src/lib.rs index eb85f4010ce4..06f11c60211f 100644 --- a/crates/component-macro/src/lib.rs +++ b/crates/component-macro/src/lib.rs @@ -885,7 +885,10 @@ impl Expander for ComponentTypeExpander { const SIZE32: usize = { let mut size = 0; #sizes - #internal::align_to(#discriminant_size as usize, Self::ALIGN32) + size + #internal::align_to( + #internal::align_to(#discriminant_size as usize, Self::ALIGN32) + size, + Self::ALIGN32 + ) }; const ALIGN32: u32 = { diff --git a/crates/component-util/src/lib.rs b/crates/component-util/src/lib.rs index 3823abcedc79..409ba551c8f4 100644 --- a/crates/component-util/src/lib.rs +++ b/crates/component-util/src/lib.rs @@ -77,3 +77,62 @@ impl FlagsSize { fn ceiling_divide(n: usize, d: usize) -> usize { (n + d - 1) / d } + +/// A simple bump allocator which can be used with modules +pub const REALLOC_AND_FREE: &str = r#" + (global $last (mut i32) (i32.const 8)) + (func $realloc (export "realloc") + (param $old_ptr i32) + (param $old_size i32) + (param $align i32) + (param $new_size i32) + (result i32) + + ;; Test if the old pointer is non-null + local.get $old_ptr + if + ;; If the old size is bigger than the new size then + ;; this is a shrink and transparently allow it + local.get $old_size + local.get $new_size + i32.gt_u + if + local.get $old_ptr + return + end + + ;; ... otherwise this is unimplemented + unreachable + end + + ;; align up `$last` + (global.set $last + (i32.and + (i32.add + (global.get $last) + (i32.add + (local.get $align) + (i32.const -1))) + (i32.xor + (i32.add + (local.get $align) + (i32.const -1)) + (i32.const -1)))) + + ;; save the current value of `$last` as the return value + global.get $last + + ;; ensure anything necessary is set to valid data by spraying a bit + ;; pattern that is invalid + global.get $last + i32.const 0xde + local.get $new_size + memory.fill + + ;; bump our pointer + (global.set $last + (i32.add + (global.get $last) + (local.get $new_size))) + ) +"#; diff --git a/crates/environ/fuzz/Cargo.toml b/crates/environ/fuzz/Cargo.toml index 90086f36f24e..4cd5b1215148 100644 --- a/crates/environ/fuzz/Cargo.toml +++ b/crates/environ/fuzz/Cargo.toml @@ -15,6 +15,7 @@ libfuzzer-sys = "0.4" wasmparser = "0.88.0" wasmprinter = "0.2.37" wasmtime-environ = { path = ".." } +component-fuzz-util = { path = "../../misc/component-fuzz-util", optional = true } [[bin]] name = "fact-valid-module" @@ -24,4 +25,4 @@ doc = false required-features = ["component-model"] [features] -component-model = ["wasmtime-environ/component-model"] +component-model = ["wasmtime-environ/component-model", "dep:component-fuzz-util"] diff --git a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs index 225efed322af..58f2488cf174 100644 --- a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs +++ b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs @@ -9,9 +9,9 @@ #![no_main] -use arbitrary::{Arbitrary, Unstructured}; +use arbitrary::Arbitrary; +use component_fuzz_util::Type as ValType; use libfuzzer_sys::fuzz_target; -use std::fmt; use wasmparser::{Validator, WasmFeatures}; use wasmtime_environ::component::*; use wasmtime_environ::fact::Module; @@ -38,34 +38,6 @@ struct FuncType { result: ValType, } -#[derive(Arbitrary, Debug)] -enum ValType { - Unit, - U8, - S8, - U16, - S16, - U32, - S32, - U64, - S64, - Float32, - Float64, - Char, - List(Box), - Record(Vec), - // Up to 65 flags to exercise up to 3 u32 values - Flags(UsizeInRange<0, 65>), - Tuple(Vec), - Variant(NonZeroLenVec), - Union(NonZeroLenVec), - // at least one enum variant but no more than what's necessary to inflate to - // 16 bits to keep this reasonably sized - Enum(UsizeInRange<1, 257>), - Option(Box), - Expected(Box, Box), -} - #[derive(Copy, Clone, Arbitrary, Debug)] enum GenStringEncoding { Utf8, @@ -73,39 +45,9 @@ enum GenStringEncoding { CompactUtf16, } -pub struct NonZeroLenVec(Vec); - -impl<'a, T: Arbitrary<'a>> Arbitrary<'a> for NonZeroLenVec { - fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { - let mut items = Vec::arbitrary(u)?; - if items.is_empty() { - items.push(u.arbitrary()?); - } - Ok(NonZeroLenVec(items)) - } -} - -impl fmt::Debug for NonZeroLenVec { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -pub struct UsizeInRange(usize); +fuzz_target!(|module: GenAdapterModule| { drop(target(module)) }); -impl<'a, const L: usize, const H: usize> Arbitrary<'a> for UsizeInRange { - fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { - Ok(UsizeInRange(u.int_in_range(L..=H)?)) - } -} - -impl fmt::Debug for UsizeInRange { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -fuzz_target!(|module: GenAdapterModule| { +fn target(module: GenAdapterModule) -> Result<(), ()> { drop(env_logger::try_init()); let mut types = ComponentTypesBuilder::default(); @@ -148,9 +90,9 @@ fuzz_target!(|module: GenAdapterModule| { for adapter in module.adapters.iter() { let mut params = Vec::new(); for param in adapter.ty.params.iter() { - params.push((None, intern(&mut types, param))); + params.push((None, intern(&mut types, param)?)); } - let result = intern(&mut types, &adapter.ty.result); + let result = intern(&mut types, &adapter.ty.result)?; let signature = types.add_func_type(TypeFunc { params: params.into(), result, @@ -201,7 +143,7 @@ fuzz_target!(|module: GenAdapterModule| { .validate_all(&wasm); let err = match result { - Ok(_) => return, + Ok(_) => return Ok(()), Err(e) => e, }; eprintln!("invalid wasm module: {err:?}"); @@ -215,11 +157,12 @@ fuzz_target!(|module: GenAdapterModule| { } panic!() -}); +} -fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType { - match ty { +fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> Result { + Ok(match ty { ValType::Unit => InterfaceType::Unit, + ValType::Bool => InterfaceType::Bool, ValType::U8 => InterfaceType::U8, ValType::S8 => InterfaceType::S8, ValType::U16 => InterfaceType::U16, @@ -232,7 +175,7 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType { ValType::Float64 => InterfaceType::Float64, ValType::Char => InterfaceType::Char, ValType::List(ty) => { - let ty = intern(types, ty); + let ty = intern(types, ty)?; InterfaceType::List(types.add_interface_type(ty)) } ValType::Record(tys) => { @@ -240,61 +183,72 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType { fields: tys .iter() .enumerate() - .map(|(i, ty)| RecordField { - name: format!("f{i}"), - ty: intern(types, ty), + .map(|(i, ty)| { + Ok(RecordField { + name: format!("f{i}"), + ty: intern(types, ty)?, + }) }) - .collect(), + .collect::>()?, }; InterfaceType::Record(types.add_record_type(ty)) } ValType::Flags(size) => { let ty = TypeFlags { - names: (0..size.0).map(|i| format!("f{i}")).collect(), + names: (0..size.as_usize()).map(|i| format!("f{i}")).collect(), }; InterfaceType::Flags(types.add_flags_type(ty)) } ValType::Tuple(tys) => { let ty = TypeTuple { - types: tys.iter().map(|ty| intern(types, ty)).collect(), + types: tys + .iter() + .map(|ty| intern(types, ty)) + .collect::>()?, }; InterfaceType::Tuple(types.add_tuple_type(ty)) } - ValType::Variant(NonZeroLenVec(cases)) => { + ValType::Variant(cases) => { let ty = TypeVariant { cases: cases .iter() .enumerate() - .map(|(i, ty)| VariantCase { - name: format!("c{i}"), - ty: intern(types, ty), + .map(|(i, ty)| { + Ok(VariantCase { + name: format!("c{i}"), + ty: intern(types, ty)?, + }) }) - .collect(), + .collect::>()?, }; InterfaceType::Variant(types.add_variant_type(ty)) } ValType::Union(tys) => { let ty = TypeUnion { - types: tys.0.iter().map(|ty| intern(types, ty)).collect(), + types: tys + .iter() + .map(|ty| intern(types, ty)) + .collect::>()?, }; InterfaceType::Union(types.add_union_type(ty)) } ValType::Enum(size) => { let ty = TypeEnum { - names: (0..size.0).map(|i| format!("c{i}")).collect(), + names: (0..size.as_usize()).map(|i| format!("c{i}")).collect(), }; InterfaceType::Enum(types.add_enum_type(ty)) } ValType::Option(ty) => { - let ty = intern(types, ty); + let ty = intern(types, ty)?; InterfaceType::Option(types.add_interface_type(ty)) } - ValType::Expected(ok, err) => { - let ok = intern(types, ok); - let err = intern(types, err); + ValType::Expected { ok, err } => { + let ok = intern(types, ok)?; + let err = intern(types, err)?; InterfaceType::Expected(types.add_expected_type(TypeExpected { ok, err })) } - } + ValType::String => return Err(()), + }) } impl From for StringEncoding { diff --git a/crates/fuzzing/Cargo.toml b/crates/fuzzing/Cargo.toml index 76ae1d269c0b..d1dc70130d70 100644 --- a/crates/fuzzing/Cargo.toml +++ b/crates/fuzzing/Cargo.toml @@ -10,6 +10,8 @@ license = "Apache-2.0 WITH LLVM-exception" [dependencies] anyhow = "1.0.22" arbitrary = { version = "1.1.0", features = ["derive"] } +component-test-util = { path = "../misc/component-test-util" } +component-fuzz-util = { path = "../misc/component-fuzz-util" } env_logger = "0.9.0" log = "0.4.8" rayon = "1.2.1" diff --git a/crates/fuzzing/src/generators.rs b/crates/fuzzing/src/generators.rs index 83492b2fda77..b4242d4302f9 100644 --- a/crates/fuzzing/src/generators.rs +++ b/crates/fuzzing/src/generators.rs @@ -10,6 +10,7 @@ pub mod api; mod codegen_settings; +pub mod component_types; mod config; mod instance_allocation_strategy; mod instance_limits; diff --git a/crates/fuzzing/src/generators/component_types.rs b/crates/fuzzing/src/generators/component_types.rs new file mode 100644 index 000000000000..2d93f29d726a --- /dev/null +++ b/crates/fuzzing/src/generators/component_types.rs @@ -0,0 +1,189 @@ +//! This module generates test cases for the Wasmtime component model function APIs, +//! e.g. `wasmtime::component::func::Func` and `TypedFunc`. +//! +//! Each case includes a list of arbitrary interface types to use as parameters, plus another one to use as a +//! result, and a component which exports a function and imports a function. The exported function forwards its +//! parameters to the imported one and forwards the result back to the caller. This serves to excercise Wasmtime's +//! lifting and lowering code and verify the values remain intact during both processes. + +use arbitrary::{Arbitrary, Unstructured}; +use component_fuzz_util::{Declarations, EXPORT_FUNCTION, IMPORT_FUNCTION}; +use std::fmt::Debug; +use std::ops::ControlFlow; +use wasmtime::component::{self, Component, Lift, Linker, Lower, Val}; +use wasmtime::{Config, Engine, Store, StoreContextMut}; + +/// Minimum length of an arbitrary list value generated for a test case +const MIN_LIST_LENGTH: u32 = 0; + +/// Maximum length of an arbitrary list value generated for a test case +const MAX_LIST_LENGTH: u32 = 10; + +/// Generate an arbitrary instance of the specified type. +pub fn arbitrary_val(ty: &component::Type, input: &mut Unstructured) -> arbitrary::Result { + use component::Type; + + Ok(match ty { + Type::Unit => Val::Unit, + Type::Bool => Val::Bool(input.arbitrary()?), + Type::S8 => Val::S8(input.arbitrary()?), + Type::U8 => Val::U8(input.arbitrary()?), + Type::S16 => Val::S16(input.arbitrary()?), + Type::U16 => Val::U16(input.arbitrary()?), + Type::S32 => Val::S32(input.arbitrary()?), + Type::U32 => Val::U32(input.arbitrary()?), + Type::S64 => Val::S64(input.arbitrary()?), + Type::U64 => Val::U64(input.arbitrary()?), + Type::Float32 => Val::Float32(input.arbitrary::()?.to_bits()), + Type::Float64 => Val::Float64(input.arbitrary::()?.to_bits()), + Type::Char => Val::Char(input.arbitrary()?), + Type::String => Val::String(input.arbitrary()?), + Type::List(list) => { + let mut values = Vec::new(); + input.arbitrary_loop(Some(MIN_LIST_LENGTH), Some(MAX_LIST_LENGTH), |input| { + values.push(arbitrary_val(&list.ty(), input)?); + + Ok(ControlFlow::Continue(())) + })?; + + list.new_val(values.into()).unwrap() + } + Type::Record(record) => record + .new_val( + record + .fields() + .map(|field| Ok((field.name, arbitrary_val(&field.ty, input)?))) + .collect::>>()?, + ) + .unwrap(), + Type::Tuple(tuple) => tuple + .new_val( + tuple + .types() + .map(|ty| arbitrary_val(&ty, input)) + .collect::>()?, + ) + .unwrap(), + Type::Variant(variant) => { + let mut cases = variant.cases(); + let discriminant = input.int_in_range(0..=cases.len() - 1)?; + variant + .new_val( + &format!("C{discriminant}"), + arbitrary_val(&cases.nth(discriminant).unwrap().ty, input)?, + ) + .unwrap() + } + Type::Enum(en) => { + let discriminant = input.int_in_range(0..=en.names().len() - 1)?; + en.new_val(&format!("C{discriminant}")).unwrap() + } + Type::Union(un) => { + let mut types = un.types(); + let discriminant = input.int_in_range(0..=types.len() - 1)?; + un.new_val( + discriminant.try_into().unwrap(), + arbitrary_val(&types.nth(discriminant).unwrap(), input)?, + ) + .unwrap() + } + Type::Option(option) => { + let discriminant = input.int_in_range(0..=1)?; + option + .new_val(match discriminant { + 0 => None, + 1 => Some(arbitrary_val(&option.ty(), input)?), + _ => unreachable!(), + }) + .unwrap() + } + Type::Expected(expected) => { + let discriminant = input.int_in_range(0..=1)?; + expected + .new_val(match discriminant { + 0 => Ok(arbitrary_val(&expected.ok(), input)?), + 1 => Err(arbitrary_val(&expected.err(), input)?), + _ => unreachable!(), + }) + .unwrap() + } + Type::Flags(flags) => flags + .new_val( + &flags + .names() + .filter_map(|name| { + input + .arbitrary() + .map(|p| if p { Some(name) } else { None }) + .transpose() + }) + .collect::>>()?, + ) + .unwrap(), + }) +} + +macro_rules! define_static_api_test { + ($name:ident $(($param:ident $param_name:ident $param_expected_name:ident))*) => { + #[allow(unused_parens)] + /// Generate zero or more sets of arbitrary argument and result values and execute the test using those + /// values, asserting that they flow from host-to-guest and guest-to-host unchanged. + pub fn $name<'a, $($param,)* R>( + input: &mut Unstructured<'a>, + declarations: &Declarations, + ) -> arbitrary::Result<()> + where + $($param: Lift + Lower + Clone + PartialEq + Debug + Arbitrary<'a> + 'static,)* + R: Lift + Lower + Clone + PartialEq + Debug + Arbitrary<'a> + 'static + { + crate::init_fuzzing(); + + let mut config = Config::new(); + config.wasm_component_model(true); + let engine = Engine::new(&config).unwrap(); + let component = Component::new( + &engine, + declarations.make_component().as_bytes() + ).unwrap(); + let mut linker = Linker::new(&engine); + linker + .root() + .func_wrap( + IMPORT_FUNCTION, + |cx: StoreContextMut<'_, ($(Option<$param>,)* Option)>, + $($param_name: $param,)*| + { + let ($($param_expected_name,)* result) = cx.data(); + $(assert_eq!($param_name, *$param_expected_name.as_ref().unwrap());)* + Ok(result.as_ref().unwrap().clone()) + }, + ) + .unwrap(); + let mut store = Store::new(&engine, Default::default()); + let instance = linker.instantiate(&mut store, &component).unwrap(); + let func = instance + .get_typed_func::<($($param,)*), R, _>(&mut store, EXPORT_FUNCTION) + .unwrap(); + + while input.arbitrary()? { + $(let $param_name = input.arbitrary::<$param>()?;)* + let result = input.arbitrary::()?; + *store.data_mut() = ($(Some($param_name.clone()),)* Some(result.clone())); + + assert_eq!(func.call(&mut store, ($($param_name,)*)).unwrap(), result); + func.post_return(&mut store).unwrap(); + } + + Ok(()) + } + } +} + +define_static_api_test!(static_api_test0); +define_static_api_test!(static_api_test1 (P0 p0 p0_expected)); +define_static_api_test!(static_api_test2 (P0 p0 p0_expected) (P1 p1 p1_expected)); +define_static_api_test!(static_api_test3 (P0 p0 p0_expected) (P1 p1 p1_expected) (P2 p2 p2_expected)); +define_static_api_test!(static_api_test4 (P0 p0 p0_expected) (P1 p1 p1_expected) (P2 p2 p2_expected) + (P3 p3 p3_expected)); +define_static_api_test!(static_api_test5 (P0 p0 p0_expected) (P1 p1 p1_expected) (P2 p2 p2_expected) + (P3 p3 p3_expected) (P4 p4 p4_expected)); diff --git a/crates/fuzzing/src/oracles.rs b/crates/fuzzing/src/oracles.rs index d2ff3283c6f5..4e7d090c4ff3 100644 --- a/crates/fuzzing/src/oracles.rs +++ b/crates/fuzzing/src/oracles.rs @@ -1073,3 +1073,60 @@ fn set_fuel(store: &mut Store, fuel: u64) { // double-check that the store has the expected amount of fuel remaining assert_eq!(store.consume_fuel(0).unwrap(), fuel); } + +/// Generate and execute a `crate::generators::component_types::TestCase` using the specified `input` to create +/// arbitrary types and values. +pub fn dynamic_component_api_target(input: &mut arbitrary::Unstructured) -> arbitrary::Result<()> { + use crate::generators::component_types; + use anyhow::Result; + use component_fuzz_util::{TestCase, EXPORT_FUNCTION, IMPORT_FUNCTION}; + use component_test_util::FuncExt; + use wasmtime::component::{Component, Linker, Val}; + + crate::init_fuzzing(); + + let case = input.arbitrary::()?; + + let engine = component_test_util::engine(); + let mut store = Store::new(&engine, (Box::new([]) as Box<[Val]>, None)); + let component = + Component::new(&engine, case.declarations().make_component().as_bytes()).unwrap(); + let mut linker = Linker::new(&engine); + + linker + .root() + .func_new(&component, IMPORT_FUNCTION, { + move |cx: StoreContextMut<'_, (Box<[Val]>, Option)>, args: &[Val]| -> Result { + let (expected_args, result) = cx.data(); + assert_eq!(args.len(), expected_args.len()); + for (expected, actual) in expected_args.iter().zip(args) { + assert_eq!(expected, actual); + } + Ok(result.as_ref().unwrap().clone()) + } + }) + .unwrap(); + + let instance = linker.instantiate(&mut store, &component).unwrap(); + let func = instance.get_func(&mut store, EXPORT_FUNCTION).unwrap(); + let params = func.params(&store); + let result = func.result(&store); + + while input.arbitrary()? { + let args = params + .iter() + .map(|ty| component_types::arbitrary_val(ty, input)) + .collect::>>()?; + + let result = component_types::arbitrary_val(&result, input)?; + + *store.data_mut() = (args.clone(), Some(result.clone())); + + assert_eq!( + func.call_and_post_return(&mut store, &args).unwrap(), + result + ); + } + + Ok(()) +} diff --git a/crates/misc/component-fuzz-util/Cargo.toml b/crates/misc/component-fuzz-util/Cargo.toml new file mode 100644 index 000000000000..e17332334049 --- /dev/null +++ b/crates/misc/component-fuzz-util/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "component-fuzz-util" +authors = ["The Wasmtime Project Developers"] +license = "Apache-2.0 WITH LLVM-exception" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +anyhow = { version = "1.0.19" } +arbitrary = { version = "1.1.0", features = ["derive"] } +proc-macro2 = "1.0" +quote = "1.0" +wasmtime-component-util = { path = "../../component-util" } diff --git a/crates/misc/component-fuzz-util/src/lib.rs b/crates/misc/component-fuzz-util/src/lib.rs new file mode 100644 index 000000000000..9b14266dcd92 --- /dev/null +++ b/crates/misc/component-fuzz-util/src/lib.rs @@ -0,0 +1,800 @@ +//! This module generates test cases for the Wasmtime component model function APIs, +//! e.g. `wasmtime::component::func::Func` and `TypedFunc`. +//! +//! Each case includes a list of arbitrary interface types to use as parameters, plus another one to use as a +//! result, and a component which exports a function and imports a function. The exported function forwards its +//! parameters to the imported one and forwards the result back to the caller. This serves to excercise Wasmtime's +//! lifting and lowering code and verify the values remain intact during both processes. + +use arbitrary::{Arbitrary, Unstructured}; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote}; +use std::fmt::{self, Debug, Write}; +use std::iter; +use std::ops::Deref; +use wasmtime_component_util::{DiscriminantSize, FlagsSize, REALLOC_AND_FREE}; + +const MAX_FLAT_PARAMS: usize = 16; +const MAX_FLAT_RESULTS: usize = 1; +const MAX_ARITY: usize = 5; + +/// The name of the imported host function which the generated component will call +pub const IMPORT_FUNCTION: &str = "echo"; + +/// The name of the exported guest function which the host should call +pub const EXPORT_FUNCTION: &str = "echo"; + +/// Maximum length of an arbitrary tuple type. As of this writing, the `wasmtime::component::func::typed` module +/// only implements the `ComponentType` trait for tuples up to this length. +const MAX_TUPLE_LENGTH: usize = 16; + +#[derive(Copy, Clone, PartialEq, Eq)] +enum CoreType { + I32, + I64, + F32, + F64, +} + +impl CoreType { + /// This is the `join` operation specified in [the canonical + /// ABI](https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md#flattening) for + /// variant types. + fn join(self, other: Self) -> Self { + match (self, other) { + _ if self == other => self, + (Self::I32, Self::F32) | (Self::F32, Self::I32) => Self::I32, + _ => Self::I64, + } + } +} + +impl fmt::Display for CoreType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::I32 => f.write_str("i32"), + Self::I64 => f.write_str("i64"), + Self::F32 => f.write_str("f32"), + Self::F64 => f.write_str("f64"), + } + } +} + +#[derive(Debug)] +pub struct UsizeInRange(usize); + +impl UsizeInRange { + pub fn as_usize(&self) -> usize { + self.0 + } +} + +impl<'a, const L: usize, const H: usize> Arbitrary<'a> for UsizeInRange { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + Ok(UsizeInRange(u.int_in_range(L..=H)?)) + } +} + +/// Wraps a `Box<[T]>` and provides an `Arbitrary` implementation that always generates non-empty slices +#[derive(Debug)] +pub struct NonEmptyArray(Box<[T]>); + +impl<'a, T: Arbitrary<'a>> Arbitrary<'a> for NonEmptyArray { + fn arbitrary(input: &mut Unstructured<'a>) -> arbitrary::Result { + Ok(Self( + iter::once(input.arbitrary()) + .chain(input.arbitrary_iter()?) + .collect::>()?, + )) + } +} + +impl Deref for NonEmptyArray { + type Target = [T]; + + fn deref(&self) -> &[T] { + self.0.deref() + } +} + +/// Wraps a `Box<[T]>` and provides an `Arbitrary` implementation that always generates slices of length less than +/// or equal to the longest tuple for which Wasmtime generates a `ComponentType` impl +#[derive(Debug)] +pub struct TupleArray(Box<[T]>); + +impl<'a, T: Arbitrary<'a>> Arbitrary<'a> for TupleArray { + fn arbitrary(input: &mut Unstructured<'a>) -> arbitrary::Result { + Ok(Self( + input + .arbitrary_iter()? + .take(MAX_TUPLE_LENGTH) + .collect::>()?, + )) + } +} + +impl Deref for TupleArray { + type Target = [T]; + + fn deref(&self) -> &[T] { + self.0.deref() + } +} + +/// Represents a component model interface type +#[allow(missing_docs)] +#[derive(Arbitrary, Debug)] +pub enum Type { + Unit, + Bool, + S8, + U8, + S16, + U16, + S32, + U32, + S64, + U64, + Float32, + Float64, + Char, + String, + List(Box), + Record(Box<[Type]>), + Tuple(TupleArray), + Variant(NonEmptyArray), + Enum(UsizeInRange<1, 257>), + Union(NonEmptyArray), + Option(Box), + Expected { ok: Box, err: Box }, + Flags(UsizeInRange<0, 65>), +} + +fn lower_record<'a>(types: impl Iterator, vec: &mut Vec) { + for ty in types { + ty.lower(vec); + } +} + +fn lower_variant<'a>(types: impl Iterator, vec: &mut Vec) { + vec.push(CoreType::I32); + let offset = vec.len(); + for ty in types { + for (index, ty) in ty.lowered().iter().enumerate() { + let index = offset + index; + if index < vec.len() { + vec[index] = vec[index].join(*ty); + } else { + vec.push(*ty) + } + } + } +} + +fn u32_count_from_flag_count(count: usize) -> usize { + match FlagsSize::from_count(count) { + FlagsSize::Size0 => 0, + FlagsSize::Size1 | FlagsSize::Size2 => 1, + FlagsSize::Size4Plus(n) => n, + } +} + +struct SizeAndAlignment { + size: usize, + alignment: u32, +} + +impl Type { + fn lowered(&self) -> Vec { + let mut vec = Vec::new(); + self.lower(&mut vec); + vec + } + + fn lower(&self, vec: &mut Vec) { + match self { + Type::Unit => (), + Type::Bool + | Type::U8 + | Type::S8 + | Type::S16 + | Type::U16 + | Type::S32 + | Type::U32 + | Type::Char + | Type::Enum(_) => vec.push(CoreType::I32), + Type::S64 | Type::U64 => vec.push(CoreType::I64), + Type::Float32 => vec.push(CoreType::F32), + Type::Float64 => vec.push(CoreType::F64), + Type::String | Type::List(_) => { + vec.push(CoreType::I32); + vec.push(CoreType::I32); + } + Type::Record(types) => lower_record(types.iter(), vec), + Type::Tuple(types) => lower_record(types.0.iter(), vec), + Type::Variant(types) | Type::Union(types) => lower_variant(types.0.iter(), vec), + Type::Option(ty) => lower_variant([&Type::Unit, ty].into_iter(), vec), + Type::Expected { ok, err } => lower_variant([ok.deref(), err].into_iter(), vec), + Type::Flags(count) => { + vec.extend(iter::repeat(CoreType::I32).take(u32_count_from_flag_count(count.0))) + } + } + } + + fn size_and_alignment(&self) -> SizeAndAlignment { + match self { + Type::Unit => SizeAndAlignment { + size: 0, + alignment: 1, + }, + + Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment { + size: 1, + alignment: 1, + }, + + Type::S16 | Type::U16 => SizeAndAlignment { + size: 2, + alignment: 2, + }, + + Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment { + size: 4, + alignment: 4, + }, + + Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment { + size: 8, + alignment: 8, + }, + + Type::String | Type::List(_) => SizeAndAlignment { + size: 8, + alignment: 4, + }, + + Type::Record(types) => record_size_and_alignment(types.iter()), + + Type::Tuple(types) => record_size_and_alignment(types.0.iter()), + + Type::Variant(types) | Type::Union(types) => variant_size_and_alignment(types.0.iter()), + + Type::Enum(count) => variant_size_and_alignment((0..count.0).map(|_| &Type::Unit)), + + Type::Option(ty) => variant_size_and_alignment([&Type::Unit, ty].into_iter()), + + Type::Expected { ok, err } => variant_size_and_alignment([ok.deref(), err].into_iter()), + + Type::Flags(count) => match FlagsSize::from_count(count.0) { + FlagsSize::Size0 => SizeAndAlignment { + size: 0, + alignment: 1, + }, + FlagsSize::Size1 => SizeAndAlignment { + size: 1, + alignment: 1, + }, + FlagsSize::Size2 => SizeAndAlignment { + size: 2, + alignment: 2, + }, + FlagsSize::Size4Plus(n) => SizeAndAlignment { + size: n * 4, + alignment: 4, + }, + }, + } + } +} + +fn align_to(a: usize, align: u32) -> usize { + let align = align as usize; + (a + (align - 1)) & !(align - 1) +} + +fn record_size_and_alignment<'a>(types: impl Iterator) -> SizeAndAlignment { + let mut offset = 0; + let mut align = 1; + for ty in types { + let SizeAndAlignment { size, alignment } = ty.size_and_alignment(); + offset = align_to(offset, alignment) + size; + align = align.max(alignment); + } + + SizeAndAlignment { + size: align_to(offset, align), + alignment: align, + } +} + +fn variant_size_and_alignment<'a>( + types: impl ExactSizeIterator, +) -> SizeAndAlignment { + let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap(); + let mut alignment = u32::from(discriminant_size); + let mut size = 0; + for ty in types { + let size_and_alignment = ty.size_and_alignment(); + alignment = alignment.max(size_and_alignment.alignment); + size = size.max(size_and_alignment.size); + } + + SizeAndAlignment { + size: align_to( + align_to(usize::from(discriminant_size), alignment) + size, + alignment, + ), + alignment, + } +} + +fn make_import_and_export(params: &[Type], result: &Type) -> Box { + let params_lowered = params + .iter() + .flat_map(|ty| ty.lowered()) + .collect::>(); + let result_lowered = result.lowered(); + + let mut core_params = String::new(); + let mut gets = String::new(); + + if params_lowered.len() <= MAX_FLAT_PARAMS { + for (index, param) in params_lowered.iter().enumerate() { + write!(&mut core_params, " {param}").unwrap(); + write!(&mut gets, "local.get {index} ").unwrap(); + } + } else { + write!(&mut core_params, " i32").unwrap(); + write!(&mut gets, "local.get 0 ").unwrap(); + } + + let maybe_core_params = if params_lowered.is_empty() { + String::new() + } else { + format!("(param{core_params})") + }; + + if result_lowered.len() <= MAX_FLAT_RESULTS { + let mut core_results = String::new(); + for result in result_lowered.iter() { + write!(&mut core_results, " {result}").unwrap(); + } + + let maybe_core_results = if result_lowered.is_empty() { + String::new() + } else { + format!("(result{core_results})") + }; + + format!( + r#" + (func $f (import "host" "{IMPORT_FUNCTION}") {maybe_core_params} {maybe_core_results}) + + (func (export "{EXPORT_FUNCTION}") {maybe_core_params} {maybe_core_results} + {gets} + + call $f + )"# + ) + } else { + let SizeAndAlignment { size, alignment } = result.size_and_alignment(); + + format!( + r#" + (func $f (import "host" "{IMPORT_FUNCTION}") (param{core_params} i32)) + + (func (export "{EXPORT_FUNCTION}") {maybe_core_params} (result i32) + (local $base i32) + (local.set $base + (call $realloc + (i32.const 0) + (i32.const 0) + (i32.const {alignment}) + (i32.const {size}))) + {gets} + local.get $base + + call $f + + local.get $base + )"# + ) + } + .into() +} + +fn make_rust_name(name_counter: &mut u32) -> Ident { + let name = format_ident!("Foo{name_counter}"); + *name_counter += 1; + name +} + +/// Generate a [`TokenStream`] containing the rust type name for a type. +/// +/// The `name_counter` parameter is used to generate names for each recursively visited type. The `declarations` +/// parameter is used to accumulate declarations for each recursively visited type. +pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStream) -> TokenStream { + match ty { + Type::Unit => quote!(()), + Type::Bool => quote!(bool), + Type::S8 => quote!(i8), + Type::U8 => quote!(u8), + Type::S16 => quote!(i16), + Type::U16 => quote!(u16), + Type::S32 => quote!(i32), + Type::U32 => quote!(u32), + Type::S64 => quote!(i64), + Type::U64 => quote!(u64), + Type::Float32 => quote!(Float32), + Type::Float64 => quote!(Float64), + Type::Char => quote!(char), + Type::String => quote!(Box), + Type::List(ty) => { + let ty = rust_type(ty, name_counter, declarations); + quote!(Vec<#ty>) + } + Type::Record(types) => { + let fields = types + .iter() + .enumerate() + .map(|(index, ty)| { + let name = format_ident!("f{index}"); + let ty = rust_type(ty, name_counter, declarations); + quote!(#name: #ty,) + }) + .collect::(); + + let name = make_rust_name(name_counter); + + declarations.extend(quote! { + #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)] + #[component(record)] + struct #name { + #fields + } + }); + + quote!(#name) + } + Type::Tuple(types) => { + let fields = types + .0 + .iter() + .map(|ty| { + let ty = rust_type(ty, name_counter, declarations); + quote!(#ty,) + }) + .collect::(); + + quote!((#fields)) + } + Type::Variant(types) | Type::Union(types) => { + let cases = types + .0 + .iter() + .enumerate() + .map(|(index, ty)| { + let name = format_ident!("C{index}"); + let ty = rust_type(ty, name_counter, declarations); + quote!(#name(#ty),) + }) + .collect::(); + + let name = make_rust_name(name_counter); + + let which = if let Type::Variant(_) = ty { + quote!(variant) + } else { + quote!(union) + }; + + declarations.extend(quote! { + #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)] + #[component(#which)] + enum #name { + #cases + } + }); + + quote!(#name) + } + Type::Enum(count) => { + let cases = (0..count.0) + .map(|index| { + let name = format_ident!("C{index}"); + quote!(#name,) + }) + .collect::(); + + let name = make_rust_name(name_counter); + + declarations.extend(quote! { + #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)] + #[component(enum)] + enum #name { + #cases + } + }); + + quote!(#name) + } + Type::Option(ty) => { + let ty = rust_type(ty, name_counter, declarations); + quote!(Option<#ty>) + } + Type::Expected { ok, err } => { + let ok = rust_type(ok, name_counter, declarations); + let err = rust_type(err, name_counter, declarations); + quote!(Result<#ok, #err>) + } + Type::Flags(count) => { + let type_name = make_rust_name(name_counter); + + let mut flags = TokenStream::new(); + let mut names = TokenStream::new(); + + for index in 0..count.0 { + let name = format_ident!("F{index}"); + flags.extend(quote!(const #name;)); + names.extend(quote!(#type_name::#name,)) + } + + declarations.extend(quote! { + wasmtime::component::flags! { + #type_name { + #flags + } + } + + impl<'a> Arbitrary<'a> for #type_name { + fn arbitrary(input: &mut Unstructured<'a>) -> arbitrary::Result { + let mut flags = #type_name::default(); + for flag in [#names] { + if input.arbitrary()? { + flags |= flag; + } + } + Ok(flags) + } + } + }); + + quote!(#type_name) + } + } +} + +fn make_component_name(name_counter: &mut u32) -> String { + let name = format!("$Foo{name_counter}"); + *name_counter += 1; + name +} + +fn write_component_type( + ty: &Type, + f: &mut String, + name_counter: &mut u32, + declarations: &mut String, +) { + match ty { + Type::Unit => f.push_str("unit"), + Type::Bool => f.push_str("bool"), + Type::S8 => f.push_str("s8"), + Type::U8 => f.push_str("u8"), + Type::S16 => f.push_str("s16"), + Type::U16 => f.push_str("u16"), + Type::S32 => f.push_str("s32"), + Type::U32 => f.push_str("u32"), + Type::S64 => f.push_str("s64"), + Type::U64 => f.push_str("u64"), + Type::Float32 => f.push_str("float32"), + Type::Float64 => f.push_str("float64"), + Type::Char => f.push_str("char"), + Type::String => f.push_str("string"), + Type::List(ty) => { + let mut case = String::new(); + write_component_type(ty, &mut case, name_counter, declarations); + let name = make_component_name(name_counter); + write!(declarations, "(type {name} (list {case}))").unwrap(); + f.push_str(&name); + } + Type::Record(types) => { + let mut fields = String::new(); + for (index, ty) in types.iter().enumerate() { + write!(fields, r#" (field "f{index}" "#).unwrap(); + write_component_type(ty, &mut fields, name_counter, declarations); + fields.push_str(")"); + } + let name = make_component_name(name_counter); + write!(declarations, "(type {name} (record{fields}))").unwrap(); + f.push_str(&name); + } + Type::Tuple(types) => { + let mut fields = String::new(); + for ty in types.0.iter() { + fields.push_str(" "); + write_component_type(ty, &mut fields, name_counter, declarations); + } + let name = make_component_name(name_counter); + write!(declarations, "(type {name} (tuple{fields}))").unwrap(); + f.push_str(&name); + } + Type::Variant(types) => { + let mut cases = String::new(); + for (index, ty) in types.0.iter().enumerate() { + write!(cases, r#" (case "C{index}" "#).unwrap(); + write_component_type(ty, &mut cases, name_counter, declarations); + cases.push_str(")"); + } + let name = make_component_name(name_counter); + write!(declarations, "(type {name} (variant{cases}))").unwrap(); + f.push_str(&name); + } + Type::Enum(count) => { + f.push_str("(enum"); + for index in 0..count.0 { + write!(f, r#" "C{index}""#).unwrap(); + } + f.push_str(")"); + } + Type::Union(types) => { + let mut cases = String::new(); + for ty in types.0.iter() { + cases.push_str(" "); + write_component_type(ty, &mut cases, name_counter, declarations); + } + let name = make_component_name(name_counter); + write!(declarations, "(type {name} (union{cases}))").unwrap(); + f.push_str(&name); + } + Type::Option(ty) => { + let mut case = String::new(); + write_component_type(ty, &mut case, name_counter, declarations); + let name = make_component_name(name_counter); + write!(declarations, "(type {name} (option {case}))").unwrap(); + f.push_str(&name); + } + Type::Expected { ok, err } => { + let mut cases = String::new(); + write_component_type(ok, &mut cases, name_counter, declarations); + cases.push_str(" "); + write_component_type(err, &mut cases, name_counter, declarations); + let name = make_component_name(name_counter); + write!(declarations, "(type {name} (expected {cases}))").unwrap(); + f.push_str(&name); + } + Type::Flags(count) => { + f.push_str("(flags"); + for index in 0..count.0 { + write!(f, r#" "F{index}""#).unwrap(); + } + f.push_str(")"); + } + } +} + +/// Represents custom fragments of a WAT file which may be used to create a component for exercising [`TestCase`]s +#[derive(Debug)] +pub struct Declarations { + /// Type declarations (if any) referenced by `params` and/or `result` + pub types: Box, + /// Parameter declarations used for the imported and exported functions + pub params: Box, + /// Result declaration used for the imported and exported functions + pub result: Box, + /// A WAT fragment representing the core function import and export to use for testing + pub import_and_export: Box, +} + +impl Declarations { + /// Generate a complete WAT file based on the specified fragments. + pub fn make_component(&self) -> Box { + let Self { + types, + params, + result, + import_and_export, + } = self; + + format!( + r#" + (component + (core module $libc + (memory (export "memory") 1) + {REALLOC_AND_FREE} + ) + + (core instance $libc (instantiate $libc)) + + {types} + + (import "{IMPORT_FUNCTION}" (func $f {params} {result})) + + (core func $f_lower (canon lower + (func $f) + (memory $libc "memory") + (realloc (func $libc "realloc")) + )) + + (core module $m + (memory (import "libc" "memory") 1) + (func $realloc (import "libc" "realloc") (param i32 i32 i32 i32) (result i32)) + + {import_and_export} + ) + + (core instance $i (instantiate $m + (with "libc" (instance $libc)) + (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower)))) + )) + + (func (export "echo") {params} {result} + (canon lift + (core func $i "echo") + (memory $libc "memory") + (realloc (func $libc "realloc")) + ) + ) + )"#, + ) + .into() + } +} + +/// Represents a test case for calling a component function +#[derive(Debug)] +pub struct TestCase { + /// The types of parameters to pass to the function + pub params: Box<[Type]>, + /// The type of the result to be returned by the function + pub result: Type, +} + +impl TestCase { + /// Generate a `Declarations` for this `TestCase` which may be used to build a component to execute the case. + pub fn declarations(&self) -> Declarations { + let mut types = String::new(); + let name_counter = &mut 0; + + let params = self + .params + .iter() + .map(|ty| { + let mut tmp = String::new(); + write_component_type(ty, &mut tmp, name_counter, &mut types); + format!("(param {tmp})") + }) + .collect::>() + .join(" ") + .into(); + + let result = { + let mut tmp = String::new(); + write_component_type(&self.result, &mut tmp, name_counter, &mut types); + format!("(result {tmp})") + } + .into(); + + let import_and_export = make_import_and_export(&self.params, &self.result); + + Declarations { + types: types.into(), + params, + result, + import_and_export, + } + } +} + +impl<'a> Arbitrary<'a> for TestCase { + /// Generate an arbitrary [`TestCase`]. + fn arbitrary(input: &mut Unstructured<'a>) -> arbitrary::Result { + Ok(Self { + params: input + .arbitrary_iter()? + .take(MAX_ARITY) + .collect::>>()?, + result: input.arbitrary()?, + }) + } +} diff --git a/crates/misc/component-test-util/Cargo.toml b/crates/misc/component-test-util/Cargo.toml new file mode 100644 index 000000000000..1c5012d32fcb --- /dev/null +++ b/crates/misc/component-test-util/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "component-test-util" +authors = ["The Wasmtime Project Developers"] +license = "Apache-2.0 WITH LLVM-exception" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +env_logger = "0.9.0" +anyhow = "1.0.19" +arbitrary = { version = "1.1.0", features = ["derive"] } +wasmtime = { path = "../../wasmtime", features = ["component-model"] } diff --git a/crates/misc/component-test-util/src/lib.rs b/crates/misc/component-test-util/src/lib.rs new file mode 100644 index 000000000000..364708250941 --- /dev/null +++ b/crates/misc/component-test-util/src/lib.rs @@ -0,0 +1,112 @@ +use anyhow::Result; +use arbitrary::Arbitrary; +use std::mem::MaybeUninit; +use wasmtime::component::__internal::{ + ComponentTypes, InterfaceType, Memory, MemoryMut, Options, StoreOpaque, +}; +use wasmtime::component::{ComponentParams, ComponentType, Func, Lift, Lower, TypedFunc, Val}; +use wasmtime::{AsContextMut, Config, Engine, StoreContextMut}; + +pub trait TypedFuncExt { + fn call_and_post_return(&self, store: impl AsContextMut, params: P) -> Result; +} + +impl TypedFuncExt for TypedFunc +where + P: ComponentParams + Lower, + R: Lift, +{ + fn call_and_post_return(&self, mut store: impl AsContextMut, params: P) -> Result { + let result = self.call(&mut store, params)?; + self.post_return(&mut store)?; + Ok(result) + } +} + +pub trait FuncExt { + fn call_and_post_return(&self, store: impl AsContextMut, args: &[Val]) -> Result; +} + +impl FuncExt for Func { + fn call_and_post_return(&self, mut store: impl AsContextMut, args: &[Val]) -> Result { + let result = self.call(&mut store, args)?; + self.post_return(&mut store)?; + Ok(result) + } +} + +pub fn engine() -> Engine { + drop(env_logger::try_init()); + + let mut config = Config::new(); + config.wasm_component_model(true); + + // When `WASMTIME_TEST_NO_HOG_MEMORY` is set it means we're in qemu. The + // component model tests create a disproportionate number of instances so + // try to cut down on virtual memory usage by avoiding 4G reservations. + if std::env::var("WASMTIME_TEST_NO_HOG_MEMORY").is_ok() { + config.static_memory_maximum_size(0); + config.dynamic_memory_guard_size(0); + } + Engine::new(&config).unwrap() +} + +/// Newtype wrapper for `f32` whose `PartialEq` impl considers NaNs equal to each other. +#[derive(Copy, Clone, Debug, Arbitrary)] +pub struct Float32(pub f32); + +/// Newtype wrapper for `f64` whose `PartialEq` impl considers NaNs equal to each other. +#[derive(Copy, Clone, Debug, Arbitrary)] +pub struct Float64(pub f64); + +macro_rules! forward_impls { + ($($a:ty => $b:ty,)*) => ($( + unsafe impl ComponentType for $a { + type Lower = <$b as ComponentType>::Lower; + + const SIZE32: usize = <$b as ComponentType>::SIZE32; + const ALIGN32: u32 = <$b as ComponentType>::ALIGN32; + + #[inline] + fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> { + <$b as ComponentType>::typecheck(ty, types) + } + } + + unsafe impl Lower for $a { + fn lower( + &self, + store: &mut StoreContextMut, + options: &Options, + dst: &mut MaybeUninit, + ) -> Result<()> { + <$b as Lower>::lower(&self.0, store, options, dst) + } + + fn store(&self, memory: &mut MemoryMut<'_, U>, offset: usize) -> Result<()> { + <$b as Lower>::store(&self.0, memory, offset) + } + } + + unsafe impl Lift for $a { + fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result { + Ok(Self(<$b as Lift>::lift(store, options, src)?)) + } + + fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result { + Ok(Self(<$b as Lift>::load(memory, bytes)?)) + } + } + + impl PartialEq for $a { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 || (self.0.is_nan() && other.0.is_nan()) + } + } + )*) +} + +forward_impls! { + Float32 => f32, + Float64 => f64, +} diff --git a/crates/wasmtime/src/component/func.rs b/crates/wasmtime/src/component/func.rs index 83db620cd528..0942cf74474f 100644 --- a/crates/wasmtime/src/component/func.rs +++ b/crates/wasmtime/src/component/func.rs @@ -257,6 +257,12 @@ impl Func { .collect() } + /// Get the result type for this function. + pub fn result(&self, store: impl AsContext) -> Type { + let data = &store.as_context()[self.0]; + Type::from(&data.types[data.ty].result, &data.types) + } + /// Invokes this function with the `params` given and returns the result. /// /// The `params` here must match the type signature of this `Func`, or this will return an error. If a trap @@ -307,11 +313,14 @@ impl Func { self.store_args(store, &options, ¶ms, args, dst) } else { dst.write([ValRaw::u64(0); MAX_FLAT_PARAMS]); - let dst = unsafe { + + let dst = &mut unsafe { mem::transmute::<_, &mut [MaybeUninit; MAX_FLAT_PARAMS]>(dst) - }; + } + .iter_mut(); + args.iter() - .try_for_each(|arg| arg.lower(store, &options, &mut dst.iter_mut())) + .try_for_each(|arg| arg.lower(store, &options, dst)) } }, |store, options, src: &[ValRaw; MAX_FLAT_RESULTS]| { diff --git a/crates/wasmtime/src/component/func/host.rs b/crates/wasmtime/src/component/func/host.rs index b76daeb494bc..b29321439f0f 100644 --- a/crates/wasmtime/src/component/func/host.rs +++ b/crates/wasmtime/src/component/func/host.rs @@ -1,9 +1,10 @@ use crate::component::func::{Memory, MemoryMut, Options}; -use crate::component::{ComponentParams, ComponentType, Lift, Lower}; +use crate::component::types::SizeAndAlignment; +use crate::component::{ComponentParams, ComponentType, Lift, Lower, Type, Val}; use crate::{AsContextMut, StoreContextMut, ValRaw}; -use anyhow::{bail, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use std::any::Any; -use std::mem::MaybeUninit; +use std::mem::{self, MaybeUninit}; use std::panic::{self, AssertUnwindSafe}; use std::ptr::NonNull; use std::sync::Arc; @@ -43,7 +44,7 @@ pub trait IntoComponentFunc { pub struct HostFunc { entrypoint: VMLoweringCallee, - typecheck: fn(TypeFuncIndex, &ComponentTypes) -> Result<()>, + typecheck: Box) -> Result<()>) + Send + Sync>, func: Box, } @@ -51,17 +52,54 @@ impl HostFunc { fn new(func: F, entrypoint: VMLoweringCallee) -> Arc where F: Send + Sync + 'static, - P: ComponentParams + Lift, - R: Lower, + P: ComponentParams + Lift + 'static, + R: Lower + 'static, { Arc::new(HostFunc { entrypoint, - typecheck: typecheck::, + typecheck: Box::new(typecheck::), func: Box::new(func), }) } - pub fn typecheck(&self, ty: TypeFuncIndex, types: &ComponentTypes) -> Result<()> { + pub(crate) fn new_dynamic< + T, + F: Fn(StoreContextMut<'_, T>, &[Val]) -> Result + Send + Sync + 'static, + >( + func: F, + index: TypeFuncIndex, + types: &Arc, + ) -> Arc { + let ty = &types[index]; + + Arc::new(HostFunc { + entrypoint: dynamic_entrypoint::, + typecheck: Box::new({ + let types = types.clone(); + + move |expected_index, expected_types| { + if index == expected_index && Arc::ptr_eq(&types, expected_types) { + Ok(()) + } else { + Err(anyhow!("function type mismatch")) + } + } + }), + func: Box::new(DynamicContext { + func, + types: Types { + params: ty + .params + .iter() + .map(|(_, ty)| Type::from(ty, types)) + .collect(), + result: Type::from(&ty.result, types), + }, + }), + }) + } + + pub fn typecheck(&self, ty: TypeFuncIndex, types: &Arc) -> Result<()> { (self.typecheck)(ty, types) } @@ -74,7 +112,7 @@ impl HostFunc { } } -fn typecheck(ty: TypeFuncIndex, types: &ComponentTypes) -> Result<()> +fn typecheck(ty: TypeFuncIndex, types: &Arc) -> Result<()> where P: ComponentParams + Lift, R: Lower, @@ -256,8 +294,8 @@ macro_rules! impl_into_component_func { impl IntoComponentFunc for F where F: Fn($($args),*) -> Result + Send + Sync + 'static, - ($($args,)*): ComponentParams + Lift, - R: Lower, + ($($args,)*): ComponentParams + Lift + 'static, + R: Lower + 'static, { extern "C" fn entrypoint( cx: *mut VMOpaqueContext, @@ -294,8 +332,8 @@ macro_rules! impl_into_component_func { impl IntoComponentFunc, $($args,)*), R> for F where F: Fn(StoreContextMut<'_, T>, $($args),*) -> Result + Send + Sync + 'static, - ($($args,)*): ComponentParams + Lift, - R: Lower, + ($($args,)*): ComponentParams + Lift + 'static, + R: Lower + 'static, { extern "C" fn entrypoint( cx: *mut VMOpaqueContext, @@ -330,3 +368,154 @@ macro_rules! impl_into_component_func { } for_each_function_signature!(impl_into_component_func); + +unsafe fn call_host_dynamic( + Types { params, result }: &Types, + cx: *mut VMOpaqueContext, + mut flags: InstanceFlags, + memory: *mut VMMemoryDefinition, + realloc: *mut VMCallerCheckedAnyfunc, + string_encoding: StringEncoding, + storage: &mut [ValRaw], + closure: F, +) -> Result<()> +where + F: FnOnce(StoreContextMut<'_, T>, &[Val]) -> Result, +{ + let cx = VMComponentContext::from_opaque(cx); + let instance = (*cx).instance(); + let mut cx = StoreContextMut::from_raw((*instance).store()); + + let options = Options::new( + cx.0.id(), + NonNull::new(memory), + NonNull::new(realloc), + string_encoding, + ); + + // Perform a dynamic check that this instance can indeed be left. Exiting + // the component is disallowed, for example, when the `realloc` function + // calls a canonical import. + if !flags.may_leave() { + bail!("cannot leave component instance"); + } + + let param_count = params.iter().map(|ty| ty.flatten_count()).sum::(); + + let args; + let ret_index; + + if param_count <= MAX_FLAT_PARAMS { + let iter = &mut storage.iter(); + args = params + .iter() + .map(|ty| Val::lift(ty, cx.0, &options, iter)) + .collect::>>()?; + ret_index = param_count; + } else { + let param_layout = { + let mut size = 0; + let mut alignment = 1; + for ty in params.iter() { + alignment = alignment.max(ty.size_and_alignment().alignment); + ty.next_field(&mut size); + } + SizeAndAlignment { size, alignment } + }; + + let memory = Memory::new(cx.0, &options); + let mut offset = validate_inbounds_dynamic(param_layout, memory.as_slice(), &storage[0])?; + args = params + .iter() + .map(|ty| { + Val::load( + ty, + &memory, + &memory.as_slice()[ty.next_field(&mut offset)..] + [..ty.size_and_alignment().size], + ) + }) + .collect::>>()?; + ret_index = 1; + }; + + let ret = closure(cx.as_context_mut(), &args)?; + flags.set_may_leave(false); + result.check(&ret)?; + + let result_count = result.flatten_count(); + if result_count <= MAX_FLAT_RESULTS { + let dst = mem::transmute::<&mut [ValRaw], &mut [MaybeUninit]>(storage); + ret.lower(&mut cx, &options, &mut dst.iter_mut())?; + } else { + let ret_ptr = &storage[ret_index]; + let mut memory = MemoryMut::new(cx.as_context_mut(), &options); + let ptr = + validate_inbounds_dynamic(result.size_and_alignment(), memory.as_slice_mut(), ret_ptr)?; + ret.store(&mut memory, ptr)?; + } + + flags.set_may_leave(true); + + return Ok(()); +} + +fn validate_inbounds_dynamic( + SizeAndAlignment { size, alignment }: SizeAndAlignment, + memory: &[u8], + ptr: &ValRaw, +) -> Result { + // FIXME: needs memory64 support + let ptr = usize::try_from(ptr.get_u32())?; + if ptr % usize::try_from(alignment)? != 0 { + bail!("pointer not aligned"); + } + let end = match ptr.checked_add(size) { + Some(n) => n, + None => bail!("pointer size overflow"), + }; + if end > memory.len() { + bail!("pointer out of bounds") + } + Ok(ptr) +} + +struct Types { + params: Box<[Type]>, + result: Type, +} + +struct DynamicContext { + func: F, + types: Types, +} + +extern "C" fn dynamic_entrypoint< + T, + F: Fn(StoreContextMut<'_, T>, &[Val]) -> Result + Send + Sync + 'static, +>( + cx: *mut VMOpaqueContext, + data: *mut u8, + flags: InstanceFlags, + memory: *mut VMMemoryDefinition, + realloc: *mut VMCallerCheckedAnyfunc, + string_encoding: StringEncoding, + storage: *mut ValRaw, + storage_len: usize, +) { + let data = data as *const DynamicContext; + unsafe { + handle_result(|| { + call_host_dynamic::( + &(*data).types, + cx, + flags, + memory, + realloc, + string_encoding, + std::slice::from_raw_parts_mut(storage, storage_len), + |store, values| ((*data).func)(store, values), + ) + }) + } +} diff --git a/crates/wasmtime/src/component/func/typed.rs b/crates/wasmtime/src/component/func/typed.rs index 9ac129444c34..0300ac4e72e6 100644 --- a/crates/wasmtime/src/component/func/typed.rs +++ b/crates/wasmtime/src/component/func/typed.rs @@ -1731,7 +1731,7 @@ macro_rules! impl_component_ty_for_tuples { _size = align_to(_size, $t::ALIGN32); _size += $t::SIZE32; )* - _size + align_to(_size, Self::ALIGN32) }; const ALIGN32: u32 = { diff --git a/crates/wasmtime/src/component/instance.rs b/crates/wasmtime/src/component/instance.rs index fd0a98cb0247..1d5f24588be0 100644 --- a/crates/wasmtime/src/component/instance.rs +++ b/crates/wasmtime/src/component/instance.rs @@ -64,7 +64,7 @@ impl Instance { /// Looks up a function by name within this [`Instance`]. /// /// This is a convenience method for calling [`Instance::exports`] followed - /// by [`ExportInstance::get_func`]. + /// by [`ExportInstance::func`]. /// /// # Panics /// diff --git a/crates/wasmtime/src/component/linker.rs b/crates/wasmtime/src/component/linker.rs index 6feb7ff234b3..1289deb632ac 100644 --- a/crates/wasmtime/src/component/linker.rs +++ b/crates/wasmtime/src/component/linker.rs @@ -1,12 +1,13 @@ use crate::component::func::HostFunc; use crate::component::instance::RuntimeImport; use crate::component::matching::TypeChecker; -use crate::component::{Component, Instance, InstancePre, IntoComponentFunc}; -use crate::{AsContextMut, Engine, Module}; +use crate::component::{Component, Instance, InstancePre, IntoComponentFunc, Val}; +use crate::{AsContextMut, Engine, Module, StoreContextMut}; use anyhow::{anyhow, bail, Context, Result}; use std::collections::hash_map::{Entry, HashMap}; use std::marker; use std::sync::Arc; +use wasmtime_environ::component::TypeDef; use wasmtime_environ::PrimaryMap; /// A type used to instantiate [`Component`]s. @@ -230,6 +231,37 @@ impl LinkerInstance<'_, T> { self.insert(name, Definition::Func(func.into_host_func())) } + /// Define a new host-provided function using dynamic types. + /// + /// `name` must refer to a function type import in `component`. If and when + /// that import is invoked by the component, the specified `func` will be + /// called, which must return a `Val` which is an instance of the result + /// type of the import. + pub fn func_new< + F: Fn(StoreContextMut<'_, T>, &[Val]) -> Result + Send + Sync + 'static, + >( + &mut self, + component: &Component, + name: &str, + func: F, + ) -> Result<()> { + for (import_name, ty) in component.env_component().import_types.values() { + if name == import_name { + if let TypeDef::ComponentFunc(index) = ty { + let name = self.strings.intern(name); + return self.insert( + name, + Definition::Func(HostFunc::new_dynamic(func, *index, component.types())), + ); + } else { + bail!("import `{name}` has the wrong type (expected a function)"); + } + } + } + + Err(anyhow!("import `{name}` not found")) + } + /// Defines a [`Module`] within this instance. /// /// This can be used to provide a core wasm [`Module`] as an import to a diff --git a/crates/wasmtime/src/component/matching.rs b/crates/wasmtime/src/component/matching.rs index 7a6a8aca2dd9..e5012a504339 100644 --- a/crates/wasmtime/src/component/matching.rs +++ b/crates/wasmtime/src/component/matching.rs @@ -3,12 +3,13 @@ use crate::component::linker::{Definition, NameMap, Strings}; use crate::types::matching; use crate::Module; use anyhow::{anyhow, bail, Context, Result}; +use std::sync::Arc; use wasmtime_environ::component::{ ComponentTypes, TypeComponentInstance, TypeDef, TypeFuncIndex, TypeModule, }; pub struct TypeChecker<'a> { - pub types: &'a ComponentTypes, + pub types: &'a Arc, pub strings: &'a Strings, } diff --git a/crates/wasmtime/src/component/types.rs b/crates/wasmtime/src/component/types.rs index e90fe6e16688..e87a00cc88ff 100644 --- a/crates/wasmtime/src/component/types.rs +++ b/crates/wasmtime/src/component/types.rs @@ -222,6 +222,7 @@ impl Flags { } /// Represents the size and alignment requirements of the heap-serialized form of a type +#[derive(Debug)] pub(crate) struct SizeAndAlignment { pub(crate) size: usize, pub(crate) alignment: u32, @@ -662,7 +663,10 @@ fn variant_size_and_alignment(types: impl ExactSizeIterator) -> Siz } SizeAndAlignment { - size: func::align_to(usize::from(discriminant_size), alignment) + size, + size: func::align_to( + func::align_to(usize::from(discriminant_size), alignment) + size, + alignment, + ), alignment, } } diff --git a/crates/wasmtime/src/component/values.rs b/crates/wasmtime/src/component/values.rs index 3ef855ecc2e8..7c3f5501522a 100644 --- a/crates/wasmtime/src/component/values.rs +++ b/crates/wasmtime/src/component/values.rs @@ -604,8 +604,9 @@ impl Val { } Type::Flags(handle) => { let count = u32::try_from(handle.names().len()).unwrap(); - assert!(count <= 32); - let value = iter::once(u32::lift(store, options, next(src))?).collect(); + let value = iter::repeat_with(|| u32::lift(store, options, next(src))) + .take(u32_count_for_flag_count(count.try_into()?)) + .collect::>()?; Val::Flags(Flags { ty: handle.clone(), diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index f0a1f2a856bc..7f3ae4a34606 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -9,6 +9,8 @@ publish = false cargo-fuzz = true [dependencies] +anyhow = { version = "1.0.19" } +arbitrary = { version = "1.1.0", features = ["derive"] } cranelift-codegen = { path = "../cranelift/codegen" } cranelift-reader = { path = "../cranelift/reader" } cranelift-wasm = { path = "../cranelift/wasm" } @@ -19,6 +21,16 @@ libfuzzer-sys = "0.4.0" target-lexicon = "0.12" wasmtime = { path = "../crates/wasmtime" } wasmtime-fuzzing = { path = "../crates/fuzzing" } +component-test-util = { path = "../crates/misc/component-test-util" } +component-fuzz-util = { path = "../crates/misc/component-fuzz-util" } + +[build-dependencies] +anyhow = "1.0.19" +proc-macro2 = "1.0" +arbitrary = { version = "1.1.0", features = ["derive"] } +rand = { version = "0.8.0" } +quote = "1.0" +component-fuzz-util = { path = "../crates/misc/component-fuzz-util" } [features] default = ['fuzz-spec-interpreter'] @@ -102,3 +114,9 @@ name = "instantiate-many" path = "fuzz_targets/instantiate-many.rs" test = false doc = false + +[[bin]] +name = "component_api" +path = "fuzz_targets/component_api.rs" +test = false +doc = false diff --git a/fuzz/build.rs b/fuzz/build.rs new file mode 100644 index 000000000000..b8b45a36f45f --- /dev/null +++ b/fuzz/build.rs @@ -0,0 +1,144 @@ +fn main() -> anyhow::Result<()> { + component::generate_static_api_tests()?; + + Ok(()) +} + +mod component { + use anyhow::{anyhow, Context, Error, Result}; + use arbitrary::{Arbitrary, Unstructured}; + use component_fuzz_util::{self, Declarations, TestCase}; + use proc_macro2::TokenStream; + use quote::{format_ident, quote}; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::env; + use std::fmt::Write; + use std::fs; + use std::iter; + use std::path::PathBuf; + use std::process::Command; + + pub fn generate_static_api_tests() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + let out_dir = PathBuf::from( + env::var_os("OUT_DIR").expect("The OUT_DIR environment variable must be set"), + ); + + let mut out = String::new(); + write_static_api_tests(&mut out)?; + + let output = out_dir.join("static_component_api.rs"); + fs::write(&output, out)?; + + drop(Command::new("rustfmt").arg(&output).status()); + + Ok(()) + } + + fn write_static_api_tests(out: &mut String) -> Result<()> { + let seed = if let Ok(seed) = env::var("WASMTIME_FUZZ_SEED") { + seed.parse::() + .with_context(|| anyhow!("expected u64 in WASMTIME_FUZZ_SEED"))? + } else { + StdRng::from_entropy().gen() + }; + + eprintln!( + "using seed {seed} (set WASMTIME_FUZZ_SEED={seed} in your environment to reproduce)" + ); + + let mut rng = StdRng::seed_from_u64(seed); + + const TEST_CASE_COUNT: usize = 100; + + let mut tests = TokenStream::new(); + + let name_counter = &mut 0; + + let mut declarations = TokenStream::new(); + + for index in 0..TEST_CASE_COUNT { + let mut bytes = Vec::new(); + + let case = loop { + let count = rng.gen_range(1000..2000); + bytes.extend(iter::repeat_with(|| rng.gen::()).take(count)); + + match TestCase::arbitrary(&mut Unstructured::new(&bytes)) { + Ok(case) => break case, + Err(arbitrary::Error::NotEnoughData) => (), + Err(error) => return Err(Error::from(error)), + } + }; + + let Declarations { + types, + params, + result, + import_and_export, + } = case.declarations(); + + let test = format_ident!("static_api_test{}", case.params.len()); + + let rust_params = case + .params + .iter() + .map(|ty| { + let ty = component_fuzz_util::rust_type(&ty, name_counter, &mut declarations); + quote!(#ty,) + }) + .collect::(); + + let rust_result = + component_fuzz_util::rust_type(&case.result, name_counter, &mut declarations); + + let test = quote!(#index => component_types::#test::<#rust_params #rust_result>( + input, + &Declarations { + types: #types.into(), + params: #params.into(), + result: #result.into(), + import_and_export: #import_and_export.into() + } + ),); + + tests.extend(test); + } + + let module = quote! { + #[allow(unused_imports)] + fn static_component_api_target(input: &mut arbitrary::Unstructured) -> arbitrary::Result<()> { + use anyhow::Result; + use arbitrary::{Unstructured, Arbitrary}; + use component_test_util::{self, Float32, Float64}; + use component_fuzz_util::Declarations; + use std::sync::{Arc, Once}; + use wasmtime::component::{ComponentType, Lift, Lower}; + use wasmtime_fuzzing::generators::component_types; + + const SEED: u64 = #seed; + + static ONCE: Once = Once::new(); + + ONCE.call_once(|| { + eprintln!( + "Seed {SEED} was used to generate static component API fuzz tests.\n\ + Set WASMTIME_FUZZ_SEED={SEED} in your environment at build time to reproduce." + ); + }); + + #declarations + + match input.int_in_range(0..=(#TEST_CASE_COUNT-1))? { + #tests + _ => unreachable!() + } + } + }; + + write!(out, "{module}")?; + + Ok(()) + } +} diff --git a/fuzz/fuzz_targets/component_api.rs b/fuzz/fuzz_targets/component_api.rs new file mode 100644 index 000000000000..7dc76dc4db07 --- /dev/null +++ b/fuzz/fuzz_targets/component_api.rs @@ -0,0 +1,22 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use wasmtime_fuzzing::oracles; + +include!(concat!(env!("OUT_DIR"), "/static_component_api.rs")); + +#[allow(unused_imports)] +fn target(input: &mut arbitrary::Unstructured) -> arbitrary::Result<()> { + if input.arbitrary()? { + static_component_api_target(input) + } else { + oracles::dynamic_component_api_target(input) + } +} + +fuzz_target!(|bytes: &[u8]| { + match target(&mut arbitrary::Unstructured::new(bytes)) { + Ok(()) | Err(arbitrary::Error::NotEnoughData) => (), + Err(error) => panic!("{}", error), + } +}); diff --git a/tests/all/component_model.rs b/tests/all/component_model.rs index ada10f49c9f0..a9a2117ad151 100644 --- a/tests/all/component_model.rs +++ b/tests/all/component_model.rs @@ -1,8 +1,9 @@ use anyhow::Result; +use component_test_util::{engine, TypedFuncExt}; use std::fmt::Write; use std::iter; -use wasmtime::component::{Component, ComponentParams, Lift, Lower, TypedFunc}; -use wasmtime::{AsContextMut, Config, Engine}; +use wasmtime::component::Component; +use wasmtime_component_util::REALLOC_AND_FREE; mod dynamic; mod func; @@ -12,97 +13,6 @@ mod macros; mod nested; mod post_return; -trait TypedFuncExt { - fn call_and_post_return(&self, store: impl AsContextMut, params: P) -> Result; -} - -impl TypedFuncExt for TypedFunc -where - P: ComponentParams + Lower, - R: Lift, -{ - fn call_and_post_return(&self, mut store: impl AsContextMut, params: P) -> Result { - let result = self.call(&mut store, params)?; - self.post_return(&mut store)?; - Ok(result) - } -} - -// A simple bump allocator which can be used with modules -const REALLOC_AND_FREE: &str = r#" - (global $last (mut i32) (i32.const 8)) - (func $realloc (export "realloc") - (param $old_ptr i32) - (param $old_size i32) - (param $align i32) - (param $new_size i32) - (result i32) - - ;; Test if the old pointer is non-null - local.get $old_ptr - if - ;; If the old size is bigger than the new size then - ;; this is a shrink and transparently allow it - local.get $old_size - local.get $new_size - i32.gt_u - if - local.get $old_ptr - return - end - - ;; ... otherwise this is unimplemented - unreachable - end - - ;; align up `$last` - (global.set $last - (i32.and - (i32.add - (global.get $last) - (i32.add - (local.get $align) - (i32.const -1))) - (i32.xor - (i32.add - (local.get $align) - (i32.const -1)) - (i32.const -1)))) - - ;; save the current value of `$last` as the return value - global.get $last - - ;; ensure anything necessary is set to valid data by spraying a bit - ;; pattern that is invalid - global.get $last - i32.const 0xde - local.get $new_size - memory.fill - - ;; bump our pointer - (global.set $last - (i32.add - (global.get $last) - (local.get $new_size))) - ) -"#; - -fn engine() -> Engine { - drop(env_logger::try_init()); - - let mut config = Config::new(); - config.wasm_component_model(true); - - // When pooling allocator tests are skipped it means we're in qemu. The - // component model tests create a disproportionate number of instances so - // try to cut down on virtual memory usage by avoiding 4G reservations. - if crate::skip_pooling_allocator_tests() { - config.static_memory_maximum_size(0); - config.dynamic_memory_guard_size(0); - } - Engine::new(&config).unwrap() -} - #[test] fn components_importing_modules() -> Result<()> { let engine = engine(); @@ -113,49 +23,49 @@ fn components_importing_modules() -> Result<()> { Component::new( &engine, r#" - (component - (import "" (core module)) - ) + (component + (import "" (core module)) + ) "#, )?; Component::new( &engine, r#" - (component - (import "" (core module $m1 - (import "" "" (func)) - (import "" "x" (global i32)) - - (export "a" (table 1 funcref)) - (export "b" (memory 1)) - (export "c" (func (result f32))) - (export "d" (global i64)) - )) - - (core module $m2 - (func (export "")) - (global (export "x") i32 i32.const 0) - ) - (core instance $i2 (instantiate (module $m2))) - (core instance $i1 (instantiate (module $m1) (with "" (instance $i2)))) - - (core module $m3 - (import "mod" "1" (memory 1)) - (import "mod" "2" (table 1 funcref)) - (import "mod" "3" (global i64)) - (import "mod" "4" (func (result f32))) - ) + (component + (import "" (core module $m1 + (import "" "" (func)) + (import "" "x" (global i32)) + + (export "a" (table 1 funcref)) + (export "b" (memory 1)) + (export "c" (func (result f32))) + (export "d" (global i64)) + )) + + (core module $m2 + (func (export "")) + (global (export "x") i32 i32.const 0) + ) + (core instance $i2 (instantiate (module $m2))) + (core instance $i1 (instantiate (module $m1) (with "" (instance $i2)))) + + (core module $m3 + (import "mod" "1" (memory 1)) + (import "mod" "2" (table 1 funcref)) + (import "mod" "3" (global i64)) + (import "mod" "4" (func (result f32))) + ) - (core instance $i3 (instantiate (module $m3) - (with "mod" (instance - (export "1" (memory $i1 "b")) - (export "2" (table $i1 "a")) - (export "3" (global $i1 "d")) - (export "4" (func $i1 "c")) - )) + (core instance $i3 (instantiate (module $m3) + (with "mod" (instance + (export "1" (memory $i1 "b")) + (export "2" (table $i1 "a")) + (export "3" (global $i1 "d")) + (export "4" (func $i1 "c")) )) - ) + )) + ) "#, )?; diff --git a/tests/all/component_model/dynamic.rs b/tests/all/component_model/dynamic.rs index 66ab8c0fe6cf..f9ac63fdb871 100644 --- a/tests/all/component_model/dynamic.rs +++ b/tests/all/component_model/dynamic.rs @@ -1,19 +1,8 @@ use super::{make_echo_component, make_echo_component_with_params, Param, Type}; use anyhow::Result; -use wasmtime::component::{self, Component, Func, Linker, Val}; -use wasmtime::{AsContextMut, Store}; - -trait FuncExt { - fn call_and_post_return(&self, store: impl AsContextMut, args: &[Val]) -> Result; -} - -impl FuncExt for Func { - fn call_and_post_return(&self, mut store: impl AsContextMut, args: &[Val]) -> Result { - let result = self.call(&mut store, args)?; - self.post_return(&mut store)?; - Ok(result) - } -} +use component_test_util::FuncExt; +use wasmtime::component::{self, Component, Linker, Val}; +use wasmtime::Store; #[test] fn primitives() -> Result<()> { diff --git a/tests/all/component_model/import.rs b/tests/all/component_model/import.rs index 5e67d92df5af..5bba4fcfe016 100644 --- a/tests/all/component_model/import.rs +++ b/tests/all/component_model/import.rs @@ -1,5 +1,6 @@ use super::REALLOC_AND_FREE; use anyhow::Result; +use std::ops::Deref; use wasmtime::component::*; use wasmtime::{Store, StoreContextMut, Trap}; @@ -117,6 +118,12 @@ fn simple() -> Result<()> { "#; let engine = super::engine(); + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, None); + assert!(store.data().is_none()); + + // First, test the static API + let mut linker = Linker::new(&engine); linker.root().func_wrap( "", @@ -127,15 +134,36 @@ fn simple() -> Result<()> { Ok(()) }, )?; - let component = Component::new(&engine, component)?; - let mut store = Store::new(&engine, None); let instance = linker.instantiate(&mut store, &component)?; - assert!(store.data().is_none()); instance .get_typed_func::<(), (), _>(&mut store, "call")? .call(&mut store, ())?; assert_eq!(store.data().as_ref().unwrap(), "hello world"); + // Next, test the dynamic API + + *store.data_mut() = None; + let mut linker = Linker::new(&engine); + linker.root().func_new( + &component, + "", + |mut store: StoreContextMut<'_, Option>, args| { + if let Val::String(s) = &args[0] { + assert!(store.data().is_none()); + *store.data_mut() = Some(s.to_string()); + Ok(Val::Unit) + } else { + panic!() + } + }, + )?; + let instance = linker.instantiate(&mut store, &component)?; + instance + .get_func(&mut store, "call") + .unwrap() + .call(&mut store, &[])?; + assert_eq!(store.data().as_ref().unwrap(), "hello world"); + Ok(()) } @@ -299,15 +327,20 @@ fn attempt_to_reenter_during_host() -> Result<()> { ) "#; - struct State { + let engine = super::engine(); + let component = Component::new(&engine, component)?; + + // First, test the static API + + struct StaticState { func: Option>, } - let engine = super::engine(); + let mut store = Store::new(&engine, StaticState { func: None }); let mut linker = Linker::new(&engine); linker.root().func_wrap( "thunk", - |mut store: StoreContextMut<'_, State>| -> Result<()> { + |mut store: StoreContextMut<'_, StaticState>| -> Result<()> { let func = store.data_mut().func.take().unwrap(); let trap = func.call(&mut store, ()).unwrap_err(); assert!( @@ -319,12 +352,39 @@ fn attempt_to_reenter_during_host() -> Result<()> { Ok(()) }, )?; - let component = Component::new(&engine, component)?; - let mut store = Store::new(&engine, State { func: None }); let instance = linker.instantiate(&mut store, &component)?; let func = instance.get_typed_func::<(), (), _>(&mut store, "run")?; store.data_mut().func = Some(func); func.call(&mut store, ())?; + + // Next, test the dynamic API + + struct DynamicState { + func: Option, + } + + let mut store = Store::new(&engine, DynamicState { func: None }); + let mut linker = Linker::new(&engine); + linker.root().func_new( + &component, + "thunk", + |mut store: StoreContextMut<'_, DynamicState>, _| { + let func = store.data_mut().func.take().unwrap(); + let trap = func.call(&mut store, &[]).unwrap_err(); + assert!( + trap.to_string() + .contains("cannot reenter component instance"), + "bad trap: {}", + trap, + ); + Ok(Val::Unit) + }, + )?; + let instance = linker.instantiate(&mut store, &component)?; + let func = instance.get_func(&mut store, "run").unwrap(); + store.data_mut().func = Some(func); + func.call(&mut store, &[])?; + Ok(()) } @@ -466,6 +526,11 @@ fn stack_and_heap_args_and_rets() -> Result<()> { ); let engine = super::engine(); + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, ()); + + // First, test the static API + let mut linker = Linker::new(&engine); linker.root().func_wrap("f1", |x: u32| -> Result { assert_eq!(x, 1); @@ -515,12 +580,60 @@ fn stack_and_heap_args_and_rets() -> Result<()> { Ok("xyz".to_string()) }, )?; - let component = Component::new(&engine, component)?; - let mut store = Store::new(&engine, ()); let instance = linker.instantiate(&mut store, &component)?; instance .get_typed_func::<(), (), _>(&mut store, "run")? .call(&mut store, ())?; + + // Next, test the dynamic API + + let mut linker = Linker::new(&engine); + linker.root().func_new(&component, "f1", |_, args| { + if let Val::U32(x) = &args[0] { + assert_eq!(*x, 1); + Ok(Val::U32(2)) + } else { + panic!() + } + })?; + linker.root().func_new(&component, "f2", |_, args| { + if let Val::Tuple(tuple) = &args[0] { + if let Val::String(s) = &tuple.values()[0] { + assert_eq!(s.deref(), "abc"); + Ok(Val::U32(3)) + } else { + panic!() + } + } else { + panic!() + } + })?; + linker.root().func_new(&component, "f3", |_, args| { + if let Val::U32(x) = &args[0] { + assert_eq!(*x, 8); + Ok(Val::String("xyz".into())) + } else { + panic!(); + } + })?; + linker.root().func_new(&component, "f4", |_, args| { + if let Val::Tuple(tuple) = &args[0] { + if let Val::String(s) = &tuple.values()[0] { + assert_eq!(s.deref(), "abc"); + Ok(Val::String("xyz".into())) + } else { + panic!() + } + } else { + panic!() + } + })?; + let instance = linker.instantiate(&mut store, &component)?; + instance + .get_func(&mut store, "run") + .unwrap() + .call(&mut store, &[])?; + Ok(()) } @@ -648,6 +761,9 @@ fn no_actual_wasm_code() -> Result<()> { let engine = super::engine(); let component = Component::new(&engine, component)?; let mut store = Store::new(&engine, 0); + + // First, test the static API + let mut linker = Linker::new(&engine); linker .root() @@ -663,5 +779,23 @@ fn no_actual_wasm_code() -> Result<()> { thunk.call(&mut store, ())?; assert_eq!(*store.data(), 1); + // Next, test the dynamic API + + *store.data_mut() = 0; + let mut linker = Linker::new(&engine); + linker + .root() + .func_new(&component, "f", |mut store: StoreContextMut<'_, u32>, _| { + *store.data_mut() += 1; + Ok(Val::Unit) + })?; + + let instance = linker.instantiate(&mut store, &component)?; + let thunk = instance.get_func(&mut store, "thunk").unwrap(); + + assert_eq!(*store.data(), 0); + thunk.call(&mut store, &[])?; + assert_eq!(*store.data(), 1); + Ok(()) }