Skip to content

Commit

Permalink
riscv64: Implement SIMD swizzle and shuffle (bytecodealliance#6515)
Browse files Browse the repository at this point in the history
* riscv64: Implement SIMD `swizzle`

* riscv64: Implement SIMD `shuffle`

* wasmtime: Enable more RISC-V SIMD tests

* riscv64: Add TODO issue numbers

* riscv64: Fix trailing newline issues
  • Loading branch information
afonso360 authored Jun 6, 2023
1 parent 81cd998 commit 579918c
Show file tree
Hide file tree
Showing 12 changed files with 339 additions and 16 deletions.
1 change: 0 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"simd_i8x16_arith2",
"simd_i8x16_cmp",
"simd_int_to_int_extend",
"simd_lane",
"simd_load",
"simd_load_extend",
"simd_load_zero",
Expand Down
8 changes: 7 additions & 1 deletion cranelift/codegen/src/isa/riscv64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1566,13 +1566,19 @@
(extern constructor imm5_from_i8 imm5_from_i8)

;; Extractor that matches a `Value` equivalent to a replicated Imm5 on all lanes.
;; TODO: Try matching vconst here as well
;; TODO(#6527): Try matching vconst here as well
(decl replicated_imm5 (Imm5) Value)
(extractor (replicated_imm5 n)
(def_inst (splat (iconst (u64_from_imm64 (imm5_from_u64 n))))))

;; UImm5 Helpers

;; Extractor that matches a `Value` equivalent to a replicated UImm5 on all lanes.
;; TODO(#6527): Try matching vconst here as well
(decl replicated_uimm5 (UImm5) Value)
(extractor (replicated_uimm5 n)
(def_inst (splat (uimm5_from_value n))))

;; Helper to go directly from a `Value`, when it's an `iconst`, to an `UImm5`.
(decl uimm5_from_value (UImm5) Value)
(extractor (uimm5_from_value n)
Expand Down
28 changes: 25 additions & 3 deletions cranelift/codegen/src/isa/riscv64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,17 +654,39 @@ fn riscv64_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut Operan

collector.reg_use(vs1);
collector.reg_use(vs2);
collector.reg_def(vd);

// If the operation forbids source/destination overlap, then we must
// register it as an early_def. This encodes the constraint that
// these must not overlap.
if op.forbids_src_dst_overlaps() {
collector.reg_early_def(vd);
} else {
collector.reg_def(vd);
}

vec_mask_operands(mask, collector);
}
&Inst::VecAluRRImm5 {
vd, vs2, ref mask, ..
op,
vd,
vs2,
ref mask,
..
} => {
debug_assert_eq!(vd.to_reg().class(), RegClass::Vector);
debug_assert_eq!(vs2.class(), RegClass::Vector);

collector.reg_use(vs2);
collector.reg_def(vd);

// If the operation forbids source/destination overlap, then we must
// register it as an early_def. This encodes the constraint that
// these must not overlap.
if op.forbids_src_dst_overlaps() {
collector.reg_early_def(vd);
} else {
collector.reg_def(vd);
}

vec_mask_operands(mask, collector);
}
&Inst::VecAluRR {
Expand Down
30 changes: 26 additions & 4 deletions cranelift/codegen/src/isa/riscv64/inst/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ impl VecAluOpRRR {
VecAluOpRRR::VssubuVV | VecAluOpRRR::VssubuVX => 0b100010,
VecAluOpRRR::VssubVV | VecAluOpRRR::VssubVX => 0b100011,
VecAluOpRRR::VfsgnjnVV => 0b001001,
VecAluOpRRR::VrgatherVV | VecAluOpRRR::VrgatherVX => 0b001100,
VecAluOpRRR::VmsltVX => 0b011011,
}
}
Expand All @@ -318,7 +319,8 @@ impl VecAluOpRRR {
| VecAluOpRRR::VminVV
| VecAluOpRRR::VmaxuVV
| VecAluOpRRR::VmaxVV
| VecAluOpRRR::VmergeVVM => VecOpCategory::OPIVV,
| VecAluOpRRR::VmergeVVM
| VecAluOpRRR::VrgatherVV => VecOpCategory::OPIVV,
VecAluOpRRR::VmulVV
| VecAluOpRRR::VmulhVV
| VecAluOpRRR::VmulhuVV
Expand All @@ -343,7 +345,8 @@ impl VecAluOpRRR {
| VecAluOpRRR::VmaxVX
| VecAluOpRRR::VslidedownVX
| VecAluOpRRR::VmergeVXM
| VecAluOpRRR::VmsltVX => VecOpCategory::OPIVX,
| VecAluOpRRR::VmsltVX
| VecAluOpRRR::VrgatherVX => VecOpCategory::OPIVX,
VecAluOpRRR::VfaddVV
| VecAluOpRRR::VfsubVV
| VecAluOpRRR::VfmulVV
Expand All @@ -368,6 +371,14 @@ impl VecAluOpRRR {
_ => unreachable!(),
}
}

/// Some instructions do not allow the source and destination registers to overlap.
pub fn forbids_src_dst_overlaps(&self) -> bool {
match self {
VecAluOpRRR::VrgatherVV | VecAluOpRRR::VrgatherVX => true,
_ => false,
}
}
}

impl fmt::Display for VecAluOpRRR {
Expand Down Expand Up @@ -408,6 +419,7 @@ impl VecAluOpRRImm5 {
VecAluOpRRImm5::VmergeVIM => 0b010111,
VecAluOpRRImm5::VsadduVI => 0b100000,
VecAluOpRRImm5::VsaddVI => 0b100001,
VecAluOpRRImm5::VrgatherVI => 0b001100,
}
}

Expand All @@ -424,7 +436,8 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VslidedownVI
| VecAluOpRRImm5::VmergeVIM
| VecAluOpRRImm5::VsadduVI
| VecAluOpRRImm5::VsaddVI => VecOpCategory::OPIVI,
| VecAluOpRRImm5::VsaddVI
| VecAluOpRRImm5::VrgatherVI => VecOpCategory::OPIVI,
}
}

Expand All @@ -433,7 +446,8 @@ impl VecAluOpRRImm5 {
VecAluOpRRImm5::VsllVI
| VecAluOpRRImm5::VsrlVI
| VecAluOpRRImm5::VsraVI
| VecAluOpRRImm5::VslidedownVI => true,
| VecAluOpRRImm5::VslidedownVI
| VecAluOpRRImm5::VrgatherVI => true,
VecAluOpRRImm5::VaddVI
| VecAluOpRRImm5::VrsubVI
| VecAluOpRRImm5::VandVI
Expand All @@ -444,6 +458,14 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VsaddVI => false,
}
}

