diff --git a/build.rs b/build.rs index a2a4628ecbbc..22b53276d10e 100644 --- a/build.rs +++ b/build.rs @@ -31,6 +31,7 @@ fn main() -> anyhow::Result<()> { test_directory(out, "tests/misc_testsuite", strategy)?; test_directory_module(out, "tests/misc_testsuite/multi-memory", strategy)?; test_directory_module(out, "tests/misc_testsuite/simd", strategy)?; + test_directory_module(out, "tests/misc_testsuite/tail-call", strategy)?; test_directory_module(out, "tests/misc_testsuite/threads", strategy)?; test_directory_module(out, "tests/misc_testsuite/memory64", strategy)?; test_directory_module(out, "tests/misc_testsuite/component-model", strategy)?; @@ -61,6 +62,7 @@ fn main() -> anyhow::Result<()> { "tests/spec_testsuite/proposals/relaxed-simd", strategy, )?; + test_directory_module(out, "tests/spec_testsuite/proposals/tail-call", strategy)?; } else { println!( "cargo:warning=The spec testsuite is disabled. To enable, run `git submodule \ @@ -213,11 +215,6 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool { return true; } - // Tail calls are not yet implemented. - if testname.contains("return_call") { - return true; - } - if testsuite == "function_references" { // The following tests fail due to function references not yet // being exposed in the public API. @@ -241,6 +238,9 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool { "s390x" => { // FIXME: These tests fail under qemu due to a qemu bug. testname == "simd_f32x4_pmin_pmax" || testname == "simd_f64x2_pmin_pmax" + // TODO(#6530): These tests require tail calls, but s390x + // doesn't support them yet. + || testsuite == "function_references" || testsuite == "tail_call" } "riscv64" => { diff --git a/cranelift/codegen/src/cursor.rs b/cranelift/codegen/src/cursor.rs index 1a09d6f37646..a0176e6296f5 100644 --- a/cranelift/codegen/src/cursor.rs +++ b/cranelift/codegen/src/cursor.rs @@ -593,7 +593,7 @@ impl<'f> FuncCursor<'f> { } /// Create an instruction builder that inserts an instruction at the current position. - pub fn ins(&mut self) -> ir::InsertBuilder<&mut FuncCursor<'f>> { + pub fn ins(&mut self) -> ir::InsertBuilder<'_, &mut FuncCursor<'f>> { ir::InsertBuilder::new(self) } } diff --git a/cranelift/codegen/src/isa/riscv64/abi.rs b/cranelift/codegen/src/isa/riscv64/abi.rs index 50fa73b6cf9c..e4cfb0e101c7 100644 --- a/cranelift/codegen/src/isa/riscv64/abi.rs +++ b/cranelift/codegen/src/isa/riscv64/abi.rs @@ -702,13 +702,13 @@ impl Riscv64ABICallSite { }); match dest { - // TODO: Our riscv64 backend doesn't have relocs for direct calls, - // the callee is always put in a register and then the register is - // relocated, so we don't currently differentiate between - // `RelocDistance::Near` and `RelocDistance::Far`. We just always - // use indirect calls. We should eventually add a non-indirect - // `return_call` instruction and path. - CallDest::ExtName(name, _) => { + CallDest::ExtName(name, RelocDistance::Near) => { + ctx.emit(Inst::ReturnCall { + callee: Box::new(name), + info, + }); + } + CallDest::ExtName(name, RelocDistance::Far) => { let callee = ctx.alloc_tmp(ir::types::I64).only_reg().unwrap(); ctx.emit(Inst::LoadExtName { rd: callee, diff --git a/cranelift/codegen/src/isa/riscv64/inst.isle b/cranelift/codegen/src/isa/riscv64/inst.isle index 233bc9a82734..264880d0def9 100644 --- a/cranelift/codegen/src/isa/riscv64/inst.isle +++ b/cranelift/codegen/src/isa/riscv64/inst.isle @@ -98,6 +98,11 @@ (CallInd (info BoxCallIndInfo)) + ;; A direct return-call macro instruction. + (ReturnCall + (callee BoxExternalName) + (info BoxReturnCallInfo)) + ;; An indirect return-call macro instruction. (ReturnCallInd (callee Reg) diff --git a/cranelift/codegen/src/isa/riscv64/inst/emit.rs b/cranelift/codegen/src/isa/riscv64/inst/emit.rs index 15f44ebfafe9..348007d2ffb0 100644 --- a/cranelift/codegen/src/isa/riscv64/inst/emit.rs +++ b/cranelift/codegen/src/isa/riscv64/inst/emit.rs @@ -424,6 +424,7 @@ impl Inst { | Inst::AdjustSp { .. } | Inst::Call { .. } | Inst::CallInd { .. } + | Inst::ReturnCall { .. } | Inst::ReturnCallInd { .. } | Inst::TrapIf { .. } | Inst::Jal { .. } @@ -884,6 +885,32 @@ impl MachInstEmit for Inst { ); } + &Inst::ReturnCall { + ref callee, + ref info, + } => { + emit_return_call_common_sequence( + &mut allocs, + sink, + emit_info, + state, + info.new_stack_arg_size, + info.old_stack_arg_size, + &info.uses, + ); + + sink.add_call_site(ir::Opcode::ReturnCall); + sink.add_reloc(Reloc::RiscvCall, &callee, 0); + Inst::construct_auipc_and_jalr(None, writable_spilltmp_reg(), 0) + .into_iter() + .for_each(|i| i.emit(&[], sink, emit_info, state)); + + // `emit_return_call_common_sequence` emits an island if + // necessary, so we can safely disable the worst-case-size check + // in this case. + start_off = sink.cur_offset(); + } + &Inst::ReturnCallInd { callee, ref info } => { let callee = allocs.next(callee); @@ -903,6 +930,11 @@ impl MachInstEmit for Inst { offset: Imm12::zero(), } .emit(&[], sink, emit_info, state); + + // `emit_return_call_common_sequence` emits an island if + // necessary, so we can safely disable the worst-case-size check + // in this case. + start_off = sink.cur_offset(); } &Inst::Jal { dest } => { @@ -3089,9 +3121,10 @@ fn emit_return_call_common_sequence( // We are emitting a dynamic number of instructions and might need an // island. We emit four instructions regardless of how many stack arguments - // we have, and then two instructions per word of stack argument space. + // we have, up to two instructions for the actual call, and then two + // instructions per word of stack argument space. let new_stack_words = new_stack_arg_size / 8; - let insts = 4 + 2 * new_stack_words; + let insts = 4 + 2 + 2 * new_stack_words; let space_needed = insts * u32::try_from(Inst::INSTRUCTION_SIZE).unwrap(); if sink.island_needed(space_needed) { let jump_around_label = sink.get_label(); diff --git a/cranelift/codegen/src/isa/riscv64/inst/mod.rs b/cranelift/codegen/src/isa/riscv64/inst/mod.rs index 9b96832299cd..ca6bfe4fb537 100644 --- a/cranelift/codegen/src/isa/riscv64/inst/mod.rs +++ b/cranelift/codegen/src/isa/riscv64/inst/mod.rs @@ -459,6 +459,14 @@ fn riscv64_get_operands VReg>(inst: &Inst, collector: &mut Operan } collector.reg_clobbers(info.clobbers); } + &Inst::ReturnCall { + callee: _, + ref info, + } => { + for u in &info.uses { + collector.reg_fixed_use(u.vreg, u.preg); + } + } &Inst::ReturnCallInd { ref info, callee } => { collector.reg_use(callee); for u in &info.uses { @@ -880,7 +888,7 @@ impl MachInst for Inst { &Inst::Jalr { .. } => MachTerminator::Uncond, &Inst::Ret { .. } => MachTerminator::Ret, &Inst::BrTable { .. } => MachTerminator::Indirect, - &Inst::ReturnCallInd { .. } => MachTerminator::RetCall, + &Inst::ReturnCall { .. } | &Inst::ReturnCallInd { .. } => MachTerminator::RetCall, _ => MachTerminator::None, } } @@ -1602,6 +1610,21 @@ impl Inst { let rd = format_reg(info.rn, allocs); format!("callind {}", rd) } + &MInst::ReturnCall { + ref callee, + ref info, + } => { + let mut s = format!( + "return_call {callee:?} old_stack_arg_size:{} new_stack_arg_size:{}", + info.old_stack_arg_size, info.new_stack_arg_size + ); + for ret in &info.uses { + let preg = format_reg(ret.preg, &mut empty_allocs); + let vreg = format_reg(ret.vreg, allocs); + write!(&mut s, " {vreg}={preg}").unwrap(); + } + s + } &MInst::ReturnCallInd { callee, ref info } => { let callee = format_reg(callee, allocs); let mut s = format!( diff --git a/cranelift/codegen/src/machinst/buffer.rs b/cranelift/codegen/src/machinst/buffer.rs index 08ef25f0b67e..9f458052e185 100644 --- a/cranelift/codegen/src/machinst/buffer.rs +++ b/cranelift/codegen/src/machinst/buffer.rs @@ -1781,6 +1781,9 @@ impl TextSectionBuilder for MachTextSectionBuilder { } fn resolve_reloc(&mut self, offset: u64, reloc: Reloc, addend: Addend, target: usize) -> bool { + crate::trace!( + "Resolving relocation @ {offset:#x} + {addend:#x} to target {target} of kind {reloc:?}" + ); let label = MachLabel::from_block(BlockIndex::new(target)); let offset = u32::try_from(offset).unwrap(); match I::LabelUse::from_reloc(reloc, addend) { diff --git a/cranelift/filetests/filetests/isa/riscv64/return-call.clif b/cranelift/filetests/filetests/isa/riscv64/return-call.clif index df5f8fb3481d..4d9972f5ae29 100644 --- a/cranelift/filetests/filetests/isa/riscv64/return-call.clif +++ b/cranelift/filetests/filetests/isa/riscv64/return-call.clif @@ -69,8 +69,7 @@ block0(v0: i64): ; sd fp,0(sp) ; mv fp,sp ; block0: -; load_sym t2,%callee_i64+0 -; return_call_ind t2 old_stack_arg_size:0 new_stack_arg_size:0 s1=s1 +; return_call TestCase(%callee_i64) old_stack_arg_size:0 new_stack_arg_size:0 s1=s1 ; ; Disassembled: ; block0: ; offset 0x0 @@ -79,16 +78,12 @@ block0(v0: i64): ; sd s0, 0(sp) ; ori s0, sp, 0 ; block1: ; offset 0x10 -; auipc t2, 0 -; ld t2, 0xc(t2) -; j 0xc -; .byte 0x00, 0x00, 0x00, 0x00 ; reloc_external Abs8 %callee_i64 0 -; .byte 0x00, 0x00, 0x00, 0x00 ; ld ra, 8(s0) ; ld t6, 0(s0) ; addi sp, s0, 0x10 ; ori s0, t6, 0 -; jr t2 +; auipc t6, 0 ; reloc_external RiscvCall %callee_i64 0 +; jr t6 ;;;; Test passing `f64`s ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; diff --git a/cranelift/filetests/src/test_wasm/env.rs b/cranelift/filetests/src/test_wasm/env.rs index c0b97613e52f..89c7ad2119eb 100644 --- a/cranelift/filetests/src/test_wasm/env.rs +++ b/cranelift/filetests/src/test_wasm/env.rs @@ -417,6 +417,17 @@ impl<'a> FuncEnvironment for FuncEnv<'a> { ) } + fn translate_return_call_ref( + &mut self, + builder: &mut cranelift_frontend::FunctionBuilder, + sig_ref: ir::SigRef, + callee: ir::Value, + call_args: &[ir::Value], + ) -> cranelift_wasm::WasmResult<()> { + self.inner + .translate_return_call_ref(builder, sig_ref, callee, call_args) + } + fn translate_memory_grow( &mut self, pos: cranelift_codegen::cursor::FuncCursor, diff --git a/cranelift/frontend/src/lib.rs b/cranelift/frontend/src/lib.rs index fbae35afb602..4487d0116597 100644 --- a/cranelift/frontend/src/lib.rs +++ b/cranelift/frontend/src/lib.rs @@ -191,7 +191,7 @@ use hashbrown::HashMap; #[cfg(feature = "std")] use std::collections::HashMap; -pub use crate::frontend::{FunctionBuilder, FunctionBuilderContext}; +pub use crate::frontend::{FuncInstBuilder, FunctionBuilder, FunctionBuilderContext}; pub use crate::switch::Switch; pub use crate::variable::Variable; diff --git a/cranelift/wasm/src/code_translator.rs b/cranelift/wasm/src/code_translator.rs index 3a83b7f1bf45..0ddfce74623b 100644 --- a/cranelift/wasm/src/code_translator.rs +++ b/cranelift/wasm/src/code_translator.rs @@ -624,7 +624,7 @@ pub fn translate_operator( ); let call = environ.translate_call( - builder.cursor(), + builder, FuncIndex::from_u32(*function_index), fref, args, @@ -694,7 +694,7 @@ pub fn translate_operator( ); environ.translate_return_call( - builder.cursor(), + builder, FuncIndex::from_u32(*function_index), fref, args, @@ -731,11 +731,21 @@ pub fn translate_operator( state.popn(num_args); state.reachable = false; } - Operator::ReturnCallRef { type_index: _ } => { - return Err(wasm_unsupported!( - "proposed tail-call operator for function references {:?}", - op - )); + Operator::ReturnCallRef { type_index } => { + // Get function signature + // `index` is the index of the function's signature and `table_index` is the index of + // the table to search the function in. + let (sigref, num_args) = state.get_indirect_sig(builder.func, *type_index, environ)?; + let callee = state.pop1(); + + // Bitcast any vector arguments to their default type, I8X16, before calling. + let args = state.peekn_mut(num_args); + bitcast_wasm_params(environ, sigref, args, builder); + + environ.translate_return_call_ref(builder, sigref, callee, state.peekn(num_args))?; + + state.popn(num_args); + state.reachable = false; } /******************************* Memory management *********************************** * Memory management is handled by environment. It is usually translated into calls to diff --git a/cranelift/wasm/src/environ/dummy.rs b/cranelift/wasm/src/environ/dummy.rs index 09c50a59e8af..86622c98c41e 100644 --- a/cranelift/wasm/src/environ/dummy.rs +++ b/cranelift/wasm/src/environ/dummy.rs @@ -457,15 +457,25 @@ impl<'dummy_environment> FuncEnvironment for DummyFuncEnvironment<'dummy_environ unimplemented!() } + fn translate_return_call_ref( + &mut self, + _builder: &mut FunctionBuilder, + _sig_ref: ir::SigRef, + _callee: ir::Value, + _call_args: &[ir::Value], + ) -> WasmResult<()> { + unimplemented!() + } + fn translate_call( &mut self, - mut pos: FuncCursor, + builder: &mut FunctionBuilder, _callee_index: FuncIndex, callee: ir::FuncRef, call_args: &[ir::Value], ) -> WasmResult { // Pass the current function's vmctx parameter on to the callee. - let vmctx = pos + let vmctx = builder .func .special_param(ir::ArgumentPurpose::VMContext) .expect("Missing vmctx parameter"); @@ -473,10 +483,13 @@ impl<'dummy_environment> FuncEnvironment for DummyFuncEnvironment<'dummy_environ // Build a value list for the call instruction containing the call_args and the vmctx // parameter. let mut args = ir::ValueList::default(); - args.extend(call_args.iter().cloned(), &mut pos.func.dfg.value_lists); - args.push(vmctx, &mut pos.func.dfg.value_lists); + args.extend(call_args.iter().cloned(), &mut builder.func.dfg.value_lists); + args.push(vmctx, &mut builder.func.dfg.value_lists); - Ok(pos.ins().Call(ir::Opcode::Call, INVALID, callee, args).0) + Ok(builder + .ins() + .Call(ir::Opcode::Call, INVALID, callee, args) + .0) } fn translate_call_ref( diff --git a/cranelift/wasm/src/environ/spec.rs b/cranelift/wasm/src/environ/spec.rs index 2a0e43f0d977..9dca6d761abc 100644 --- a/cranelift/wasm/src/environ/spec.rs +++ b/cranelift/wasm/src/environ/spec.rs @@ -179,12 +179,12 @@ pub trait FuncEnvironment: TargetEnvironment { /// Return the call instruction whose results are the WebAssembly return values. fn translate_call( &mut self, - mut pos: FuncCursor, + builder: &mut FunctionBuilder, _callee_index: FuncIndex, callee: ir::FuncRef, call_args: &[ir::Value], ) -> WasmResult { - Ok(pos.ins().call(callee, call_args)) + Ok(builder.ins().call(callee, call_args)) } /// Translate a `call_indirect` WebAssembly instruction at `pos`. @@ -208,29 +208,33 @@ pub trait FuncEnvironment: TargetEnvironment { call_args: &[ir::Value], ) -> WasmResult; - /// Translate a `return_call` WebAssembly instruction at `pos`. + /// Translate a `return_call` WebAssembly instruction at the builder's + /// current position. /// - /// Insert instructions at `pos` for a direct tail call to the function `callee_index`. + /// Insert instructions at the builder's current position for a direct tail + /// call to the function `callee_index`. /// /// The function reference `callee` was previously created by `make_direct_func()`. /// /// Return the call instruction whose results are the WebAssembly return values. fn translate_return_call( &mut self, - mut pos: FuncCursor, + builder: &mut FunctionBuilder, _callee_index: FuncIndex, callee: ir::FuncRef, call_args: &[ir::Value], ) -> WasmResult<()> { - pos.ins().return_call(callee, call_args); + builder.ins().return_call(callee, call_args); Ok(()) } - /// Translate a `return_call_indirect` WebAssembly instruction at `pos`. + /// Translate a `return_call_indirect` WebAssembly instruction at the + /// builder's current position. /// - /// Insert instructions at `pos` for an indirect tail call to the function - /// `callee` in the table `table_index` with WebAssembly signature - /// `sig_index`. The `callee` value will have type `i32`. + /// Insert instructions at the builder's current position for an indirect + /// tail call to the function `callee` in the table `table_index` with + /// WebAssembly signature `sig_index`. The `callee` value will have type + /// `i32`. /// /// The signature `sig_ref` was previously created by `make_indirect_sig()`. #[allow(clippy::too_many_arguments)] @@ -245,10 +249,30 @@ pub trait FuncEnvironment: TargetEnvironment { call_args: &[ir::Value], ) -> WasmResult<()>; - /// Translate a `call_ref` WebAssembly instruction at `pos`. + /// Translate a `return_call_ref` WebAssembly instruction at the builder's + /// given position. + /// + /// Insert instructions at the builder's current position for an indirect + /// tail call to the function `callee`. The `callee` value will be a Wasm + /// funcref that may need to be translated to a native function address + /// depending on your implementation of this trait. + /// + /// The signature `sig_ref` was previously created by `make_indirect_sig()`. + fn translate_return_call_ref( + &mut self, + builder: &mut FunctionBuilder, + sig_ref: ir::SigRef, + callee: ir::Value, + call_args: &[ir::Value], + ) -> WasmResult<()>; + + /// Translate a `call_ref` WebAssembly instruction at the builder's current + /// position. /// - /// Insert instructions at `pos` for an indirect call to the - /// function `callee`. The `callee` value will have type `Ref`. + /// Insert instructions at the builder's current position for an indirect + /// call to the function `callee`. The `callee` value will be a Wasm funcref + /// that may need to be translated to a native function address depending on + /// your implementation of this trait. /// /// The signature `sig_ref` was previously created by `make_indirect_sig()`. /// diff --git a/crates/cli-flags/src/lib.rs b/crates/cli-flags/src/lib.rs index 5bb8e351ddcb..9febfcd17893 100644 --- a/crates/cli-flags/src/lib.rs +++ b/crates/cli-flags/src/lib.rs @@ -39,6 +39,7 @@ pub const SUPPORTED_WASM_FEATURES: &[(&str, &str)] = &[ "relaxed-simd", "enables support for the relaxed simd proposal", ), + ("tail-call", "enables support for WebAssembly tail calls"), ("threads", "enables support for WebAssembly threads"), ("memory64", "enables support for 64-bit memories"), #[cfg(feature = "component-model")] @@ -371,6 +372,7 @@ impl CommonOptions { bulk_memory, reference_types, multi_value, + tail_call, threads, multi_memory, memory64, @@ -397,6 +399,9 @@ impl CommonOptions { if let Some(enable) = multi_value { config.wasm_multi_value(enable); } + if let Some(enable) = tail_call { + config.wasm_tail_call(enable); + } if let Some(enable) = threads { config.wasm_threads(enable); } @@ -441,6 +446,7 @@ pub struct WasmFeatures { pub bulk_memory: Option, pub simd: Option, pub relaxed_simd: Option, + pub tail_call: Option, pub threads: Option, pub multi_memory: Option, pub memory64: Option, @@ -493,6 +499,7 @@ fn parse_wasm_features(features: &str) -> Result { bulk_memory: all.or(values["bulk-memory"]), simd: all.or(values["simd"]), relaxed_simd: all.or(values["relaxed-simd"]), + tail_call: all.or(values["tail-call"]), threads: all.or(values["threads"]), multi_memory: all.or(values["multi-memory"]), memory64: all.or(values["memory64"]), @@ -611,6 +618,7 @@ mod test { bulk_memory, simd, relaxed_simd, + tail_call, threads, multi_memory, memory64, @@ -621,6 +629,7 @@ mod test { assert_eq!(multi_value, Some(true)); assert_eq!(bulk_memory, Some(true)); assert_eq!(simd, Some(true)); + assert_eq!(tail_call, Some(true)); assert_eq!(threads, Some(true)); assert_eq!(multi_memory, Some(true)); assert_eq!(memory64, Some(true)); @@ -640,6 +649,7 @@ mod test { bulk_memory, simd, relaxed_simd, + tail_call, threads, multi_memory, memory64, @@ -650,6 +660,7 @@ mod test { assert_eq!(multi_value, Some(false)); assert_eq!(bulk_memory, Some(false)); assert_eq!(simd, Some(false)); + assert_eq!(tail_call, Some(false)); assert_eq!(threads, Some(false)); assert_eq!(multi_memory, Some(false)); assert_eq!(memory64, Some(false)); @@ -672,6 +683,7 @@ mod test { bulk_memory, simd, relaxed_simd, + tail_call, threads, multi_memory, memory64, @@ -682,6 +694,7 @@ mod test { assert_eq!(multi_value, None); assert_eq!(bulk_memory, None); assert_eq!(simd, Some(true)); + assert_eq!(tail_call, None); assert_eq!(threads, None); assert_eq!(multi_memory, Some(true)); assert_eq!(memory64, Some(true)); @@ -725,6 +738,7 @@ mod test { feature_test!(test_bulk_memory_feature, bulk_memory, "bulk-memory"); feature_test!(test_simd_feature, simd, "simd"); feature_test!(test_relaxed_simd_feature, relaxed_simd, "relaxed-simd"); + feature_test!(test_tail_call_feature, tail_call, "tail-call"); feature_test!(test_threads_feature, threads, "threads"); feature_test!(test_multi_memory_feature, multi_memory, "multi-memory"); feature_test!(test_memory64_feature, memory64, "memory64"); diff --git a/crates/cranelift/src/builder.rs b/crates/cranelift/src/builder.rs index 5fcf8e05fbc3..cd255eccd999 100644 --- a/crates/cranelift/src/builder.rs +++ b/crates/cranelift/src/builder.rs @@ -12,9 +12,10 @@ use std::fmt; use std::path; use std::sync::Arc; use wasmtime_cranelift_shared::isa_builder::IsaBuilder; -use wasmtime_environ::{CacheStore, CompilerBuilder, Setting}; +use wasmtime_environ::{CacheStore, CompilerBuilder, Setting, Tunables}; struct Builder { + tunables: Tunables, inner: IsaBuilder>, linkopts: LinkOptions, cache_store: Option>, @@ -36,6 +37,7 @@ pub struct LinkOptions { pub fn builder() -> Box { Box::new(Builder { + tunables: Tunables::default(), inner: IsaBuilder::new(|triple| isa::lookup(triple).map_err(|e| e.into())), linkopts: LinkOptions::default(), cache_store: None, @@ -76,9 +78,15 @@ impl CompilerBuilder for Builder { self.inner.enable(name) } + fn set_tunables(&mut self, tunables: Tunables) -> Result<()> { + self.tunables = tunables; + Ok(()) + } + fn build(&self) -> Result> { let isa = self.inner.build()?; Ok(Box::new(crate::compiler::Compiler::new( + self.tunables.clone(), isa, self.cache_store.clone(), self.linkopts.clone(), diff --git a/crates/cranelift/src/compiler.rs b/crates/cranelift/src/compiler.rs index 3f6aa8dffbd4..3578cad1374e 100644 --- a/crates/cranelift/src/compiler.rs +++ b/crates/cranelift/src/compiler.rs @@ -65,6 +65,7 @@ impl Default for CompilerContext { /// A compiler that compiles a WebAssembly module with Compiler, translating /// the Wasm to Compiler IR, optimizing it and then translating to assembly. pub(crate) struct Compiler { + tunables: Tunables, contexts: Mutex>, isa: OwnedTargetIsa, linkopts: LinkOptions, @@ -102,6 +103,7 @@ impl Drop for Compiler { impl Compiler { pub(crate) fn new( + tunables: Tunables, isa: OwnedTargetIsa, cache_store: Option>, linkopts: LinkOptions, @@ -109,6 +111,7 @@ impl Compiler { ) -> Compiler { Compiler { contexts: Default::default(), + tunables, isa, linkopts, cache_store, @@ -123,7 +126,6 @@ impl wasmtime_environ::Compiler for Compiler { translation: &ModuleTranslation<'_>, func_index: DefinedFuncIndex, input: FunctionBodyData<'_>, - tunables: &Tunables, types: &ModuleTypes, ) -> Result<(WasmFunctionInfo, Box), CompileError> { let isa = &*self.isa; @@ -135,17 +137,17 @@ impl wasmtime_environ::Compiler for Compiler { let mut compiler = self.function_compiler(); let context = &mut compiler.cx.codegen_context; - context.func.signature = wasm_call_signature(isa, wasm_func_ty); + context.func.signature = wasm_call_signature(isa, wasm_func_ty, &self.tunables); context.func.name = UserFuncName::User(UserExternalName { namespace: 0, index: func_index.as_u32(), }); - if tunables.generate_native_debuginfo { + if self.tunables.generate_native_debuginfo { context.func.collect_debug_info(); } - let mut func_env = FuncEnvironment::new(isa, translation, types, tunables); + let mut func_env = FuncEnvironment::new(isa, translation, types, &self.tunables); // The `stack_limit` global value below is the implementation of stack // overflow checks in Wasmtime. @@ -220,7 +222,7 @@ impl wasmtime_environ::Compiler for Compiler { write!(output, "{}", context.func.display()).unwrap(); } - let (info, func) = compiler.finish_with_info(Some((&body, tunables)))?; + let (info, func) = compiler.finish_with_info(Some((&body, &self.tunables)))?; let timing = cranelift_codegen::timing::take_current(); log::debug!("{:?} translated in {:?}", func_index, timing.total()); @@ -241,7 +243,7 @@ impl wasmtime_environ::Compiler for Compiler { let isa = &*self.isa; let pointer_type = isa.pointer_type(); - let wasm_call_sig = wasm_call_signature(isa, wasm_func_ty); + let wasm_call_sig = wasm_call_signature(isa, wasm_func_ty, &self.tunables); let array_call_sig = array_call_signature(isa); let mut compiler = self.function_compiler(); @@ -310,7 +312,7 @@ impl wasmtime_environ::Compiler for Compiler { let isa = &*self.isa; let pointer_type = isa.pointer_type(); let func_index = translation.module.func_index(def_func_index); - let wasm_call_sig = wasm_call_signature(isa, wasm_func_ty); + let wasm_call_sig = wasm_call_signature(isa, wasm_func_ty, &self.tunables); let native_call_sig = native_call_signature(isa, wasm_func_ty); let mut compiler = self.function_compiler(); @@ -355,7 +357,7 @@ impl wasmtime_environ::Compiler for Compiler { ) -> Result, CompileError> { let isa = &*self.isa; let pointer_type = isa.pointer_type(); - let wasm_call_sig = wasm_call_signature(isa, wasm_func_ty); + let wasm_call_sig = wasm_call_signature(isa, wasm_func_ty, &self.tunables); let native_call_sig = native_call_signature(isa, wasm_func_ty); let mut compiler = self.function_compiler(); @@ -441,7 +443,6 @@ impl wasmtime_environ::Compiler for Compiler { &self, obj: &mut Object<'static>, funcs: &[(String, Box)], - tunables: &Tunables, resolve_reloc: &dyn Fn(usize, FuncIndex) -> usize, ) -> Result> { let mut builder = @@ -458,7 +459,7 @@ impl wasmtime_environ::Compiler for Compiler { .downcast_ref::>() .unwrap(); let (sym, range) = builder.append_func(&sym, func, |idx| resolve_reloc(i, idx)); - if tunables.generate_address_map { + if self.tunables.generate_address_map { let addr = func.address_map(); addrs.push(range.clone(), &addr.instructions); } @@ -473,7 +474,7 @@ impl wasmtime_environ::Compiler for Compiler { builder.finish(); - if tunables.generate_address_map { + if self.tunables.generate_address_map { addrs.append_to(obj); } traps.append_to(obj); @@ -783,7 +784,7 @@ impl Compiler { ) -> Result, CompileError> { let isa = &*self.isa; let pointer_type = isa.pointer_type(); - let wasm_call_sig = wasm_call_signature(isa, ty); + let wasm_call_sig = wasm_call_signature(isa, ty, &self.tunables); let array_call_sig = array_call_signature(isa); let mut compiler = self.function_compiler(); diff --git a/crates/cranelift/src/compiler/component.rs b/crates/cranelift/src/compiler/component.rs index c61ce3a74f0a..e914d8a26805 100644 --- a/crates/cranelift/src/compiler/component.rs +++ b/crates/cranelift/src/compiler/component.rs @@ -44,7 +44,7 @@ impl<'a> TrampolineCompiler<'a> { let func = ir::Function::with_name_signature( ir::UserFuncName::user(0, 0), match abi { - Abi::Wasm => crate::wasm_call_signature(isa, ty), + Abi::Wasm => crate::wasm_call_signature(isa, ty, &compiler.tunables), Abi::Native => crate::native_call_signature(isa, ty), Abi::Array => crate::array_call_signature(isa), }, @@ -487,8 +487,14 @@ impl<'a> TrampolineCompiler<'a> { dtor_func_ref, i32::from(self.offsets.ptr.vm_func_ref_vmctx()), ); - let sig = crate::wasm_call_signature(self.isa, &self.types[self.signature]); + + let sig = crate::wasm_call_signature( + self.isa, + &self.types[self.signature], + &self.compiler.tunables, + ); let sig_ref = self.builder.import_signature(sig); + // NB: note that the "caller" vmctx here is the caller of this // intrinsic itself, not the `VMComponentContext`. This effectively // takes ourselves out of the chain here but that's ok since the diff --git a/crates/cranelift/src/func_environ.rs b/crates/cranelift/src/func_environ.rs index 1b951c422929..d7b20d98e54b 100644 --- a/crates/cranelift/src/func_environ.rs +++ b/crates/cranelift/src/func_environ.rs @@ -292,52 +292,6 @@ impl<'module_environment> FuncEnvironment<'module_environment> { (base, func_addr) } - /// This calls a function by reference without checking the signature. It - /// gets the function address, sets relevant flags, and passes the special - /// callee/caller vmctxs. It is used by both call_indirect (which checks the - /// signature) and call_ref (which doesn't). - fn call_function_unchecked( - &mut self, - builder: &mut FunctionBuilder, - sig_ref: ir::SigRef, - callee: ir::Value, - call_args: &[ir::Value], - ) -> WasmResult { - let pointer_type = self.pointer_type(); - - // Dereference callee pointer to get the function address. - let mem_flags = ir::MemFlags::trusted().with_readonly(); - let func_addr = builder.ins().load( - pointer_type, - mem_flags, - callee, - i32::from(self.offsets.ptr.vm_func_ref_wasm_call()), - ); - - let mut real_call_args = Vec::with_capacity(call_args.len() + 2); - let caller_vmctx = builder - .func - .special_param(ArgumentPurpose::VMContext) - .unwrap(); - - // First append the callee vmctx address. - let vmctx = builder.ins().load( - pointer_type, - mem_flags, - callee, - i32::from(self.offsets.ptr.vm_func_ref_vmctx()), - ); - real_call_args.push(vmctx); - real_call_args.push(caller_vmctx); - - // Then append the regular call arguments. - real_call_args.extend_from_slice(call_args); - - Ok(builder - .ins() - .call_indirect(sig_ref, func_addr, &real_call_args)) - } - /// Generate code to increment or decrement the given `externref`'s /// reference count. /// @@ -860,6 +814,266 @@ impl<'module_environment> FuncEnvironment<'module_environment> { } } +struct Call<'a, 'func, 'module_env> { + builder: &'a mut FunctionBuilder<'func>, + env: &'a mut FuncEnvironment<'module_env>, + tail: bool, +} + +impl<'a, 'func, 'module_env> Call<'a, 'func, 'module_env> { + /// Create a new `Call` site that will do regular, non-tail calls. + pub fn new( + builder: &'a mut FunctionBuilder<'func>, + env: &'a mut FuncEnvironment<'module_env>, + ) -> Self { + Call { + builder, + env, + tail: false, + } + } + + /// Create a new `Call` site that will perform tail calls. + pub fn new_tail( + builder: &'a mut FunctionBuilder<'func>, + env: &'a mut FuncEnvironment<'module_env>, + ) -> Self { + Call { + builder, + env, + tail: true, + } + } + + /// Do a direct call to the given callee function. + pub fn direct_call( + mut self, + callee_index: FuncIndex, + callee: ir::FuncRef, + call_args: &[ir::Value], + ) -> WasmResult { + let mut real_call_args = Vec::with_capacity(call_args.len() + 2); + let caller_vmctx = self + .builder + .func + .special_param(ArgumentPurpose::VMContext) + .unwrap(); + + // Handle direct calls to locally-defined functions. + if !self.env.module.is_imported_function(callee_index) { + // First append the callee vmctx address, which is the same as the caller vmctx in + // this case. + real_call_args.push(caller_vmctx); + + // Then append the caller vmctx address. + real_call_args.push(caller_vmctx); + + // Then append the regular call arguments. + real_call_args.extend_from_slice(call_args); + + // Finally, make the direct call! + return Ok(self.direct_call_inst(callee, &real_call_args)); + } + + // Handle direct calls to imported functions. We use an indirect call + // so that we don't have to patch the code at runtime. + let pointer_type = self.env.pointer_type(); + let sig_ref = self.builder.func.dfg.ext_funcs[callee].signature; + let vmctx = self.env.vmctx(self.builder.func); + let base = self.builder.ins().global_value(pointer_type, vmctx); + + let mem_flags = ir::MemFlags::trusted().with_readonly(); + + // Load the callee address. + let body_offset = i32::try_from( + self.env + .offsets + .vmctx_vmfunction_import_wasm_call(callee_index), + ) + .unwrap(); + let func_addr = self + .builder + .ins() + .load(pointer_type, mem_flags, base, body_offset); + + // First append the callee vmctx address. + let vmctx_offset = + i32::try_from(self.env.offsets.vmctx_vmfunction_import_vmctx(callee_index)).unwrap(); + let vmctx = self + .builder + .ins() + .load(pointer_type, mem_flags, base, vmctx_offset); + real_call_args.push(vmctx); + real_call_args.push(caller_vmctx); + + // Then append the regular call arguments. + real_call_args.extend_from_slice(call_args); + + // Finally, make the indirect call! + Ok(self.indirect_call_inst(sig_ref, func_addr, &real_call_args)) + } + + /// Do an indirect call through the given funcref table. + pub fn indirect_call( + mut self, + table_index: TableIndex, + table: ir::Table, + ty_index: TypeIndex, + sig_ref: ir::SigRef, + callee: ir::Value, + call_args: &[ir::Value], + ) -> WasmResult { + let pointer_type = self.env.pointer_type(); + + // Get the funcref pointer from the table. + let funcref_ptr = + self.env + .get_or_init_func_ref_table_elem(self.builder, table_index, table, callee); + + // Check for whether the table element is null, and trap if so. + self.builder + .ins() + .trapz(funcref_ptr, ir::TrapCode::IndirectCallToNull); + + // If necessary, check the signature. + match self.env.module.table_plans[table_index].style { + TableStyle::CallerChecksSignature => { + let sig_id_size = self.env.offsets.size_of_vmshared_signature_index(); + let sig_id_type = Type::int(u16::from(sig_id_size) * 8).unwrap(); + let vmctx = self.env.vmctx(self.builder.func); + let base = self.builder.ins().global_value(pointer_type, vmctx); + + // Load the caller ID. This requires loading the `*mut + // VMFuncRef` base pointer from `VMContext` and then loading, + // based on `SignatureIndex`, the corresponding entry. + let mem_flags = ir::MemFlags::trusted().with_readonly(); + let signatures = self.builder.ins().load( + pointer_type, + mem_flags, + base, + i32::try_from(self.env.offsets.vmctx_signature_ids_array()).unwrap(), + ); + let sig_index = self.env.module.types[ty_index].unwrap_function(); + let offset = + i32::try_from(sig_index.as_u32().checked_mul(sig_id_type.bytes()).unwrap()) + .unwrap(); + let caller_sig_id = + self.builder + .ins() + .load(sig_id_type, mem_flags, signatures, offset); + + // Load the callee ID. + let mem_flags = ir::MemFlags::trusted().with_readonly(); + let callee_sig_id = self.builder.ins().load( + sig_id_type, + mem_flags, + funcref_ptr, + i32::from(self.env.offsets.ptr.vm_func_ref_type_index()), + ); + + // Check that they match. + let cmp = self + .builder + .ins() + .icmp(IntCC::Equal, callee_sig_id, caller_sig_id); + self.builder.ins().trapz(cmp, ir::TrapCode::BadSignature); + } + } + + self.unchecked_call(sig_ref, funcref_ptr, call_args) + } + + /// Call a typed function reference. + pub fn call_ref( + mut self, + sig_ref: ir::SigRef, + callee: ir::Value, + args: &[ir::Value], + ) -> WasmResult { + // Check for whether the callee is null, and trap if so. + // + // FIXME: the wasm type system tracks enough information to know whether + // `callee` is a null reference or not. In some situations it can be + // statically known here that `callee` cannot be null in which case this + // null check can be elided. This requires feeding type information from + // wasmparser's validator into this function, however, which is not + // easily done at this time. + self.builder + .ins() + .trapz(callee, ir::TrapCode::NullReference); + + self.unchecked_call(sig_ref, callee, args) + } + + /// This calls a function by reference without checking the signature. + /// + /// It gets the function address, sets relevant flags, and passes the + /// special callee/caller vmctxs. It is used by both call_indirect (which + /// checks the signature) and call_ref (which doesn't). + fn unchecked_call( + &mut self, + sig_ref: ir::SigRef, + callee: ir::Value, + call_args: &[ir::Value], + ) -> WasmResult { + let pointer_type = self.env.pointer_type(); + + // Dereference callee pointer to get the function address. + let mem_flags = ir::MemFlags::trusted().with_readonly(); + let func_addr = self.builder.ins().load( + pointer_type, + mem_flags, + callee, + i32::from(self.env.offsets.ptr.vm_func_ref_wasm_call()), + ); + + let mut real_call_args = Vec::with_capacity(call_args.len() + 2); + let caller_vmctx = self + .builder + .func + .special_param(ArgumentPurpose::VMContext) + .unwrap(); + + // First append the callee vmctx address. + let vmctx = self.builder.ins().load( + pointer_type, + mem_flags, + callee, + i32::from(self.env.offsets.ptr.vm_func_ref_vmctx()), + ); + real_call_args.push(vmctx); + real_call_args.push(caller_vmctx); + + // Then append the regular call arguments. + real_call_args.extend_from_slice(call_args); + + Ok(self.indirect_call_inst(sig_ref, func_addr, &real_call_args)) + } + + fn direct_call_inst(&mut self, callee: ir::FuncRef, args: &[ir::Value]) -> ir::Inst { + if self.tail { + self.builder.ins().return_call(callee, args) + } else { + self.builder.ins().call(callee, args) + } + } + + fn indirect_call_inst( + &mut self, + sig_ref: ir::SigRef, + func_addr: ir::Value, + args: &[ir::Value], + ) -> ir::Inst { + if self.tail { + self.builder + .ins() + .return_call_indirect(sig_ref, func_addr, args) + } else { + self.builder.ins().call_indirect(sig_ref, func_addr, args) + } + } +} + impl TypeConvert for FuncEnvironment<'_> { fn lookup_heap_type(&self, ty: TypeIndex) -> WasmHeapType { self.module.lookup_heap_type(ty) @@ -1584,7 +1798,7 @@ impl<'module_environment> cranelift_wasm::FuncEnvironment for FuncEnvironment<'m index: TypeIndex, ) -> WasmResult { let index = self.module.types[index].unwrap_function(); - let sig = crate::wasm_call_signature(self.isa, &self.types[index]); + let sig = crate::wasm_call_signature(self.isa, &self.types[index], &self.tunables); Ok(func.import_signature(sig)) } @@ -1594,7 +1808,7 @@ impl<'module_environment> cranelift_wasm::FuncEnvironment for FuncEnvironment<'m index: FuncIndex, ) -> WasmResult { let sig = self.module.functions[index].signature; - let sig = crate::wasm_call_signature(self.isa, &self.types[sig]); + let sig = crate::wasm_call_signature(self.isa, &self.types[sig], &self.tunables); let signature = func.import_signature(sig); let name = ir::ExternalName::User(func.declare_imported_user_function(ir::UserExternalName { @@ -1633,112 +1847,24 @@ impl<'module_environment> cranelift_wasm::FuncEnvironment for FuncEnvironment<'m callee: ir::Value, call_args: &[ir::Value], ) -> WasmResult { - let pointer_type = self.pointer_type(); - - // Get the funcref pointer from the table. - let funcref_ptr = self.get_or_init_func_ref_table_elem(builder, table_index, table, callee); - - // Check for whether the table element is null, and trap if so. - builder - .ins() - .trapz(funcref_ptr, ir::TrapCode::IndirectCallToNull); - - // If necessary, check the signature. - match self.module.table_plans[table_index].style { - TableStyle::CallerChecksSignature => { - let sig_id_size = self.offsets.size_of_vmshared_signature_index(); - let sig_id_type = Type::int(u16::from(sig_id_size) * 8).unwrap(); - let vmctx = self.vmctx(builder.func); - let base = builder.ins().global_value(pointer_type, vmctx); - - // Load the caller ID. This requires loading the `*mut - // VMFuncRef` base pointer from `VMContext` and then loading, - // based on `SignatureIndex`, the corresponding entry. - let mem_flags = ir::MemFlags::trusted().with_readonly(); - let signatures = builder.ins().load( - pointer_type, - mem_flags, - base, - i32::try_from(self.offsets.vmctx_signature_ids_array()).unwrap(), - ); - let sig_index = self.module.types[ty_index].unwrap_function(); - let offset = - i32::try_from(sig_index.as_u32().checked_mul(sig_id_type.bytes()).unwrap()) - .unwrap(); - let caller_sig_id = builder - .ins() - .load(sig_id_type, mem_flags, signatures, offset); - - // Load the callee ID. - let mem_flags = ir::MemFlags::trusted().with_readonly(); - let callee_sig_id = builder.ins().load( - sig_id_type, - mem_flags, - funcref_ptr, - i32::from(self.offsets.ptr.vm_func_ref_type_index()), - ); - - // Check that they match. - let cmp = builder - .ins() - .icmp(IntCC::Equal, callee_sig_id, caller_sig_id); - builder.ins().trapz(cmp, ir::TrapCode::BadSignature); - } - } - - self.call_function_unchecked(builder, sig_ref, funcref_ptr, call_args) + Call::new(builder, self).indirect_call( + table_index, + table, + ty_index, + sig_ref, + callee, + call_args, + ) } fn translate_call( &mut self, - mut pos: FuncCursor<'_>, + builder: &mut FunctionBuilder, callee_index: FuncIndex, callee: ir::FuncRef, call_args: &[ir::Value], ) -> WasmResult { - let mut real_call_args = Vec::with_capacity(call_args.len() + 2); - let caller_vmctx = pos.func.special_param(ArgumentPurpose::VMContext).unwrap(); - - // Handle direct calls to locally-defined functions. - if !self.module.is_imported_function(callee_index) { - // First append the callee vmctx address, which is the same as the caller vmctx in - // this case. - real_call_args.push(caller_vmctx); - - // Then append the caller vmctx address. - real_call_args.push(caller_vmctx); - - // Then append the regular call arguments. - real_call_args.extend_from_slice(call_args); - - return Ok(pos.ins().call(callee, &real_call_args)); - } - - // Handle direct calls to imported functions. We use an indirect call - // so that we don't have to patch the code at runtime. - let pointer_type = self.pointer_type(); - let sig_ref = pos.func.dfg.ext_funcs[callee].signature; - let vmctx = self.vmctx(&mut pos.func); - let base = pos.ins().global_value(pointer_type, vmctx); - - let mem_flags = ir::MemFlags::trusted().with_readonly(); - - // Load the callee address. - let body_offset = - i32::try_from(self.offsets.vmctx_vmfunction_import_wasm_call(callee_index)).unwrap(); - let func_addr = pos.ins().load(pointer_type, mem_flags, base, body_offset); - - // First append the callee vmctx address. - let vmctx_offset = - i32::try_from(self.offsets.vmctx_vmfunction_import_vmctx(callee_index)).unwrap(); - let vmctx = pos.ins().load(pointer_type, mem_flags, base, vmctx_offset); - real_call_args.push(vmctx); - real_call_args.push(caller_vmctx); - - // Then append the regular call arguments. - real_call_args.extend_from_slice(call_args); - - Ok(pos.ins().call_indirect(sig_ref, func_addr, &real_call_args)) + Call::new(builder, self).direct_call(callee_index, callee, call_args) } fn translate_call_ref( @@ -1748,40 +1874,50 @@ impl<'module_environment> cranelift_wasm::FuncEnvironment for FuncEnvironment<'m callee: ir::Value, call_args: &[ir::Value], ) -> WasmResult { - // Check for whether the callee is null, and trap if so. - // - // FIXME: the wasm type system tracks enough information to know whether - // `callee` is a null reference or not. In some situations it can be - // statically known here that `callee` cannot be null in which case this - // null check can be elided. This requires feeding type information from - // wasmparser's validator into this function, however, which is not - // easily done at this time. - builder.ins().trapz(callee, ir::TrapCode::NullReference); - - self.call_function_unchecked(builder, sig_ref, callee, call_args) + Call::new(builder, self).call_ref(sig_ref, callee, call_args) } fn translate_return_call( &mut self, - _pos: FuncCursor, - _callee_index: FuncIndex, - _callee: ir::FuncRef, - _call_args: &[ir::Value], + builder: &mut FunctionBuilder, + callee_index: FuncIndex, + callee: ir::FuncRef, + call_args: &[ir::Value], ) -> WasmResult<()> { - unimplemented!() + Call::new_tail(builder, self).direct_call(callee_index, callee, call_args)?; + Ok(()) } fn translate_return_call_indirect( &mut self, - _builder: &mut FunctionBuilder, - _table_index: TableIndex, - _table: ir::Table, - _sig_index: TypeIndex, - _sig_ref: ir::SigRef, - _callee: ir::Value, - _call_args: &[ir::Value], + builder: &mut FunctionBuilder, + table_index: TableIndex, + table: ir::Table, + ty_index: TypeIndex, + sig_ref: ir::SigRef, + callee: ir::Value, + call_args: &[ir::Value], + ) -> WasmResult<()> { + Call::new_tail(builder, self).indirect_call( + table_index, + table, + ty_index, + sig_ref, + callee, + call_args, + )?; + Ok(()) + } + + fn translate_return_call_ref( + &mut self, + builder: &mut FunctionBuilder, + sig_ref: ir::SigRef, + callee: ir::Value, + call_args: &[ir::Value], ) -> WasmResult<()> { - unimplemented!() + Call::new_tail(builder, self).call_ref(sig_ref, callee, call_args)?; + Ok(()) } fn translate_memory_grow( diff --git a/crates/cranelift/src/lib.rs b/crates/cranelift/src/lib.rs index 87adcf182f07..439437261bce 100644 --- a/crates/cranelift/src/lib.rs +++ b/crates/cranelift/src/lib.rs @@ -11,6 +11,7 @@ use target_lexicon::Architecture; use wasmtime_cranelift_shared::CompiledFunctionMetadata; pub use builder::builder; +use wasmtime_environ::Tunables; mod builder; mod compiler; @@ -125,7 +126,11 @@ fn array_call_signature(isa: &dyn TargetIsa) -> ir::Signature { } /// Get the internal Wasm calling convention signature for the given type. -fn wasm_call_signature(isa: &dyn TargetIsa, wasm_func_ty: &WasmFuncType) -> ir::Signature { +fn wasm_call_signature( + isa: &dyn TargetIsa, + wasm_func_ty: &WasmFuncType, + tunables: &Tunables, +) -> ir::Signature { // NB: this calling convention in the near future is expected to be // unconditionally switched to the "tail" calling convention once all // platforms have support for tail calls. @@ -136,6 +141,18 @@ fn wasm_call_signature(isa: &dyn TargetIsa, wasm_func_ty: &WasmFuncType) -> ir:: // operates through trampolines either using the `array_call_signature` or // `native_call_signature` where the default platform ABI is used. let call_conv = match isa.triple().architecture { + // If the tail calls proposal is enabled, we must use the tail calling + // convention. We don't use it by default yet because of + // https://github.com/bytecodealliance/wasmtime/issues/6759 + arch if tunables.tail_callable => { + assert_ne!( + arch, + Architecture::S390x, + "https://github.com/bytecodealliance/wasmtime/issues/6530" + ); + CallConv::Tail + } + // On s390x the "wasmtime" calling convention is used to give vectors // little-endian lane order at the ABI layer which should reduce the // need for conversion when operating on vector function arguments. By diff --git a/crates/environ/src/compilation.rs b/crates/environ/src/compilation.rs index 84fe3559dee3..9b7964327889 100644 --- a/crates/environ/src/compilation.rs +++ b/crates/environ/src/compilation.rs @@ -1,10 +1,10 @@ //! A `Compilation` contains the compiled function bodies for a WebAssembly //! module. -use crate::obj; +use crate::{obj, Tunables}; use crate::{ DefinedFuncIndex, FilePos, FuncIndex, FunctionBodyData, ModuleTranslation, ModuleTypes, - PrimaryMap, StackMap, Tunables, WasmError, WasmFuncType, + PrimaryMap, StackMap, WasmError, WasmFuncType, }; use anyhow::Result; use object::write::{Object, SymbolId}; @@ -122,6 +122,9 @@ pub trait CompilerBuilder: Send + Sync + fmt::Debug { /// This will return an error if the compiler does not support incremental compilation. fn enable_incremental_compilation(&mut self, cache_store: Arc) -> Result<()>; + /// Set the tunables for this compiler. + fn set_tunables(&mut self, tunables: Tunables) -> Result<()>; + /// Builds a new [`Compiler`] object from this configuration. fn build(&self) -> Result>; } @@ -174,7 +177,6 @@ pub trait Compiler: Send + Sync { translation: &ModuleTranslation<'_>, index: DefinedFuncIndex, data: FunctionBodyData<'_>, - tunables: &Tunables, types: &ModuleTypes, ) -> Result<(WasmFunctionInfo, Box), CompileError>; @@ -243,7 +245,6 @@ pub trait Compiler: Send + Sync { &self, obj: &mut Object<'static>, funcs: &[(String, Box)], - tunables: &Tunables, resolve_reloc: &dyn Fn(usize, FuncIndex) -> usize, ) -> Result>; diff --git a/crates/environ/src/tunables.rs b/crates/environ/src/tunables.rs index 97a9d0802ff6..49b21fe24897 100644 --- a/crates/environ/src/tunables.rs +++ b/crates/environ/src/tunables.rs @@ -49,6 +49,9 @@ pub struct Tunables { /// Whether or not lowerings for relaxed simd instructions are forced to /// be deterministic. pub relaxed_simd_deterministic: bool, + + /// Whether or not Wasm functions can be tail-called or not. + pub tail_callable: bool, } impl Default for Tunables { @@ -106,6 +109,7 @@ impl Default for Tunables { generate_address_map: true, debug_adapter_modules: false, relaxed_simd_deterministic: false, + tail_callable: false, } } } diff --git a/crates/fuzzing/src/generators/config.rs b/crates/fuzzing/src/generators/config.rs index f54d1fdbc97a..f65e22c62f6f 100644 --- a/crates/fuzzing/src/generators/config.rs +++ b/crates/fuzzing/src/generators/config.rs @@ -152,6 +152,7 @@ impl Config { .wasm_multi_memory(self.module_config.config.max_memories > 1) .wasm_simd(self.module_config.config.simd_enabled) .wasm_memory64(self.module_config.config.memory64_enabled) + .wasm_tail_call(self.module_config.config.tail_call_enabled) .wasm_threads(self.module_config.config.threads_enabled) .native_unwind_info(self.wasmtime.native_unwind_info) .cranelift_nan_canonicalization(self.wasmtime.canonicalize_nans) diff --git a/crates/wasmtime/src/compiler.rs b/crates/wasmtime/src/compiler.rs index de0f693847b0..33856f4b2198 100644 --- a/crates/wasmtime/src/compiler.rs +++ b/crates/wasmtime/src/compiler.rs @@ -32,8 +32,7 @@ use wasmtime_environ::{ }; use wasmtime_jit::{CompiledFunctionInfo, CompiledModuleInfo}; -type CompileInput<'a> = - Box Result + Send + 'a>; +type CompileInput<'a> = Box Result + Send + 'a>; /// A sortable, comparable key for a compilation output. /// @@ -168,10 +167,7 @@ pub struct CompileInputs<'a> { } impl<'a> CompileInputs<'a> { - fn push_input( - &mut self, - f: impl FnOnce(&Tunables, &dyn Compiler) -> Result + Send + 'a, - ) { + fn push_input(&mut self, f: impl FnOnce(&dyn Compiler) -> Result + Send + 'a) { self.inputs.push(Box::new(f)); } @@ -207,7 +203,7 @@ impl<'a> CompileInputs<'a> { ret.collect_inputs_in_translations(types.module_types(), module_translations); for (idx, trampoline) in component.trampolines.iter() { - ret.push_input(move |_tunables, compiler| { + ret.push_input(move |compiler| { Ok(CompileOutput { key: CompileKey::trampoline(idx), symbol: trampoline.symbol_name(), @@ -228,7 +224,7 @@ impl<'a> CompileInputs<'a> { // requested through initializers above or such. if component.component.num_resources > 0 { if let Some(sig) = types.find_resource_drop_signature() { - ret.push_input(move |_tunables, compiler| { + ret.push_input(move |compiler| { let trampoline = compiler.compile_wasm_to_native_trampoline(&types[sig])?; Ok(CompileOutput { key: CompileKey::resource_drop_wasm_to_native_trampoline(), @@ -258,15 +254,10 @@ impl<'a> CompileInputs<'a> { for (module, translation, functions) in translations { for (def_func_index, func_body) in functions { - self.push_input(move |tunables, compiler| { + self.push_input(move |compiler| { let func_index = translation.module.func_index(def_func_index); - let (info, function) = compiler.compile_function( - translation, - def_func_index, - func_body, - tunables, - types, - )?; + let (info, function) = + compiler.compile_function(translation, def_func_index, func_body, types)?; Ok(CompileOutput { key: CompileKey::wasm_function(module, def_func_index), symbol: format!( @@ -281,7 +272,7 @@ impl<'a> CompileInputs<'a> { let func_index = translation.module.func_index(def_func_index); if translation.module.functions[func_index].is_escaping() { - self.push_input(move |_tunables, compiler| { + self.push_input(move |compiler| { let func_index = translation.module.func_index(def_func_index); let trampoline = compiler.compile_array_to_wasm_trampoline( translation, @@ -300,7 +291,7 @@ impl<'a> CompileInputs<'a> { }) }); - self.push_input(move |_tunables, compiler| { + self.push_input(move |compiler| { let func_index = translation.module.func_index(def_func_index); let trampoline = compiler.compile_native_to_wasm_trampoline( translation, @@ -327,7 +318,7 @@ impl<'a> CompileInputs<'a> { } for signature in sigs { - self.push_input(move |_tunables, compiler| { + self.push_input(move |compiler| { let wasm_func_ty = &types[signature]; let trampoline = compiler.compile_wasm_to_native_trampoline(wasm_func_ty)?; Ok(CompileOutput { @@ -346,11 +337,10 @@ impl<'a> CompileInputs<'a> { /// Compile these `CompileInput`s (maybe in parallel) and return the /// resulting `UnlinkedCompileOutput`s. pub fn compile(self, engine: &Engine) -> Result { - let tunables = &engine.config().tunables; let compiler = engine.compiler(); // Compile each individual input in parallel. - let raw_outputs = engine.run_maybe_parallel(self.inputs, |f| f(tunables, compiler))?; + let raw_outputs = engine.run_maybe_parallel(self.inputs, |f| f(compiler))?; // Bucket the outputs by kind. let mut outputs: BTreeMap> = BTreeMap::new(); @@ -470,7 +460,6 @@ impl FunctionIndices { let symbol_ids_and_locs = compiler.append_code( &mut obj, &compiled_funcs, - tunables, &|caller_index: usize, callee_index: FuncIndex| { let module = self .compiled_func_index_to_module diff --git a/crates/wasmtime/src/component/component.rs b/crates/wasmtime/src/component/component.rs index 81548bcfcc53..0d3af1916c75 100644 --- a/crates/wasmtime/src/component/component.rs +++ b/crates/wasmtime/src/component/component.rs @@ -205,7 +205,7 @@ impl Component { let (mut object, compilation_artifacts) = function_indices.link_and_append_code( object, - tunables, + &engine.config().tunables, compiler, compiled_funcs, module_translations, diff --git a/crates/wasmtime/src/config.rs b/crates/wasmtime/src/config.rs index 42e88137535d..de6514e300e5 100644 --- a/crates/wasmtime/src/config.rs +++ b/crates/wasmtime/src/config.rs @@ -1,6 +1,6 @@ use crate::memory::MemoryCreator; use crate::trampoline::MemoryCreatorProxy; -use anyhow::{bail, Result}; +use anyhow::{bail, ensure, Result}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::fmt; @@ -622,6 +622,23 @@ impl Config { self } + /// Configures whether the WebAssembly tail calls proposal will be enabled + /// for compilation or not. + /// + /// The [WebAssembly tail calls proposal] introduces the `return_call` and + /// `return_call_indirect` instructions. These instructions allow for Wasm + /// programs to implement some recursive algorithms with *O(1)* stack space + /// usage. + /// + /// This feature is disabled by default. + /// + /// [WebAssembly tail calls proposal]: https://github.com/WebAssembly/tail-call + pub fn wasm_tail_call(&mut self, enable: bool) -> &mut Self { + self.features.tail_call = enable; + self.tunables.tail_callable = enable; + self + } + /// Configures whether the WebAssembly threads proposal will be enabled for /// compilation. /// @@ -1592,6 +1609,14 @@ impl Config { .insert("enable_probestack".into()); } + if self.features.tail_call { + ensure!( + target.architecture != Architecture::S390x, + "Tail calls are not supported on s390x yet: \ + https://github.com/bytecodealliance/wasmtime/issues/6530" + ); + } + if self.native_unwind_info || // Windows always needs unwind info, since it is part of the ABI. target.operating_system == target_lexicon::OperatingSystem::Windows @@ -1637,6 +1662,8 @@ impl Config { compiler.enable_incremental_compilation(cache_store.clone())?; } + compiler.set_tunables(self.tunables.clone())?; + Ok((self, compiler.build()?)) } diff --git a/crates/wasmtime/src/engine/serialization.rs b/crates/wasmtime/src/engine/serialization.rs index 7f3cb2fdf927..31d87765c346 100644 --- a/crates/wasmtime/src/engine/serialization.rs +++ b/crates/wasmtime/src/engine/serialization.rs @@ -162,6 +162,7 @@ struct WasmFeatures { bulk_memory: bool, component_model: bool, simd: bool, + tail_call: bool, threads: bool, multi_memory: bool, exceptions: bool, @@ -200,7 +201,6 @@ impl Metadata { } = engine.config().features; assert!(!memory_control); - assert!(!tail_call); assert!(!gc); assert!(!component_model_values); @@ -216,6 +216,7 @@ impl Metadata { component_model, simd, threads, + tail_call, multi_memory, exceptions, memory64, @@ -315,6 +316,7 @@ impl Metadata { static_memory_bound_is_maximum, guard_before_linear_memory, relaxed_simd_deterministic, + tail_callable, // This doesn't affect compilation, it's just a runtime setting. dynamic_memory_growth_reserve: _, @@ -375,6 +377,7 @@ impl Metadata { other.relaxed_simd_deterministic, "relaxed simd deterministic semantics", )?; + Self::check_bool(tail_callable, other.tail_callable, "WebAssembly tail calls")?; Ok(()) } @@ -386,6 +389,7 @@ impl Metadata { bulk_memory, component_model, simd, + tail_call, threads, multi_memory, exceptions, @@ -416,6 +420,7 @@ impl Metadata { "WebAssembly component model support", )?; Self::check_bool(simd, other.simd, "WebAssembly SIMD support")?; + Self::check_bool(tail_call, other.tail_call, "WebAssembly tail calls support")?; Self::check_bool(threads, other.threads, "WebAssembly threads support")?; Self::check_bool( multi_memory, diff --git a/crates/winch/src/builder.rs b/crates/winch/src/builder.rs index d8f17d5bae6a..d9d534cc7a74 100644 --- a/crates/winch/src/builder.rs +++ b/crates/winch/src/builder.rs @@ -38,6 +38,11 @@ impl CompilerBuilder for Builder { self.inner.settings() } + fn set_tunables(&mut self, tunables: wasmtime_environ::Tunables) -> Result<()> { + let _ = tunables; + Ok(()) + } + fn build(&self) -> Result> { let isa = self.inner.build()?; diff --git a/crates/winch/src/compiler.rs b/crates/winch/src/compiler.rs index 3137038a6aac..73605c2c3761 100644 --- a/crates/winch/src/compiler.rs +++ b/crates/winch/src/compiler.rs @@ -6,7 +6,7 @@ use wasmparser::FuncValidatorAllocations; use wasmtime_cranelift_shared::{CompiledFunction, ModuleTextBuilder}; use wasmtime_environ::{ CompileError, DefinedFuncIndex, FilePos, FuncIndex, FunctionBodyData, FunctionLoc, - ModuleTranslation, ModuleTypes, PrimaryMap, TrapEncodingBuilder, Tunables, WasmFunctionInfo, + ModuleTranslation, ModuleTypes, PrimaryMap, TrapEncodingBuilder, WasmFunctionInfo, }; use winch_codegen::{TargetIsa, TrampolineKind}; @@ -53,7 +53,6 @@ impl wasmtime_environ::Compiler for Compiler { translation: &ModuleTranslation<'_>, index: DefinedFuncIndex, data: FunctionBodyData<'_>, - _tunables: &Tunables, types: &ModuleTypes, ) -> Result<(WasmFunctionInfo, Box), CompileError> { let index = translation.module.func_index(index); @@ -144,7 +143,6 @@ impl wasmtime_environ::Compiler for Compiler { &self, obj: &mut Object<'static>, funcs: &[(String, Box)], - _tunables: &Tunables, resolve_reloc: &dyn Fn(usize, FuncIndex) -> usize, ) -> Result> { let mut builder = diff --git a/tests/all/gc.rs b/tests/all/gc.rs index 096a3bec2952..c8ef7eaf1fbf 100644 --- a/tests/all/gc.rs +++ b/tests/all/gc.rs @@ -496,3 +496,158 @@ fn no_gc_middle_of_args() -> anyhow::Result<()> { Ok(()) } + +#[test] +#[cfg_attr(any( + miri, + // TODO(6530): s390x doesn't support tail calls yet. + target_arch = "s390x" +), ignore)] +fn gc_and_tail_calls_and_stack_arguments() -> anyhow::Result<()> { + // Test that GC refs in tail-calls' stack arguments get properly accounted + // for in stack maps. + // + // What we do _not_ want to happen is for tail callers to be responsible for + // including stack arguments in their stack maps (and therefore whether or + // not they get marked at runtime). If that was the case, then we could have + // the following scenario: + // + // * `f` calls `g` without any stack arguments, + // * `g` tail calls `h` with GC ref stack arguments, + // * and then `h` triggers a GC. + // + // Because `g`, who is responsible for including the GC refs in its stack + // map in this hypothetical scenario, is no longer on the stack, we never + // see its stack map, and therefore never mark the GC refs, and then we + // collect them too early, and then we can get user-after-free bugs. Not + // good! Note also that `f`, which is the frame that `h` will return to, + // _cannot_ be responsible for including these stack arguments in its stack + // map, because it has no idea what frame will be returning to it, and it + // could be any number of different functions using that frame for long (and + // indirect!) tail-call chains. + // + // In Cranelift we avoid this scenario because stack arguments are eagerly + // loaded into virtual registers, and then when we insert a GC safe point, + // we spill these virtual registers to the callee stack frame, and the stack + // map includes entries for these stack slots. + // + // Nonetheless, this test exercises the above scenario just in case we do + // something in the future like lazily load stack arguments into virtual + // registers, to make sure that everything shows up in stack maps like they + // are supposed to. + + let (mut store, module) = ref_types_module( + false, + r#" + (module + (import "" "make_some" (func $make (result externref externref externref))) + (import "" "take_some" (func $take (param externref externref externref))) + (import "" "gc" (func $gc)) + + (func $stack_args (param externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref externref) + call $gc + ;; Make sure all these GC refs are live, so that they need to + ;; be put into the stack map. + local.get 0 + local.get 1 + local.get 2 + call $take + local.get 3 + local.get 4 + local.get 5 + call $take + local.get 6 + local.get 7 + local.get 8 + call $take + local.get 9 + local.get 10 + local.get 11 + call $take + local.get 12 + local.get 13 + local.get 14 + call $take + local.get 15 + local.get 16 + local.get 17 + call $take + local.get 18 + local.get 19 + local.get 20 + call $take + local.get 21 + local.get 22 + local.get 23 + call $take + local.get 24 + local.get 25 + local.get 26 + call $take + local.get 27 + local.get 28 + local.get 29 + call $take + ) + + (func $no_stack_args + call $make + call $make + call $make + call $make + call $make + call $make + call $make + call $make + call $make + call $make + return_call $stack_args + ) + + (func (export "run") + (local i32) + i32.const 1000 + local.set 0 + loop + call $no_stack_args + local.get 0 + i32.const -1 + i32.add + local.tee 0 + br_if 0 + end + ) + ) + "#, + )?; + + let mut linker = Linker::new(store.engine()); + linker.func_wrap("", "make_some", || { + ( + Some(ExternRef::new("a".to_string())), + Some(ExternRef::new("b".to_string())), + Some(ExternRef::new("c".to_string())), + ) + })?; + linker.func_wrap( + "", + "take_some", + |a: Option, b: Option, c: Option| { + let a = a.unwrap(); + let b = b.unwrap(); + let c = c.unwrap(); + assert_eq!(a.data().downcast_ref::().unwrap(), "a"); + assert_eq!(b.data().downcast_ref::().unwrap(), "b"); + assert_eq!(c.data().downcast_ref::().unwrap(), "c"); + }, + )?; + linker.func_wrap("", "gc", |mut caller: Caller<()>| { + caller.gc(); + })?; + + let instance = linker.instantiate(&mut store, &module)?; + let func = instance.get_typed_func::<(), ()>(&mut store, "run")?; + func.call(&mut store, ())?; + + Ok(()) +} diff --git a/tests/all/main.rs b/tests/all/main.rs index e9941cc73710..55d425bed125 100644 --- a/tests/all/main.rs +++ b/tests/all/main.rs @@ -51,6 +51,12 @@ pub(crate) fn ref_types_module( let mut config = Config::new(); config.wasm_reference_types(true); + + if !cfg!(target_arch = "s390x") { + // TODO(6530): s390x doesn't support tail calls yet. + config.wasm_tail_call(true); + } + if use_epochs { config.epoch_interruption(true); } diff --git a/tests/all/wast.rs b/tests/all/wast.rs index b43c17c2b8d2..1dbd4371d7cf 100644 --- a/tests/all/wast.rs +++ b/tests/all/wast.rs @@ -27,6 +27,7 @@ fn run_wast(wast: &str, strategy: Strategy, pooling: bool) -> anyhow::Result<()> let function_references = feature_found(wast, "function-references"); let reference_types = !(threads && feature_found(wast, "proposals")); let relaxed_simd = feature_found(wast, "relaxed-simd"); + let tail_call = feature_found(wast, "tail-call") || feature_found(wast, "function-references"); let use_shared_memory = feature_found_src(&wast_bytes, "shared_memory") || feature_found_src(&wast_bytes, "shared)"); @@ -47,6 +48,7 @@ fn run_wast(wast: &str, strategy: Strategy, pooling: bool) -> anyhow::Result<()> .wasm_function_references(function_references) .wasm_reference_types(reference_types) .wasm_relaxed_simd(relaxed_simd) + .wasm_tail_call(tail_call) .strategy(strategy); if is_cranelift { diff --git a/tests/misc_testsuite/tail-call/loop-across-modules.wast b/tests/misc_testsuite/tail-call/loop-across-modules.wast new file mode 100644 index 000000000000..1b4122cdaa31 --- /dev/null +++ b/tests/misc_testsuite/tail-call/loop-across-modules.wast @@ -0,0 +1,43 @@ +;; Do the following loop: `A.f` indirect tail calls through the table, which is +;; populated by `B.start` to contain `B.g`, which in turn tail calls `A.f` and +;; the loop begins again. +;; +;; This is smoke testing that tail call chains across Wasm modules really do +;; have O(1) stack usage. + +(module $A + (type (func (param i32) (result i32))) + + (table (export "table") 1 1 funcref) + + (func (export "f") (param i32) (result i32) + local.get 0 + i32.eqz + if + (return (i32.const 42)) + else + (i32.sub (local.get 0) (i32.const 1)) + i32.const 0 + return_call_indirect (type 0) + end + unreachable + ) +) + +(module $B + (import "A" "table" (table $table 1 1 funcref)) + (import "A" "f" (func $f (param i32) (result i32))) + + (func $g (export "g") (param i32) (result i32) + local.get 0 + return_call $f + ) + + (func $start + (table.set $table (i32.const 0) (ref.func $g)) + ) + (start $start) +) + +(assert_return (invoke $B "g" (i32.const 100000000)) + (i32.const 42))