Skip to content

Commit

Permalink
riscv64: Implement SIMD icmp (bytecodealliance#6609)
Browse files Browse the repository at this point in the history
These are implemented as a combination of two steps, mask generation and
mask expansion. Our comparision rules only return their results as a mask
register, so we need to expand the mask into lane sized elements.

We have 20 (!) comparision instructions, nearly the full table of all IntCC codes
in VV, VX and VI formats. However there are some holes in this table.

They are:
* `vmsltu.vi`
* `vmslt.vi`
* `vmsgtu.vv`
* `vmsgt.vv`
* `vmsgeu.*`
* `vmsge.*`

Most of these can be replaces with the inverted IntCC instruction, however
this commit only implements the existing instructions without any inversion
and the inverted VV versions of `sgtu`/`sgt`/`sgeu`/`sge` since we need them
to get the full icmp functionality.

I've split the actual mask expansion into it's own separate rule since we are
going to need it for the `fcmp` rules as well.

The instruction selection for `icmp` is on a separate rule simply because the
rulse end up less verbose than if they were inlined directly into the `icmp` rule.
  • Loading branch information
afonso360 authored Jun 21, 2023
1 parent 1bc4ff3 commit b05a09c
Show file tree
Hide file tree
Showing 24 changed files with 3,999 additions and 12 deletions.
4 changes: 0 additions & 4 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,8 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"simd_f64x2_cmp",
"simd_f64x2_pmin_pmax",
"simd_f64x2_rounding",
"simd_i16x8_cmp",
"simd_i32x4_cmp",
"simd_i32x4_trunc_sat_f32x4",
"simd_i32x4_trunc_sat_f64x2",
"simd_i64x2_cmp",
"simd_i8x16_cmp",
"simd_load",
"simd_splat",
]
Expand Down
48 changes: 43 additions & 5 deletions cranelift/codegen/src/isa/riscv64/inst/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,14 @@ impl VecAluOpRRR {
VecAluOpRRR::VwaddWV | VecAluOpRRR::VwaddWX => 0b110101,
VecAluOpRRR::VwsubuWV | VecAluOpRRR::VwsubuWX => 0b110110,
VecAluOpRRR::VwsubWV | VecAluOpRRR::VwsubWX => 0b110111,
VecAluOpRRR::VmsltVX => 0b011011,
VecAluOpRRR::VmseqVV | VecAluOpRRR::VmseqVX => 0b011000,
VecAluOpRRR::VmsneVV | VecAluOpRRR::VmsneVX => 0b011001,
VecAluOpRRR::VmsltuVV | VecAluOpRRR::VmsltuVX => 0b011010,
VecAluOpRRR::VmsltVV | VecAluOpRRR::VmsltVX => 0b011011,
VecAluOpRRR::VmsleuVV | VecAluOpRRR::VmsleuVX => 0b011100,
VecAluOpRRR::VmsleVV | VecAluOpRRR::VmsleVX => 0b011101,
VecAluOpRRR::VmsgtuVX => 0b011110,
VecAluOpRRR::VmsgtVX => 0b011111,
}
}

