Skip to content

Commit

Permalink
Optimize compression by avoiding unpredictable branches
Browse files Browse the repository at this point in the history
Avoid unpredictable branch. Use conditional move to generate the address
that is guaranteed to be safe and compare unconditionally.
Instead of

if (idx < limit && x[idx] == val ) // mispredicted idx < limit branch

Do

addr = cmov(safe,x+idx)
if (*addr == val && idx < limit) // almost always false so well predicted

Using microbenchmarks from https://github.com/google/fleetbench,
I get about ~10% speed-up:

name                                                                                          old cpu/op   new cpu/op    delta
BM_ZSTD_COMPRESS_Fleet/compression_level:-7/window_log:15                                     1.46ns ± 3%   1.31ns ± 7%   -9.88%  (p=0.000 n=35+38)
BM_ZSTD_COMPRESS_Fleet/compression_level:-7/window_log:16                                     1.41ns ± 3%   1.28ns ± 3%   -9.56%  (p=0.000 n=36+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:-5/window_log:15                                     1.61ns ± 1%   1.43ns ± 3%  -10.70%  (p=0.000 n=30+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:-5/window_log:16                                     1.54ns ± 2%   1.39ns ± 3%   -9.21%  (p=0.000 n=37+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:-3/window_log:15                                     1.82ns ± 2%   1.61ns ± 3%  -11.31%  (p=0.000 n=37+40)
BM_ZSTD_COMPRESS_Fleet/compression_level:-3/window_log:16                                     1.73ns ± 3%   1.56ns ± 3%   -9.50%  (p=0.000 n=38+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:-1/window_log:15                                     2.12ns ± 2%   1.79ns ± 3%  -15.55%  (p=0.000 n=34+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:-1/window_log:16                                     1.99ns ± 3%   1.72ns ± 3%  -13.70%  (p=0.000 n=38+38)
BM_ZSTD_COMPRESS_Fleet/compression_level:0/window_log:15                                      3.22ns ± 3%   2.94ns ± 3%   -8.67%  (p=0.000 n=38+40)
BM_ZSTD_COMPRESS_Fleet/compression_level:0/window_log:16                                      3.19ns ± 4%   2.86ns ± 4%  -10.55%  (p=0.000 n=40+38)
BM_ZSTD_COMPRESS_Fleet/compression_level:1/window_log:15                                      2.60ns ± 3%   2.22ns ± 3%  -14.53%  (p=0.000 n=40+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:1/window_log:16                                      2.46ns ± 3%   2.13ns ± 2%  -13.67%  (p=0.000 n=39+36)
BM_ZSTD_COMPRESS_Fleet/compression_level:2/window_log:15                                      2.69ns ± 3%   2.46ns ± 3%   -8.63%  (p=0.000 n=37+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:2/window_log:16                                      2.63ns ± 3%   2.36ns ± 3%  -10.47%  (p=0.000 n=40+40)
BM_ZSTD_COMPRESS_Fleet/compression_level:3/window_log:15                                      3.20ns ± 2%   2.95ns ± 3%   -7.94%  (p=0.000 n=35+40)
BM_ZSTD_COMPRESS_Fleet/compression_level:3/window_log:16                                      3.20ns ± 4%   2.87ns ± 4%  -10.33%  (p=0.000 n=40+40)

I've also measured the impact on internal workloads and saw similar
~10% improvement in performance, measured by cpu usage/byte of data.
  • Loading branch information
TocarIP committed Sep 20, 2024
1 parent 10e2a80 commit e8fce38
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 25 deletions.
17 changes: 17 additions & 0 deletions lib/compress/zstd_compress_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,23 @@ MEM_STATIC int ZSTD_cParam_withinBounds(ZSTD_cParameter cParam, int value)
return 1;
}

/* ZSTD_selectAddr:
* @return a >= b ? trueAddr : falseAddr,
* tries to force branchless codegen. */
MEM_STATIC const BYTE* ZSTD_selectAddr(U32 a, U32 b, const BYTE* trueAddr, const BYTE* falseAddr) {
#if defined(__GNUC__) && defined(__x86_64__)
__asm__ (
"cmp %1, %2\n"
"cmova %3, %0\n"
: "+r"(trueAddr)
: "r"(a), "r"(b), "r"(falseAddr)
);
return trueAddr;
#else
return a >= b ? trueAddr : falseAddr;
#endif
}

/* ZSTD_noCompressBlock() :
* Writes uncompressed block to dst buffer from given src.
* Returns the size of the block */
Expand Down
37 changes: 24 additions & 13 deletions lib/compress/zstd_double_fast.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,17 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic(
U32 idxl1; /* the long match index for ip1 */

const BYTE* matchl0; /* the long match for ip */
const BYTE* matchl0_safe; /* matchl0 or safe address */
const BYTE* matchs0; /* the short match for ip */
const BYTE* matchl1; /* the long match for ip1 */
const BYTE* matchs0_safe; /* matchs0 or safe address */

const BYTE* ip = istart; /* the current position */
const BYTE* ip1; /* the next position */
/* Array of ~random data, should have low probability of matching data
* we load from here instead of from tables, if matchl0/matchl1 are
* invalid indices. Used to avoid unpredictable branches. */
const BYTE dummy[] = {0x12,0x34,0x56,0x78,0x9a,0xbc,0xde,0xf0,0xe2,0xb4};

DEBUGLOG(5, "ZSTD_compressBlock_doubleFast_noDict_generic");

Expand Down Expand Up @@ -191,24 +197,29 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic(

hl1 = ZSTD_hashPtr(ip1, hBitsL, 8);

if (idxl0 > prefixLowestIndex) {
/* check prefix long match */
if (MEM_read64(matchl0) == MEM_read64(ip)) {
mLength = ZSTD_count(ip+8, matchl0+8, iend) + 8;
offset = (U32)(ip-matchl0);
while (((ip>anchor) & (matchl0>prefixLowest)) && (ip[-1] == matchl0[-1])) { ip--; matchl0--; mLength++; } /* catch up */
goto _match_found;
}
/* idxl0 > prefixLowestIndex is a (somewhat) unpredictable branch.
* However expression below complies into conditional move. Since
* match is unlikely and we only *branch* on idxl0 > prefixLowestIndex
* if there is a match, all branches become predictable. */
matchl0_safe = ZSTD_selectAddr(prefixLowestIndex, idxl0, &dummy[0], matchl0);

/* check prefix long match */
if (MEM_read64(matchl0_safe) == MEM_read64(ip) && matchl0_safe == matchl0) {
mLength = ZSTD_count(ip+8, matchl0+8, iend) + 8;
offset = (U32)(ip-matchl0);
while (((ip>anchor) & (matchl0>prefixLowest)) && (ip[-1] == matchl0[-1])) { ip--; matchl0--; mLength++; } /* catch up */
goto _match_found;
}

idxl1 = hashLong[hl1];
matchl1 = base + idxl1;

if (idxs0 > prefixLowestIndex) {
/* check prefix short match */
if (MEM_read32(matchs0) == MEM_read32(ip)) {
goto _search_next_long;
}
/* Same optimization as matchl0 above */
matchs0_safe = ZSTD_selectAddr(prefixLowestIndex, idxs0, &dummy[0], matchs0);

/* check prefix short match */
if(MEM_read32(matchs0_safe) == MEM_read32(ip) && matchs0_safe == matchs0) {
goto _search_next_long;
}

if (ip1 >= nextStep) {
Expand Down
32 changes: 20 additions & 12 deletions lib/compress/zstd_fast.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
const BYTE* const prefixStart = base + prefixStartIndex;
const BYTE* const iend = istart + srcSize;
const BYTE* const ilimit = iend - HASH_READ_SIZE;
/* Array of ~random data, should have low probability of matching data
* we load from here instead of from tables, if the index is invalid.
* Used to avoid unpredictable branches. */
const BYTE dummy[] = {0x12,0x34,0x56,0x78,0x9a,0xbc,0xde,0xf0,0xe2,0xb4};
const BYTE *mvalAddr;

const BYTE* anchor = istart;
const BYTE* ip0 = istart;
Expand Down Expand Up @@ -246,15 +251,18 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
goto _match;
}

/* idx >= prefixStartIndex is a (somewhat) unpredictable branch.
* However expression below complies into conditional move. Since
* match is unlikely and we only *branch* on idxl0 > prefixLowestIndex
* if there is a match, all branches become predictable. */
mvalAddr = base + idx;
mvalAddr = ZSTD_selectAddr(idx, prefixStartIndex, mvalAddr, &dummy[0]);

/* load match for ip[0] */
if (idx >= prefixStartIndex) {
mval = MEM_read32(base + idx);
} else {
mval = MEM_read32(ip0) ^ 1; /* guaranteed to not match. */
}
mval = MEM_read32(mvalAddr);

/* check match at ip[0] */
if (MEM_read32(ip0) == mval) {
if (MEM_read32(ip0) == mval && idx >= prefixStartIndex) {
/* found a match! */

/* First write next hash table entry; we've already calculated it.
Expand All @@ -281,15 +289,15 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
current0 = (U32)(ip0 - base);
hashTable[hash0] = current0;

mvalAddr = base + idx;
mvalAddr = ZSTD_selectAddr(idx, prefixStartIndex, mvalAddr, &dummy[0]);

/* load match for ip[0] */
if (idx >= prefixStartIndex) {
mval = MEM_read32(base + idx);
} else {
mval = MEM_read32(ip0) ^ 1; /* guaranteed to not match. */
}
mval = MEM_read32(mvalAddr);


/* check match at ip[0] */
if (MEM_read32(ip0) == mval) {
if (MEM_read32(ip0) == mval && idx >= prefixStartIndex) {
/* found a match! */

/* first write next hash table entry; we've already calculated it */
Expand Down

0 comments on commit e8fce38

Please sign in to comment.