From a8de2c88e0bc8a591b2104b2e3fe63467262e6d8 Mon Sep 17 00:00:00 2001 From: Troy Hinckley Date: Fri, 21 Apr 2023 18:36:47 -0600 Subject: [PATCH] Optimize lines_crlf This change is not ready to be merged because it does not yet support x86. I don't know how to do the operation we need with SSE intrinsics. This new approach adds a another method to ByteChunk that lets you shift a value across two vectors. This let's us turn a 2-pass counting algorithim into a single pass. Making this change alone resulted in a 40% speedup. There is also some loop unrolling, but there was not a large difference in unrolling over a factor of 2. --- src/byte_chunk.rs | 27 +++ src/lines_crlf.rs | 165 ++++++++++-------- .../proptests_lines_crlf.proptest-regressions | 3 + 3 files changed, 119 insertions(+), 76 deletions(-) diff --git a/src/byte_chunk.rs b/src/byte_chunk.rs index 6f99310..3ff4e96 100644 --- a/src/byte_chunk.rs +++ b/src/byte_chunk.rs @@ -37,6 +37,9 @@ pub(crate) trait ByteChunk: Copy + Clone { /// Shifts bytes back lexographically by n bytes. fn shift_back_lex(&self, n: usize) -> Self; + /// Shifts the bottom byte of self into the top byte of n. + fn shift_across(&self, n: Self) -> Self; + /// Shifts bits to the right by n bits. fn shr(&self, n: usize) -> Self; @@ -100,6 +103,16 @@ impl ByteChunk for usize { } } + #[inline(always)] + fn shift_across(&self, n: Self) -> Self { + let size = (Self::SIZE - 1) * 8; + if cfg!(target_endian = "little") { + (*self >> size) | (n << 8) + } else { + (*self << size) | (n >> 8) + } + } + #[inline(always)] fn shr(&self, n: usize) -> Self { *self >> n @@ -198,6 +211,15 @@ impl ByteChunk for x86_64::__m128i { } } + #[inline(always)] + fn shift_across(&self, n: Self) -> Self { + unsafe { + let bottom_byte = x86_64::_mm_srli_si128(*self, 15); + let rest_shifted = x86_64::_mm_slli_si128(n, 1); + x86_64::_mm_or_si128(bottom_byte, rest_shifted) + } + } + #[inline(always)] fn shr(&self, n: usize) -> Self { match n { @@ -292,6 +314,11 @@ impl ByteChunk for aarch64::uint8x16_t { } } + #[inline(always)] + fn shift_across(&self, n: Self) -> Self { + unsafe { aarch64::vextq_u8(*self, n, 15) } + } + #[inline(always)] fn shr(&self, n: usize) -> Self { unsafe { diff --git a/src/lines_crlf.rs b/src/lines_crlf.rs index 435012b..8774ebd 100644 --- a/src/lines_crlf.rs +++ b/src/lines_crlf.rs @@ -55,6 +55,8 @@ pub fn to_byte_idx(text: &str, line_idx: usize) -> usize { } //------------------------------------------------------------- +const LF: u8 = b'\n'; +const CR: u8 = b'\r'; #[inline(always)] fn to_byte_idx_impl(text: &[u8], line_idx: usize) -> usize { @@ -67,61 +69,57 @@ fn to_byte_idx_impl(text: &[u8], line_idx: usize) -> usize { let mut byte_count = 0; let mut break_count = 0; - // Take care of any unaligned bytes at the beginning. - for byte in start.iter() { + let mut last_was_cr = false; + for byte in start.iter().copied() { + let is_lf = byte == LF; + let is_cr = byte == CR; if break_count == line_idx { - break; + if last_was_cr && is_lf { + byte_count += 1; + } + return byte_count; } - break_count += - (*byte == 0x0A || (*byte == 0x0D && text.get(byte_count + 1) != Some(&0x0A))) as usize; + if is_cr || (is_lf && !last_was_cr) { + break_count += 1; + } + last_was_cr = is_cr; byte_count += 1; } - // Process chunks in the fast path. - let mut chunks = middle; - let mut max_round_len = (line_idx - break_count) / T::MAX_ACC; - while max_round_len > 0 && !chunks.is_empty() { - // Choose the largest number of chunks we can do this round - // that will neither overflow `max_acc` nor blast past the - // remaining line breaks we're looking for. - let round_len = T::MAX_ACC.min(max_round_len).min(chunks.len()); - max_round_len -= round_len; - let round = &chunks[..round_len]; - chunks = &chunks[round_len..]; - - // Process the chunks in this round. - let mut acc = T::zero(); - for chunk in round.iter() { - let lf_flags = chunk.cmp_eq_byte(0x0A); - let cr_flags = chunk.cmp_eq_byte(0x0D); - let crlf_flags = cr_flags.bitand(lf_flags.shift_back_lex(1)); - acc = acc.add(lf_flags).add(cr_flags.sub(crlf_flags)); - } - break_count += acc.sum_bytes(); - - // Handle CRLFs at chunk boundaries in this round. - let mut i = byte_count; - while i < (byte_count + T::SIZE * round_len) { - i += T::SIZE; - break_count -= (text[i - 1] == 0x0D && text.get(i) == Some(&0x0A)) as usize; + // Process the chunks 2 at a time + let mut chunk_count = 0; + let mut prev = T::splat(last_was_cr as u8); + for chunks in middle.chunks_exact(2) { + let lf_flags0 = chunks[0].cmp_eq_byte(LF); + let cr_flags0 = chunks[0].cmp_eq_byte(CR); + let crlf_flags0 = prev.shift_across(cr_flags0).bitand(lf_flags0); + + let lf_flags1 = chunks[1].cmp_eq_byte(LF); + let cr_flags1 = chunks[1].cmp_eq_byte(CR); + let crlf_flags1 = cr_flags0.shift_across(cr_flags1).bitand(lf_flags1); + let new_break_count = break_count + + lf_flags0 + .add(cr_flags0) + .add(lf_flags1) + .add(cr_flags1) + .sub(crlf_flags0) + .sub(crlf_flags1) + .sum_bytes(); + if new_break_count >= line_idx { + break; } - - byte_count += T::SIZE * round_len; + break_count = new_break_count; + byte_count += T::SIZE * 2; + chunk_count += 2; + prev = cr_flags1; } - // Process chunks in the slow path. - for chunk in chunks.iter() { - let breaks = { - let lf_flags = chunk.cmp_eq_byte(0x0A); - let cr_flags = chunk.cmp_eq_byte(0x0D); - let crlf_flags = cr_flags.bitand(lf_flags.shift_back_lex(1)); - lf_flags.add(cr_flags.sub(crlf_flags)).sum_bytes() - }; - let boundary_crlf = { - let i = byte_count + T::SIZE; - (text[i - 1] == 0x0D && text.get(i) == Some(&0x0A)) as usize - }; - let new_break_count = break_count + breaks - boundary_crlf; + // Process the rest of the chunks + for chunk in middle[chunk_count..].iter() { + let lf_flags = chunk.cmp_eq_byte(LF); + let cr_flags = chunk.cmp_eq_byte(CR); + let crlf_flags = prev.shift_across(cr_flags).bitand(lf_flags); + let new_break_count = break_count + lf_flags.add(cr_flags).sub(crlf_flags).sum_bytes(); if new_break_count >= line_idx { break; } @@ -129,14 +127,20 @@ fn to_byte_idx_impl(text: &[u8], line_idx: usize) -> usize { byte_count += T::SIZE; } - // Take care of any unaligned bytes at the end. - let end = &text[byte_count..]; - for byte in end.iter() { + last_was_cr = text.get(byte_count.saturating_sub(1)) == Some(&CR); + for byte in text[byte_count..].iter().copied() { + let is_lf = byte == LF; + let is_cr = byte == CR; if break_count == line_idx { + if last_was_cr && is_lf { + byte_count += 1; + } break; } - break_count += - (*byte == 0x0A || (*byte == 0x0D && text.get(byte_count + 1) != Some(&0x0A))) as usize; + if is_cr || (is_lf && !last_was_cr) { + break_count += 1; + } + last_was_cr = is_cr; byte_count += 1; } @@ -159,39 +163,48 @@ fn count_breaks_impl(text: &[u8]) -> usize { // Take care of unaligned bytes at the beginning. let mut last_was_cr = false; for byte in start.iter().copied() { - let is_lf = byte == 0x0A; - let is_cr = byte == 0x0D; - count += (is_cr | (is_lf & !last_was_cr)) as usize; + let is_lf = byte == LF; + let is_cr = byte == CR; + if is_cr || (is_lf && !last_was_cr) { + count += 1; + } last_was_cr = is_cr; } - // Take care of the middle bytes in big chunks. - for chunks in middle.chunks(T::MAX_ACC) { - let mut acc = T::zero(); - for chunk in chunks.iter() { - let lf_flags = chunk.cmp_eq_byte(0x0A); - let cr_flags = chunk.cmp_eq_byte(0x0D); - let crlf_flags = cr_flags.bitand(lf_flags.shift_back_lex(1)); - acc = acc.add(lf_flags).add(cr_flags.sub(crlf_flags)); - } - count += acc.sum_bytes(); + let mut prev = T::splat(last_was_cr as u8); + for chunks in middle.chunks_exact(2) { + let lf_flags0 = chunks[0].cmp_eq_byte(LF); + let cr_flags0 = chunks[0].cmp_eq_byte(CR); + let crlf_flags0 = prev.shift_across(cr_flags0).bitand(lf_flags0); + + let lf_flags1 = chunks[1].cmp_eq_byte(LF); + let cr_flags1 = chunks[1].cmp_eq_byte(CR); + let crlf_flags1 = cr_flags0.shift_across(cr_flags1).bitand(lf_flags1); + count += lf_flags0 + .add(cr_flags0) + .sub(crlf_flags0) + .add(lf_flags1) + .add(cr_flags1) + .sub(crlf_flags1) + .sum_bytes(); + prev = cr_flags1; } - // Check chunk boundaries for CRLF. - let mut i = start.len(); - while i < (text.len() - end.len()) { - if text[i] == 0x0A { - count -= (text.get(i.saturating_sub(1)) == Some(&0x0D)) as usize; - } - i += T::SIZE; + if let Some(chunk) = middle.chunks_exact(2).remainder().iter().next() { + let lf_flags = chunk.cmp_eq_byte(LF); + let cr_flags = chunk.cmp_eq_byte(CR); + let crlf_flags = prev.shift_across(cr_flags).bitand(lf_flags); + count += lf_flags.add(cr_flags).sub(crlf_flags).sum_bytes(); } // Take care of unaligned bytes at the end. - let mut last_was_cr = text.get((text.len() - end.len()).saturating_sub(1)) == Some(&0x0D); + last_was_cr = text.get((text.len() - end.len()).saturating_sub(1)) == Some(&CR); for byte in end.iter().copied() { - let is_lf = byte == 0x0A; - let is_cr = byte == 0x0D; - count += (is_cr | (is_lf & !last_was_cr)) as usize; + let is_lf = byte == LF; + let is_cr = byte == CR; + if is_cr || (is_lf && !last_was_cr) { + count += 1; + } last_was_cr = is_cr; } diff --git a/tests/proptests_lines_crlf.proptest-regressions b/tests/proptests_lines_crlf.proptest-regressions index 9e251a6..323c9e0 100644 --- a/tests/proptests_lines_crlf.proptest-regressions +++ b/tests/proptests_lines_crlf.proptest-regressions @@ -6,3 +6,6 @@ # everyone who runs the test benefits from these saved cases. cc 7332c47dc044064cfb86ac6d94c45441f6dcdab699402718874116b512062e0a # shrinks to ref text = "a\n\rあ\r\n\rあa\n\rあ\n\nあ\r\nああ\r\nあ🐸a\ra\r🐸", idx = 48 cc f9a60685ebb7fc415f3d941598e75c90a0a7db0f0cd6673e092cf78f43a60fa3 # shrinks to ref text = "あああ\r\r\r\n\naaあ\naあ\r🐸🐸🐸\n\naあaaa\n🐸\na\n\na\n\n\n\r\n🐸\r\nあ\raaああ🐸a\naa🐸\n🐸\ra\na🐸a🐸\n\ra\n🐸🐸🐸\r\ra🐸あ\n\n🐸🐸aaあ\r\rあ🐸\rあa\n\r\n🐸🐸\n\nあ\rあa\nあa\rあaa\nあa\r\r\r\n\r\n\r\nあa🐸\r\n\r\r\na\r" +cc 3e5415f317b24c22a479d0e56a4705ad4c4c0cb060ee5610f94ef78fe3fda588 # shrinks to ref text = "\n\r\raa🐸\ra\n\raa🐸あ\r\n\n\nあ\nあ🐸🐸あ🐸\n🐸あああ🐸aaあ🐸a\n\rあああ🐸あ🐸\n\n\r🐸a\rああ\r🐸ああ\n\r\r\r🐸🐸あ🐸\r\r\n🐸🐸あaa\r\na\r\n🐸\na🐸\raあ\r\naあa\nああ\r\r🐸a\n\raa\rあ🐸\rあ\n\n\rあ\n\r\n🐸ああ\n🐸aあ\r\r\n\nあ🐸a\naaa🐸🐸あa", idx = 272 +cc 9fe71a1c06b7b4791fab564d43a9fe7e0d3047302e6c9228852e91f638833aae # shrinks to ref text = "\n\n\nあああaa\n🐸\rああ\r🐸🐸あ\n\nああ🐸\n\rああ\r\n🐸\nああ🐸\n\ra\r\r\naあ\n🐸\rあ\n\r", idx = 96 +cc b06906825a8db53b794a8dfbd9fc17ee80d3f722528a86644ad4457ba8ab12d8 # shrinks to ref text = "🐸\n\ra🐸\n🐸aaa\r🐸aaaa🐸🐸\nあ\r\naあ\rあa\raあ\rあ\rあa\ra\n\nあa\rあ\r\rあa\nああaa\n\r🐸\na\n🐸あ🐸a\r\n\n\naあa\raaa\r🐸🐸\r\rあaあ🐸\n🐸🐸\rあ\nああああaa\na\r\n\raa\n🐸\naあ\raaaa\n\n\naaあa🐸🐸🐸\ra\nあ\n\nあ🐸ああ🐸a🐸a\rあ\rあ🐸🐸あ🐸あa\n\raaあ\r🐸あ\ra\na🐸🐸\r\rあa\r\n\n\nあ\r\r\r🐸あ🐸🐸🐸\r\rあaあ🐸\n\n\r🐸\n🐸あ\r"