Expand All @@ -381,7 +388,13 @@ impl VecAluOpRRR {
| VecAluOpRRR::VmaxuVV
| VecAluOpRRR::VmaxVV
| VecAluOpRRR::VmergeVVM
| VecAluOpRRR::VrgatherVV => VecOpCategory::OPIVV,
| VecAluOpRRR::VrgatherVV
| VecAluOpRRR::VmseqVV
| VecAluOpRRR::VmsneVV
| VecAluOpRRR::VmsltuVV
| VecAluOpRRR::VmsltVV
| VecAluOpRRR::VmsleuVV
| VecAluOpRRR::VmsleVV => VecOpCategory::OPIVV,
VecAluOpRRR::VwaddVV
| VecAluOpRRR::VwaddWV
| VecAluOpRRR::VwadduVV
Expand Down Expand Up @@ -427,8 +440,15 @@ impl VecAluOpRRR {
| VecAluOpRRR::VmaxVX
| VecAluOpRRR::VslidedownVX
| VecAluOpRRR::VmergeVXM
| VecAluOpRRR::VrgatherVX
| VecAluOpRRR::VmseqVX
| VecAluOpRRR::VmsneVX
| VecAluOpRRR::VmsltuVX
| VecAluOpRRR::VmsltVX
| VecAluOpRRR::VrgatherVX => VecOpCategory::OPIVX,
| VecAluOpRRR::VmsleuVX
| VecAluOpRRR::VmsleVX
| VecAluOpRRR::VmsgtuVX
| VecAluOpRRR::VmsgtVX => VecOpCategory::OPIVX,
VecAluOpRRR::VfaddVV
| VecAluOpRRR::VfsubVV
| VecAluOpRRR::VfmulVV
Expand Down Expand Up @@ -522,6 +542,12 @@ impl VecAluOpRRImm5 {
VecAluOpRRImm5::VsaddVI => 0b100001,
VecAluOpRRImm5::VrgatherVI => 0b001100,
VecAluOpRRImm5::VmvrV => 0b100111,
VecAluOpRRImm5::VmseqVI => 0b011000,
VecAluOpRRImm5::VmsneVI => 0b011001,
VecAluOpRRImm5::VmsleuVI => 0b011100,
VecAluOpRRImm5::VmsleVI => 0b011101,
VecAluOpRRImm5::VmsgtuVI => 0b011110,
VecAluOpRRImm5::VmsgtVI => 0b011111,
}
}

Expand All @@ -541,7 +567,13 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VsadduVI
| VecAluOpRRImm5::VsaddVI
| VecAluOpRRImm5::VrgatherVI
| VecAluOpRRImm5::VmvrV => VecOpCategory::OPIVI,
| VecAluOpRRImm5::VmvrV
| VecAluOpRRImm5::VmseqVI
| VecAluOpRRImm5::VmsneVI
| VecAluOpRRImm5::VmsleuVI
| VecAluOpRRImm5::VmsleVI
| VecAluOpRRImm5::VmsgtuVI
| VecAluOpRRImm5::VmsgtVI => VecOpCategory::OPIVI,
}
}

Expand All @@ -561,7 +593,13 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VxorVI
| VecAluOpRRImm5::VmergeVIM
| VecAluOpRRImm5::VsadduVI
| VecAluOpRRImm5::VsaddVI => false,
| VecAluOpRRImm5::VsaddVI
| VecAluOpRRImm5::VmseqVI
| VecAluOpRRImm5::VmsneVI
| VecAluOpRRImm5::VmsleuVI
| VecAluOpRRImm5::VmsleVI
| VecAluOpRRImm5::VmsgtuVI
| VecAluOpRRImm5::VmsgtVI => false,
}
}

Expand Down
254 changes: 253 additions & 1 deletion cranelift/codegen/src/isa/riscv64/inst_vector.isle
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@
(VredminuVS)
(VrgatherVV)
(VcompressVM)
(VmseqVV)
(VmsneVV)
(VmsltuVV)
(VmsltVV)
(VmsleuVV)
(VmsleVV)


;; Vector-Scalar Opcodes
(VaddVX)
Expand Down Expand Up @@ -169,7 +176,14 @@
(VmergeVXM)
(VfmergeVFM)
(VrgatherVX)
(VmseqVX)
(VmsneVX)
(VmsltuVX)
(VmsltVX)
(VmsleuVX)
(VmsleVX)
(VmsgtuVX)
(VmsgtVX)
))


Expand Down Expand Up @@ -199,6 +213,12 @@
;; This opcode represents multiple instructions `vmv1r`/`vmv2r`/`vmv4r`/etc...
;; The immediate field specifies how many registers should be copied.
(VmvrV)
(VmseqVI)
(VmsneVI)
(VmsleuVI)
(VmsleVI)
(VmsgtuVI)
(VmsgtVI)
))

;; Imm only ALU Ops
Expand Down Expand Up @@ -969,11 +989,126 @@
(rule (rv_vcompress_vm vs2 vs1 vstate)
(vec_alu_rrr (VecAluOpRRR.VcompressVM) vs2 vs1 (unmasked) vstate))