/// Some instructions do not allow the source and destination registers to overlap.
pub fn forbids_src_dst_overlaps(&self) -> bool {
match self {
VecAluOpRRImm5::VrgatherVI => true,
_ => false,
}
}
}

impl fmt::Display for VecAluOpRRImm5 {
Expand Down
40 changes: 39 additions & 1 deletion cranelift/codegen/src/isa/riscv64/inst_vector.isle
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
(VmergeVVM)
(VredmaxuVS)
(VredminuVS)
(VrgatherVV)

;; Vector-Scalar Opcodes
(VaddVX)
Expand Down Expand Up @@ -145,6 +146,7 @@
(VfrdivVF)
(VmergeVXM)
(VfmergeVFM)
(VrgatherVX)
(VmsltVX)
))

Expand All @@ -163,6 +165,7 @@
(VxorVI)
(VslidedownVI)
(VmergeVIM)
(VrgatherVI)
))

;; Imm only ALU Ops
Expand Down Expand Up @@ -718,6 +721,25 @@
(rule (rv_vredmaxu_vs vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VredmaxuVS) vs2 vs1 mask vstate))

;; Helper for emitting the `vrgather.vv` instruction.
;;
;; vd[i] = (vs1[i] >= VLMAX) ? 0 : vs2[vs1[i]];
(decl rv_vrgather_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vrgather_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VrgatherVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vrgather.vx` instruction.
;;
;; vd[i] = (x[rs1] >= VLMAX) ? 0 : vs2[x[rs1]]
(decl rv_vrgather_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vrgather_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VrgatherVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vrgather.vi` instruction.
(decl rv_vrgather_vi (VReg UImm5 VecOpMasking VState) VReg)
(rule (rv_vrgather_vi vs2 imm mask vstate)
(vec_alu_rr_uimm5 (VecAluOpRRImm5.VrgatherVI) vs2 imm mask vstate))

;; Helper for emitting the `vmslt.vx` (Vector Mask Set Less Than) instruction.
(decl rv_vmslt_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmslt_vx vs2 vs1 mask vstate)
Expand Down Expand Up @@ -757,4 +779,20 @@
;; Materialize the mask into an X register, and move it into the bottom of
;; the vector register.
(rule (gen_vec_mask mask)
(rv_vmv_sx (imm $I64 mask) (vstate_from_type $I64X2)))
(rv_vmv_sx (imm $I64 mask) (vstate_from_type $I64X2)))


;; Loads a `VCodeConstant` value into a vector register. For some special `VCodeConstant`s
;; we can use a dedicated instruction, otherwise we load the value from the pool.
;;
;; Type is the preferred type to use when loading the constant.
(decl gen_constant (Type VCodeConstant) VReg)

;; The fallback case is to load the constant from the pool.
(rule (gen_constant ty n)
(vec_load
(element_width_from_type ty)
(VecAMode.UnitStride (gen_const_amode n))
(mem_flags_trusted)
(unmasked)
ty))
33 changes: 27 additions & 6 deletions cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
;; ;;;; Rules for `vconst` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule (lower (has_type (ty_vec_fits_in_register ty) (vconst n)))
(vec_load
(element_width_from_type ty)
(VecAMode.UnitStride (gen_const_amode (const_to_vconst n)))
(mem_flags_trusted)
(unmasked)
ty))
(gen_constant ty (const_to_vconst n)))

