diff --git a/utils/full_math.go b/utils/full_math.go index b8da971..e12886c 100644 --- a/utils/full_math.go +++ b/utils/full_math.go @@ -12,22 +12,10 @@ var ( One = big.NewInt(1) ) -// Calculates ceil(a×b÷denominator) with full precision +// MulDivRoundingUp Calculates ceil(a×b÷denominator) with full precision func MulDivRoundingUp(a, b, denominator *uint256.Int) (*uint256.Int, error) { - // the product can overflow so need to use big.Int here - // TODO: optimize this - var product, rem, result big.Int - product.Mul(a.ToBig(), b.ToBig()) - result.DivMod(&product, denominator.ToBig(), &rem) - if rem.Sign() != 0 { - result.Add(&result, One) - } - - resultU, overflow := uint256.FromBig(&result) - if overflow { - return nil, ErrMulDivOverflow - } - return resultU, nil + var result Uint256 + return &result, MulDivRoundingUpV2(a, b, denominator, &result) } func MulDivRoundingUpV2(a, b, denominator, result *uint256.Int) error { @@ -46,130 +34,40 @@ func MulDivRoundingUpV2(a, b, denominator, result *uint256.Int) error { return nil } -// result=floor(a×b÷denominator), remainder=a×b%denominator +// MulDivV2 z=floor(a×b÷denominator), r=a×b%denominator // (pass remainder=nil if not required) // (the main usage for `remainder` is to be used in `MulDivRoundingUpV2` to determine if we need to round up, so it won't have to call MulMod again) -func MulDivV2(a, b, denominator, result, remainder *uint256.Int) error { - // https://github.com/Uniswap/v3-core/blob/main/contracts/libraries/FullMath.sol - // 512-bit multiply [prod1 prod0] = a * b - // Compute the product mod 2**256 and mod 2**256 - 1 - // then use the Chinese Remainder Theorem to reconstruct - // the 512 bit result. The result is stored in two 256 - // variables such that product = prod1 * 2**256 + prod0 - var prod0 Uint256 // Least significant 256 bits of the product - var prod1 Uint256 // Most significant 256 bits of the product - - var denominatorTmp Uint256 // temp var (need to modify denominator along the way) - denominatorTmp.Set(denominator) - - var mm Uint256 - mm.MulMod(a, b, MaxUint256) - prod0.Mul(a, b) - prod1.Sub(&mm, &prod0) - if mm.Cmp(&prod0) < 0 { - prod1.SubUint64(&prod1, 1) - } - - // Handle non-overflow cases, 256 by 256 division - if prod1.IsZero() { - if denominatorTmp.IsZero() { - return ErrInvariant - } - - if remainder != nil { - // if the caller request then calculate remainder - remainder.MulMod(a, b, &denominatorTmp) - } - result.Div(&prod0, &denominatorTmp) +func MulDivV2(x, y, d, z, r *uint256.Int) error { + if x.IsZero() || y.IsZero() || d.IsZero() { + z.Clear() return nil } + p := umul(x, y) - // Make sure the result is less than 2**256. - // Also prevents denominator == 0 - if denominatorTmp.Cmp(&prod1) <= 0 { - return ErrInvariant + var quot [8]uint64 + rem := udivrem(quot[:], p[:], d) + if r != nil { + r.Set(&rem) } - /////////////////////////////////////////////// - // 512 by 256 division. - /////////////////////////////////////////////// + copy(z[:], quot[:4]) - // Make division exact by subtracting the remainder from [prod1 prod0] - // Compute remainder using mulmod - if remainder == nil { - // the caller doesn't request but we need it so use a temporary variable here - var remainderTmp Uint256 - remainder = &remainderTmp + if (quot[4] | quot[5] | quot[6] | quot[7]) != 0 { + return ErrMulDivOverflow } - remainder.MulMod(a, b, &denominatorTmp) - // Subtract 256 bit number from 512 bit number - if remainder.Cmp(&prod0) > 0 { - prod1.SubUint64(&prod1, 1) - } - prod0.Sub(&prod0, remainder) - - // Factor powers of two out of denominator - // Compute largest power of two divisor of denominator. - // Always >= 1. - var twos, tmp, tmp1, zero, two, three Uint256 - twos.And(tmp.Neg(&denominatorTmp), &denominatorTmp) - // Divide denominator by power of two - denominatorTmp.Div(&denominatorTmp, &twos) - - // Divide [prod1 prod0] by the factors of two - prod0.Div(&prod0, &twos) - // Shift in bits from prod1 into prod0. For this we need - // to flip `twos` such that it is 2**256 / twos. - // If twos is zero, then it becomes one - zero.Clear() - twos.AddUint64(tmp.Div(tmp1.Sub(&zero, &twos), &twos), 1) - prod0.Or(&prod0, tmp.Mul(&prod1, &twos)) - - // Invert denominator mod 2**256 - // Now that denominator is an odd number, it has an inverse - // modulo 2**256 such that denominator * inv = 1 mod 2**256. - // Compute the inverse by starting with a seed that is correct - // correct for four bits. That is, denominator * inv = 1 mod 2**4 - var inv Uint256 - two.SetUint64(2) - three.SetUint64(3) - inv.Xor(tmp.Mul(&denominatorTmp, &three), &two) - // Now use Newton-Raphson iteration to improve the precision. - // Thanks to Hensel's lifting lemma, this also works in modular - // arithmetic, doubling the correct bits in each step. - inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**8 - inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**16 - inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**32 - inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**64 - inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**128 - inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**256 - - // Because the division is now exact we can divide by multiplying - // with the modular inverse of denominator. This will give us the - // correct result modulo 2**256. Since the precoditions guarantee - // that the outcome is less than 2**256, this is the final result. - // We don't need to compute the high bits of the result and prod1 - // is no longer required. - result.Mul(&prod0, &inv) return nil } -// Calculates floor(a×b÷denominator) with full precision +// MulDiv Calculates floor(a×b÷denominator) with full precision func MulDiv(a, b, denominator *uint256.Int) (*uint256.Int, error) { - // the product can overflow so need to use big.Int here - // TODO: optimize this follow univ3 code - var product, result big.Int - product.Mul(a.ToBig(), b.ToBig()) - result.Div(&product, denominator.ToBig()) - - resultU, overflow := uint256.FromBig(&result) + result, overflow := new(uint256.Int).MulDivOverflow(a, b, denominator) if overflow { return nil, ErrMulDivOverflow } - return resultU, nil + return result, nil } -// Returns ceil(x / y) +// DivRoundingUp Returns ceil(x / y) func DivRoundingUp(a, denominator, result *uint256.Int) { var rem uint256.Int result.DivMod(a, denominator, &rem) diff --git a/utils/full_math_test.go b/utils/full_math_test.go index d215867..8c35341 100644 --- a/utils/full_math_test.go +++ b/utils/full_math_test.go @@ -20,11 +20,15 @@ func TestMulDiv(t *testing.T) { expResult string }{ {MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Dec()}, - {"0x100000000000000000000000000000000", "0x80000000000000000000000000000000", "0x180000000000000000000000000000000", "113427455640312821154458202477256070485"}, - {"0x100000000000000000000000000000000", "0x2300000000000000000000000000000000", "0x800000000000000000000000000000000", "1488735355279105777652263907513985925120"}, - {"0x100000000000000000000000000000000", "0x3e800000000000000000000000000000000", "0xbb800000000000000000000000000000000", "113427455640312821154458202477256070485"}, + {"0x100000000000000000000000000000000", "0x80000000000000000000000000000000", + "0x180000000000000000000000000000000", "113427455640312821154458202477256070485"}, + {"0x100000000000000000000000000000000", "0x2300000000000000000000000000000000", + "0x800000000000000000000000000000000", "1488735355279105777652263907513985925120"}, + {"0x100000000000000000000000000000000", "0x3e800000000000000000000000000000000", + "0xbb800000000000000000000000000000000", "113427455640312821154458202477256070485"}, - {"0x61ae64157b363469ec1e000000000000000000000000", "0x5d5502f19f7baee2e5fa2", "0x69b797741ba66bda48a81e9", "126036350226489723925526476841950279379016090973169"}, + {"0x61ae64157b363469ec1e000000000000000000000000", "0x5d5502f19f7baee2e5fa2", "0x69b797741ba66bda48a81e9", + "126036350226489723925526476841950279379016090973169"}, } for i, tt := range tests { t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { @@ -111,6 +115,57 @@ func TestMulDivV2(t *testing.T) { } } +// BenchmarkMulDivV2 +// old version (using univ3 math lib (?)): +// BenchmarkMulDivV2/test_0xa9ab0f5808bbc8ef22eb4ac1a00fcd41c618ed2fc4af748e9e0a03496c8_0x3dae37091f647259cfbe64b41a2464804821425_0xfdd6776a6c2c +// BenchmarkMulDivV2/test_0xa9ab0f5808bbc8ef22eb4ac1a00fcd41c618ed2fc4af748e9e0a03496c8_0x3dae37091f647259cfbe64b41a2464804821425_0xfdd6776a6c2c-16 10213753 124.9 ns/op +// BenchmarkMulDivV2/test_0x43f001346c79fd41_0xabcec79aae99f1d096bf86eba29af446ee3c99c69cac66993c35509_0x86dbc64474f66ca6bca0e1a49615e845e0f5ed60e2a52333a8c2b2bf691b271 +// BenchmarkMulDivV2/test_0x43f001346c79fd41_0xabcec79aae99f1d096bf86eba29af446ee3c99c69cac66993c35509_0x86dbc64474f66ca6bca0e1a49615e845e0f5ed60e2a52333a8c2b2bf691b271-16 2472106 488.8 ns/op +// BenchmarkMulDivV2/test_0x7328a5a0_0xbf42f97a8b8c6fba95e767cdfcbba62c36c378ccc619c39ceb5ffd8cd_0xed14bf3a06155275a2c4680b39ffce5620af83345f9a5198fde9 +// BenchmarkMulDivV2/test_0x7328a5a0_0xbf42f97a8b8c6fba95e767cdfcbba62c36c378ccc619c39ceb5ffd8cd_0xed14bf3a06155275a2c4680b39ffce5620af83345f9a5198fde9-16 2332716 529.5 ns/op +// BenchmarkMulDivV2/test_0x72b9b242544fb8fc7f5_0x8f945829b2f3890e_0xf04677b55a48822f663e0450310052ae +// BenchmarkMulDivV2/test_0x72b9b242544fb8fc7f5_0x8f945829b2f3890e_0xf04677b55a48822f663e0450310052ae-16 6543910 201.9 ns/op +// BenchmarkMulDivV2/test_0x43a131_0x50a329770_0x9e6d9c6a06552e8399e1c98125246037a2ab7781b005b13eb7474d5 +// BenchmarkMulDivV2/test_0x43a131_0x50a329770_0x9e6d9c6a06552e8399e1c98125246037a2ab7781b005b13eb7474d5-16 7596972 148.4 ns/op +// new version (using uint256's code): +// BenchmarkMulDivV2/test_0xa9ab0f5808bbc8ef22eb4ac1a00fcd41c618ed2fc4af748e9e0a03496c8_0x3dae37091f647259cfbe64b41a2464804821425_0xfdd6776a6c2c +// BenchmarkMulDivV2/test_0xa9ab0f5808bbc8ef22eb4ac1a00fcd41c618ed2fc4af748e9e0a03496c8_0x3dae37091f647259cfbe64b41a2464804821425_0xfdd6776a6c2c-16 18522574 65.71 ns/op +// BenchmarkMulDivV2/test_0x43f001346c79fd41_0xabcec79aae99f1d096bf86eba29af446ee3c99c69cac66993c35509_0x86dbc64474f66ca6bca0e1a49615e845e0f5ed60e2a52333a8c2b2bf691b271 +// BenchmarkMulDivV2/test_0x43f001346c79fd41_0xabcec79aae99f1d096bf86eba29af446ee3c99c69cac66993c35509_0x86dbc64474f66ca6bca0e1a49615e845e0f5ed60e2a52333a8c2b2bf691b271-16 20101362 67.59 ns/op +// BenchmarkMulDivV2/test_0x7328a5a0_0xbf42f97a8b8c6fba95e767cdfcbba62c36c378ccc619c39ceb5ffd8cd_0xed14bf3a06155275a2c4680b39ffce5620af83345f9a5198fde9 +// BenchmarkMulDivV2/test_0x7328a5a0_0xbf42f97a8b8c6fba95e767cdfcbba62c36c378ccc619c39ceb5ffd8cd_0xed14bf3a06155275a2c4680b39ffce5620af83345f9a5198fde9-16 18418658 59.09 ns/op +// BenchmarkMulDivV2/test_0x72b9b242544fb8fc7f5_0x8f945829b2f3890e_0xf04677b55a48822f663e0450310052ae +// BenchmarkMulDivV2/test_0x72b9b242544fb8fc7f5_0x8f945829b2f3890e_0xf04677b55a48822f663e0450310052ae-16 24769188 50.33 ns/op +// BenchmarkMulDivV2/test_0x43a131_0x50a329770_0x9e6d9c6a06552e8399e1c98125246037a2ab7781b005b13eb7474d5 +// BenchmarkMulDivV2/test_0x43a131_0x50a329770_0x9e6d9c6a06552e8399e1c98125246037a2ab7781b005b13eb7474d5-16 41646259 33.35 ns/op +func BenchmarkMulDivV2(tb *testing.B) { + rand.Seed(0) + for i := 0; i < 5; i++ { + a := RandUint256() + b := RandUint256() + deno := RandUint256() + + tb.Run(fmt.Sprintf("test %s %s %s", a.Hex(), b.Hex(), deno.Hex()), func(tb *testing.B) { + r, err := MulDiv(a, b, deno) + + var rv2 Uint256 + var errv2 error + tb.ResetTimer() + for i := 0; i < tb.N; i++ { + errv2 = MulDivV2(a, b, deno, &rv2, nil) + } + tb.StopTimer() + + if err != nil { + require.NotNil(tb, errv2) + } else { + require.Nil(tb, errv2) + assert.Equal(tb, r.Dec(), rv2.Dec()) + } + }) + } +} + func TestMulDivRoundingUp(t *testing.T) { // https://github.com/Uniswap/v3-core/blob/main/test/FullMath.spec.ts @@ -121,11 +176,15 @@ func TestMulDivRoundingUp(t *testing.T) { expResult string }{ {MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Dec()}, - {"0x100000000000000000000000000000000", "0x80000000000000000000000000000000", "0x180000000000000000000000000000000", "113427455640312821154458202477256070486"}, - {"0x100000000000000000000000000000000", "0x2300000000000000000000000000000000", "0x800000000000000000000000000000000", "1488735355279105777652263907513985925120"}, - {"0x100000000000000000000000000000000", "0x3e800000000000000000000000000000000", "0xbb800000000000000000000000000000000", "113427455640312821154458202477256070486"}, + {"0x100000000000000000000000000000000", "0x80000000000000000000000000000000", + "0x180000000000000000000000000000000", "113427455640312821154458202477256070486"}, + {"0x100000000000000000000000000000000", "0x2300000000000000000000000000000000", + "0x800000000000000000000000000000000", "1488735355279105777652263907513985925120"}, + {"0x100000000000000000000000000000000", "0x3e800000000000000000000000000000000", + "0xbb800000000000000000000000000000000", "113427455640312821154458202477256070486"}, - {"0x2a60f4810d72e89eaee06f20122f1de80adc64777e121", "0xfd21718acef075500c6395ba922064220", "0xd195e7433221b9e4b6ef3f19b457c9c9797ae6b5eaacb402113dce147e97979f", "14406918379743960"}, + {"0x2a60f4810d72e89eaee06f20122f1de80adc64777e121", "0xfd21718acef075500c6395ba922064220", + "0xd195e7433221b9e4b6ef3f19b457c9c9797ae6b5eaacb402113dce147e97979f", "14406918379743960"}, } for i, tt := range tests { t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { @@ -154,8 +213,11 @@ func TestMulDivRoundingUp(t *testing.T) { // {"0x100000000000000000000000000000000", "0x100000000000000000000000000000000", "0x0"}, {"0x100000000000000000000000000000000", "0x100000000000000000000000000000000", "0x1"}, {MaxUint256.Hex(), MaxUint256.Hex(), new(Uint256).SubUint64(MaxUint256, 1).Hex()}, - {"0x1e695d2db4f97", "0x10d5effea103c44aaf18a26b449186a7de3dd6c1ce3d26d03dfd9", "0x2"}, // mulDiv overflows 256 bits after rounding up - {"0xffffffffffffffffffffffffffffffffffffffb07f6d608e4dcc38020b140b35", "0xffffffffffffffffffffffffffffffffffffffb07f6d608e4dcc38020b140b36", "0xffffffffffffffffffffffffffffffffffffff60fedac11c9b9870041628166c"}, // mulDiv overflows 256 bits after rounding up case 2 + {"0x1e695d2db4f97", "0x10d5effea103c44aaf18a26b449186a7de3dd6c1ce3d26d03dfd9", + "0x2"}, // mulDiv overflows 256 bits after rounding up + {"0xffffffffffffffffffffffffffffffffffffffb07f6d608e4dcc38020b140b35", + "0xffffffffffffffffffffffffffffffffffffffb07f6d608e4dcc38020b140b36", + "0xffffffffffffffffffffffffffffffffffffff60fedac11c9b9870041628166c"}, // mulDiv overflows 256 bits after rounding up case 2 } for i, tt := range failTests { t.Run(fmt.Sprintf("fail test %d", i), func(t *testing.T) { diff --git a/utils/uint256.go b/utils/uint256.go new file mode 100644 index 0000000..4ab39a2 --- /dev/null +++ b/utils/uint256.go @@ -0,0 +1,251 @@ +// BSD 3-Clause License +// +// Copyright 2020 uint256 Authors +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package utils + +import ( + "math/bits" + + "github.com/holiman/uint256" +) + +// umul computes full 256 x 256 -> 512 multiplication. +func umul(x, y *uint256.Int) [8]uint64 { + var ( + res [8]uint64 + carry, carry4, carry5, carry6 uint64 + res1, res2, res3, res4, res5 uint64 + ) + + carry, res[0] = bits.Mul64(x[0], y[0]) + carry, res1 = umulHop(carry, x[1], y[0]) + carry, res2 = umulHop(carry, x[2], y[0]) + carry4, res3 = umulHop(carry, x[3], y[0]) + + carry, res[1] = umulHop(res1, x[0], y[1]) + carry, res2 = umulStep(res2, x[1], y[1], carry) + carry, res3 = umulStep(res3, x[2], y[1], carry) + carry5, res4 = umulStep(carry4, x[3], y[1], carry) + + carry, res[2] = umulHop(res2, x[0], y[2]) + carry, res3 = umulStep(res3, x[1], y[2], carry) + carry, res4 = umulStep(res4, x[2], y[2], carry) + carry6, res5 = umulStep(carry5, x[3], y[2], carry) + + carry, res[3] = umulHop(res3, x[0], y[3]) + carry, res[4] = umulStep(res4, x[1], y[3], carry) + carry, res[5] = umulStep(res5, x[2], y[3], carry) + res[7], res[6] = umulStep(carry6, x[3], y[3], carry) + + return res +} + +// umulHop computes (hi * 2^64 + lo) = z + (x * y) +func umulHop(z, x, y uint64) (hi, lo uint64) { + hi, lo = bits.Mul64(x, y) + lo, carry := bits.Add64(lo, z, 0) + hi, _ = bits.Add64(hi, 0, carry) + return hi, lo +} + +// umulStep computes (hi * 2^64 + lo) = z + (x * y) + carry. +func umulStep(z, x, y, carry uint64) (hi, lo uint64) { + hi, lo = bits.Mul64(x, y) + lo, carry = bits.Add64(lo, carry, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, z, 0) + hi, _ = bits.Add64(hi, 0, carry) + return hi, lo +} + +// udivrem divides u by d and produces both quotient and remainder. +// The quotient is stored in provided quot - len(u)-len(d)+1 words. +// It loosely follows the Knuth's division algorithm (sometimes referenced as "schoolbook" division) using 64-bit words. +// See Knuth, Volume 2, section 4.3.1, Algorithm D. +func udivrem(quot, u []uint64, d *uint256.Int) (rem uint256.Int) { + var dLen int + for i := len(d) - 1; i >= 0; i-- { + if d[i] != 0 { + dLen = i + 1 + break + } + } + + shift := uint(bits.LeadingZeros64(d[dLen-1])) + + var dnStorage uint256.Int + dn := dnStorage[:dLen] + for i := dLen - 1; i > 0; i-- { + dn[i] = (d[i] << shift) | (d[i-1] >> (64 - shift)) + } + dn[0] = d[0] << shift + + var uLen int + for i := len(u) - 1; i >= 0; i-- { + if u[i] != 0 { + uLen = i + 1 + break + } + } + + if uLen < dLen { + copy(rem[:], u) + return rem + } + + var unStorage [9]uint64 + un := unStorage[:uLen+1] + un[uLen] = u[uLen-1] >> (64 - shift) + for i := uLen - 1; i > 0; i-- { + un[i] = (u[i] << shift) | (u[i-1] >> (64 - shift)) + } + un[0] = u[0] << shift + + // TODO: Skip the highest word of numerator if not significant. + + if dLen == 1 { + r := udivremBy1(quot, un, dn[0]) + rem.SetUint64(r >> shift) + return rem + } + + udivremKnuth(quot, un, dn) + + for i := 0; i < dLen-1; i++ { + rem[i] = (un[i] >> shift) | (un[i+1] << (64 - shift)) + } + rem[dLen-1] = un[dLen-1] >> shift + + return rem +} + +// udivremBy1 divides u by single normalized word d and produces both quotient and remainder. +// The quotient is stored in provided quot. +func udivremBy1(quot, u []uint64, d uint64) (rem uint64) { + reciprocal := reciprocal2by1(d) + rem = u[len(u)-1] // Set the top word as remainder. + for j := len(u) - 2; j >= 0; j-- { + quot[j], rem = udivrem2by1(rem, u[j], d, reciprocal) + } + return rem +} + +// reciprocal2by1 computes <^d, ^0> / d. +func reciprocal2by1(d uint64) uint64 { + reciprocal, _ := bits.Div64(^d, ^uint64(0), d) + return reciprocal +} + +// udivrem2by1 divides / d and produces both quotient and remainder. +// It uses the provided d's reciprocal. +// Implementation ported from https://github.com/chfast/intx and is based on +// "Improved division by invariant integers", Algorithm 4. +func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) { + qh, ql := bits.Mul64(reciprocal, uh) + ql, carry := bits.Add64(ql, ul, 0) + qh, _ = bits.Add64(qh, uh, carry) + qh++ + + r := ul - qh*d + + if r > ql { + qh-- + r += d + } + + if r >= d { + qh++ + r -= d + } + + return qh, r +} + +// udivremKnuth implements the division of u by normalized multiple word d from the Knuth's division algorithm. +// The quotient is stored in provided quot - len(u)-len(d) words. +// Updates u to contain the remainder - len(d) words. +func udivremKnuth(quot, u, d []uint64) { + dh := d[len(d)-1] + dl := d[len(d)-2] + reciprocal := reciprocal2by1(dh) + + for j := len(u) - len(d) - 1; j >= 0; j-- { + u2 := u[j+len(d)] + u1 := u[j+len(d)-1] + u0 := u[j+len(d)-2] + + var qhat, rhat uint64 + if u2 >= dh { // Division overflows. + qhat = ^uint64(0) + // TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case). + } else { + qhat, rhat = udivrem2by1(u2, u1, dh, reciprocal) + ph, pl := bits.Mul64(qhat, dl) + if ph > rhat || (ph == rhat && pl > u0) { + qhat-- + // TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case). + } + } + + // Multiply and subtract. + borrow := subMulTo(u[j:], d, qhat) + u[j+len(d)] = u2 - borrow + if u2 < borrow { // Too much subtracted, add back. + qhat-- + u[j+len(d)] += addTo(u[j:], d) + } + + quot[j] = qhat // Store quotient digit. + } +} + +// subMulTo computes x -= y * multiplier. +// Requires len(x) >= len(y). +func subMulTo(x, y []uint64, multiplier uint64) uint64 { + + var borrow uint64 + for i := 0; i < len(y); i++ { + s, carry1 := bits.Sub64(x[i], borrow, 0) + ph, pl := bits.Mul64(y[i], multiplier) + t, carry2 := bits.Sub64(s, pl, 0) + x[i] = t + borrow = ph + carry1 + carry2 + } + return borrow +} + +// addTo computes x += y. +// Requires len(x) >= len(y). +func addTo(x, y []uint64) uint64 { + var carry uint64 + for i := 0; i < len(y); i++ { + x[i], carry = bits.Add64(x[i], y[i], carry) + } + return carry +}