;; Helper for emitting the `vmslt.vx` (Vector Mask Set Less Than) instruction.
;; Helper for emitting the `vmseq.vv` (Vector Mask Set If Equal) instruction.
(decl rv_vmseq_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmseq_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmseqVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmseq.vx` (Vector Mask Set If Equal) instruction.
(decl rv_vmseq_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmseq_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmseqVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmseq.vi` (Vector Mask Set If Equal) instruction.
(decl rv_vmseq_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmseq_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmseqVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsne.vv` (Vector Mask Set If Not Equal) instruction.
(decl rv_vmsne_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsne_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsneVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsne.vx` (Vector Mask Set If Not Equal) instruction.
(decl rv_vmsne_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsne_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsneVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsne.vi` (Vector Mask Set If Not Equal) instruction.
(decl rv_vmsne_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmsne_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmsneVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsltu.vv` (Vector Mask Set If Less Than, Unsigned) instruction.
(decl rv_vmsltu_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsltu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsltuVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsltu.vx` (Vector Mask Set If Less Than, Unsigned) instruction.
(decl rv_vmsltu_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsltu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsltuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmslt.vv` (Vector Mask Set If Less Than) instruction.
(decl rv_vmslt_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmslt_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsltVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmslt.vx` (Vector Mask Set If Less Than) instruction.
(decl rv_vmslt_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmslt_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsltVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsleu.vv` (Vector Mask Set If Less Than or Equal, Unsigned) instruction.
(decl rv_vmsleu_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsleu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsleuVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsleu.vx` (Vector Mask Set If Less Than or Equal, Unsigned) instruction.
(decl rv_vmsleu_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsleu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsleuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsleu.vi` (Vector Mask Set If Less Than or Equal, Unsigned) instruction.
(decl rv_vmsleu_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmsleu_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmsleuVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsle.vv` (Vector Mask Set If Less Than or Equal) instruction.
(decl rv_vmsle_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsle_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsleVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsle.vx` (Vector Mask Set If Less Than or Equal) instruction.
(decl rv_vmsle_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsle_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsleVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsle.vi` (Vector Mask Set If Less Than or Equal) instruction.
(decl rv_vmsle_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmsle_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmsleVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsgt.vv` (Vector Mask Set If Greater Than, Unsigned) instruction.
;; This is an alias for `vmsltu.vv` with the operands inverted.
(decl rv_vmsgtu_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsgtu_vv vs2 vs1 mask vstate) (rv_vmsltu_vv vs1 vs2 mask vstate))

;; Helper for emitting the `vmsgtu.vx` (Vector Mask Set If Greater Than, Unsigned) instruction.
(decl rv_vmsgtu_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsgtu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsgtuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsgtu.vi` (Vector Mask Set If Greater Than, Unsigned) instruction.
(decl rv_vmsgtu_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmsgtu_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmsgtuVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsgt.vv` (Vector Mask Set If Greater Than) instruction.
;; This is an alias for `vmslt.vv` with the operands inverted.
(decl rv_vmsgt_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsgt_vv vs2 vs1 mask vstate) (rv_vmslt_vv vs1 vs2 mask vstate))

;; Helper for emitting the `vmsgt.vx` (Vector Mask Set If Greater Than) instruction.
(decl rv_vmsgt_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsgt_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsgtVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsgt.vi` (Vector Mask Set If Greater Than) instruction.
(decl rv_vmsgt_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmsgt_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmsgtVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsgeu.vv` (Vector Mask Set If Greater Than or Equal, Unsigned) instruction.
;; This is an alias for `vmsleu.vv` with the operands inverted.
(decl rv_vmsgeu_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsgeu_vv vs2 vs1 mask vstate) (rv_vmsleu_vv vs1 vs2 mask vstate))

;; Helper for emitting the `vmsge.vv` (Vector Mask Set If Greater Than or Equal) instruction.
;; This is an alias for `vmsle.vv` with the operands inverted.
(decl rv_vmsge_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsge_vv vs2 vs1 mask vstate) (rv_vmsle_vv vs1 vs2 mask vstate))

;; Helper for emitting the `vzext.vf2` instruction.
;; Zero-extend SEW/2 source to SEW destination
(decl rv_vzext_vf2 (VReg VecOpMasking VState) VReg)
Expand Down Expand Up @@ -1078,3 +1213,120 @@
(rule 0 (gen_slidedown_half (ty_vec_fits_in_register ty) src)
(if-let amt (u64_udiv (ty_lane_count ty) 2))
(rv_vslidedown_vx src (imm $I64 amt) (unmasked) ty))


;; Expands a mask into SEW wide lanes. Enabled lanes are set to all ones, disabled
;; lanes are set to all zeros.
(decl gen_expand_mask (Type VReg) VReg)
(rule (gen_expand_mask ty mask)
(if-let zero (imm5_from_i8 0))
(if-let neg1 (imm5_from_i8 -1))
(rv_vmerge_vim (rv_vmv_vi zero ty) neg1 mask ty))


;; Builds a vector mask corresponding to the IntCC operation.
(decl gen_icmp_mask (Type IntCC Value Value) VReg)

;; IntCC.Equal

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.Equal) x y)
(rv_vmseq_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.Equal) x (splat y))
(rv_vmseq_vx x y (unmasked) ty))

(rule 2 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.Equal) (splat x) y)
(rv_vmseq_vx y x (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.Equal) x (replicated_imm5 y))
(rv_vmseq_vi x y (unmasked) ty))