;;;; Rules for `f32const` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

Expand Down Expand Up @@ -1407,3 +1402,29 @@
;; use the original type as a VState and avoid a state change.
(x_mask XReg (rv_vmv_xs mask (vstate_from_type $I64X2))))
(gen_andi x_mask (ty_lane_mask ty))))

;;;; Rules for `swizzle` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_vec_fits_in_register ty) (swizzle x y)))
(rv_vrgather_vv x y (unmasked) ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (swizzle x (splat y))))
(rv_vrgather_vx x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (swizzle x (replicated_uimm5 y))))
(rv_vrgather_vi x y (unmasked) ty))

;;;; Rules for `shuffle` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; Use a vrgather to load all 0-15 lanes from x. And then modify the mask to load all
;; 16-31 lanes from y. Finally, use a vor to combine the two vectors.
;;
;; vrgather will insert a 0 for lanes that are out of bounds, so we can let it load
;; negative and out of bounds indexes.
(rule (lower (has_type (ty_vec_fits_in_register ty @ $I8X16) (shuffle x y (vconst_from_immediate mask))))
(if-let neg16 (imm5_from_i8 -16))
(let ((x_mask VReg (gen_constant ty mask))
(x_lanes VReg (rv_vrgather_vv x x_mask (unmasked) ty))
(y_mask VReg (rv_vadd_vi x_mask neg16 (unmasked) ty))
(y_lanes VReg (rv_vrgather_vv y y_mask (unmasked) ty)))
(rv_vor_vv x_lanes y_lanes (unmasked) ty)))
7 changes: 7 additions & 0 deletions cranelift/codegen/src/machinst/isle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,13 @@ macro_rules! isle_lower_prelude_methods {
Some(u128::from_le_bytes(bytes.try_into().ok()?))
}

#[inline]
fn vconst_from_immediate(&mut self, imm: Immediate) -> Option<VCodeConstant> {
Some(self.lower_ctx.use_constant(VCodeConstantData::Generated(
self.lower_ctx.get_immediate_data(imm).clone(),
)))
}

#[inline]
fn vec_mask_from_immediate(&mut self, imm: Immediate) -> Option<VecMask> {
let data = self.lower_ctx.get_immediate_data(imm);
Expand Down
5 changes: 5 additions & 0 deletions cranelift/codegen/src/prelude_lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,11 @@
(decl u128_from_immediate (u128) Immediate)
(extern extractor u128_from_immediate u128_from_immediate)

;; Extracts an `Immediate` as a `VCodeConstant`.

(decl vconst_from_immediate (VCodeConstant) Immediate)
(extern extractor vconst_from_immediate vconst_from_immediate)

;; Accessor for `Constant` as u128.

(decl u128_from_constant (u128) Constant)
Expand Down
61 changes: 61 additions & 0 deletions cranelift/filetests/filetests/isa/riscv64/simd-shuffle.clif
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
test compile precise-output
set unwind_info=false
target riscv64 has_v

function %shuffle_i8x16(i8x16, i8x16) -> i8x16 {
block0(v0: i8x16, v1: i8x16):
v2 = shuffle v0, v1, [3 0 31 26 4 6 12 11 23 13 24 4 2 15 17 5]
return v2
}

; VCode:
; add sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v1,16(fp) #avl=16, #vtype=(e8, m1, ta, ma)
; vle8.v v3,32(fp) #avl=16, #vtype=(e8, m1, ta, ma)
; vle8.v v6,[const(0)] #avl=16, #vtype=(e8, m1, ta, ma)
; vrgather.vv v8,v1,v6 #avl=16, #vtype=(e8, m1, ta, ma)
; vadd.vi v10,v6,-16 #avl=16, #vtype=(e8, m1, ta, ma)
; vrgather.vv v12,v3,v10 #avl=16, #vtype=(e8, m1, ta, ma)
; vor.vv v14,v8,v12 #avl=16, #vtype=(e8, m1, ta, ma)
; vse8.v v14,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; add sp,+16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; ori s0, sp, 0
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, s0, 0x10
; .byte 0x87, 0x80, 0x0f, 0x02
; addi t6, s0, 0x20
; .byte 0x87, 0x81, 0x0f, 0x02
; auipc t6, 0
; addi t6, t6, 0x3c
; .byte 0x07, 0x83, 0x0f, 0x02
; .byte 0x57, 0x04, 0x13, 0x32
; .byte 0x57, 0x35, 0x68, 0x02
; .byte 0x57, 0x06, 0x35, 0x32
; .byte 0x57, 0x07, 0x86, 0x2a
; .byte 0x27, 0x07, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
; lb zero, 0x1a1(t5)
; .byte 0x04, 0x06, 0x0c, 0x0b
; auipc s10, 0x4180
; .byte 0x02, 0x0f, 0x11, 0x05

Loading

0 comments on commit 579918c

Please sign in to comment.