(rule 4 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.Equal) (replicated_imm5 x) y)
(rv_vmseq_vi y x (unmasked) ty))

;; IntCC.NotEqual

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.NotEqual) x y)
(rv_vmsne_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.NotEqual) x (splat y))
(rv_vmsne_vx x y (unmasked) ty))

(rule 2 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.NotEqual) (splat x) y)
(rv_vmsne_vx y x (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.NotEqual) x (replicated_imm5 y))
(rv_vmsne_vi x y (unmasked) ty))

(rule 4 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.NotEqual) (replicated_imm5 x) y)
(rv_vmsne_vi y x (unmasked) ty))

;; IntCC.UnsignedLessThan

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedLessThan) x y)
(rv_vmsltu_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedLessThan) x (splat y))
(rv_vmsltu_vx x y (unmasked) ty))

;; IntCC.SignedLessThan

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedLessThan) x y)
(rv_vmslt_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedLessThan) x (splat y))
(rv_vmslt_vx x y (unmasked) ty))

;; IntCC.UnsignedLessThanOrEqual

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedLessThanOrEqual) x y)
(rv_vmsleu_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedLessThanOrEqual) x (splat y))
(rv_vmsleu_vx x y (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedLessThanOrEqual) x (replicated_imm5 y))
(rv_vmsleu_vi x y (unmasked) ty))

;; IntCC.SignedLessThanOrEqual

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedLessThanOrEqual) x y)
(rv_vmsle_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedLessThanOrEqual) x (splat y))
(rv_vmsle_vx x y (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedLessThanOrEqual) x (replicated_imm5 y))
(rv_vmsle_vi x y (unmasked) ty))

;; IntCC.UnsignedGreaterThan

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedGreaterThan) x y)
(rv_vmsgtu_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedGreaterThan) x (splat y))
(rv_vmsgtu_vx x y (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedGreaterThan) x (replicated_imm5 y))
(rv_vmsgtu_vi x y (unmasked) ty))

;; IntCC.SignedGreaterThan

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedGreaterThan) x y)
(rv_vmsgt_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedGreaterThan) x (splat y))
(rv_vmsgt_vx x y (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedGreaterThan) x (replicated_imm5 y))
(rv_vmsgt_vi x y (unmasked) ty))

;; IntCC.UnsignedGreaterThanOrEqual

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedGreaterThanOrEqual) x y)
(rv_vmsgeu_vv x y (unmasked) ty))

;; IntCC.SignedGreaterThanOrEqual

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedGreaterThanOrEqual) x y)
(rv_vmsge_vv x y (unmasked) ty))
Loading

0 comments on commit b05a09c

Please sign in to comment.