Skip to content

Commit

Permalink
ft: optimize MulDiv(RoundingUp) (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
NgoKimPhu authored Dec 23, 2024
1 parent 5c17d0b commit bf7bcb8
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 131 deletions.
140 changes: 19 additions & 121 deletions utils/full_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,10 @@ var (
One = big.NewInt(1)
)

// Calculates ceil(a×b÷denominator) with full precision
// MulDivRoundingUp Calculates ceil(a×b÷denominator) with full precision
func MulDivRoundingUp(a, b, denominator *uint256.Int) (*uint256.Int, error) {
// the product can overflow so need to use big.Int here
// TODO: optimize this
var product, rem, result big.Int
product.Mul(a.ToBig(), b.ToBig())
result.DivMod(&product, denominator.ToBig(), &rem)
if rem.Sign() != 0 {
result.Add(&result, One)
}

resultU, overflow := uint256.FromBig(&result)
if overflow {
return nil, ErrMulDivOverflow
}
return resultU, nil
var result Uint256
return &result, MulDivRoundingUpV2(a, b, denominator, &result)
}

func MulDivRoundingUpV2(a, b, denominator, result *uint256.Int) error {
Expand All @@ -46,130 +34,40 @@ func MulDivRoundingUpV2(a, b, denominator, result *uint256.Int) error {
return nil
}

// result=floor(a×b÷denominator), remainder=a×b%denominator
// MulDivV2 z=floor(a×b÷denominator), r=a×b%denominator
// (pass remainder=nil if not required)
// (the main usage for `remainder` is to be used in `MulDivRoundingUpV2` to determine if we need to round up, so it won't have to call MulMod again)
func MulDivV2(a, b, denominator, result, remainder *uint256.Int) error {
// https://github.com/Uniswap/v3-core/blob/main/contracts/libraries/FullMath.sol
// 512-bit multiply [prod1 prod0] = a * b
// Compute the product mod 2**256 and mod 2**256 - 1
// then use the Chinese Remainder Theorem to reconstruct
// the 512 bit result. The result is stored in two 256
// variables such that product = prod1 * 2**256 + prod0
var prod0 Uint256 // Least significant 256 bits of the product
var prod1 Uint256 // Most significant 256 bits of the product

var denominatorTmp Uint256 // temp var (need to modify denominator along the way)
denominatorTmp.Set(denominator)

var mm Uint256
mm.MulMod(a, b, MaxUint256)
prod0.Mul(a, b)
prod1.Sub(&mm, &prod0)
if mm.Cmp(&prod0) < 0 {
prod1.SubUint64(&prod1, 1)
}

// Handle non-overflow cases, 256 by 256 division
if prod1.IsZero() {
if denominatorTmp.IsZero() {
return ErrInvariant
}

if remainder != nil {
// if the caller request then calculate remainder
remainder.MulMod(a, b, &denominatorTmp)
}
result.Div(&prod0, &denominatorTmp)
func MulDivV2(x, y, d, z, r *uint256.Int) error {
if x.IsZero() || y.IsZero() || d.IsZero() {
z.Clear()
return nil
}
p := umul(x, y)

// Make sure the result is less than 2**256.
// Also prevents denominator == 0
if denominatorTmp.Cmp(&prod1) <= 0 {
return ErrInvariant
var quot [8]uint64
rem := udivrem(quot[:], p[:], d)
if r != nil {
r.Set(&rem)
}

///////////////////////////////////////////////
// 512 by 256 division.
///////////////////////////////////////////////
copy(z[:], quot[:4])

// Make division exact by subtracting the remainder from [prod1 prod0]
// Compute remainder using mulmod
if remainder == nil {
// the caller doesn't request but we need it so use a temporary variable here
var remainderTmp Uint256
remainder = &remainderTmp
if (quot[4] | quot[5] | quot[6] | quot[7]) != 0 {
return ErrMulDivOverflow
}
remainder.MulMod(a, b, &denominatorTmp)
// Subtract 256 bit number from 512 bit number
if remainder.Cmp(&prod0) > 0 {
prod1.SubUint64(&prod1, 1)
}
prod0.Sub(&prod0, remainder)

// Factor powers of two out of denominator
// Compute largest power of two divisor of denominator.
// Always >= 1.
var twos, tmp, tmp1, zero, two, three Uint256
twos.And(tmp.Neg(&denominatorTmp), &denominatorTmp)
// Divide denominator by power of two
denominatorTmp.Div(&denominatorTmp, &twos)

// Divide [prod1 prod0] by the factors of two
prod0.Div(&prod0, &twos)
// Shift in bits from prod1 into prod0. For this we need
// to flip `twos` such that it is 2**256 / twos.
// If twos is zero, then it becomes one
zero.Clear()
twos.AddUint64(tmp.Div(tmp1.Sub(&zero, &twos), &twos), 1)
prod0.Or(&prod0, tmp.Mul(&prod1, &twos))

// Invert denominator mod 2**256
// Now that denominator is an odd number, it has an inverse
// modulo 2**256 such that denominator * inv = 1 mod 2**256.
// Compute the inverse by starting with a seed that is correct
// correct for four bits. That is, denominator * inv = 1 mod 2**4
var inv Uint256
two.SetUint64(2)
three.SetUint64(3)
inv.Xor(tmp.Mul(&denominatorTmp, &three), &two)
// Now use Newton-Raphson iteration to improve the precision.
// Thanks to Hensel's lifting lemma, this also works in modular
// arithmetic, doubling the correct bits in each step.
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**8
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**16
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**32
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**64
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**128
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**256

// Because the division is now exact we can divide by multiplying
// with the modular inverse of denominator. This will give us the
// correct result modulo 2**256. Since the precoditions guarantee
// that the outcome is less than 2**256, this is the final result.
// We don't need to compute the high bits of the result and prod1
// is no longer required.
result.Mul(&prod0, &inv)
return nil
}

// Calculates floor(a×b÷denominator) with full precision
// MulDiv Calculates floor(a×b÷denominator) with full precision
func MulDiv(a, b, denominator *uint256.Int) (*uint256.Int, error) {
// the product can overflow so need to use big.Int here
// TODO: optimize this follow univ3 code
var product, result big.Int
product.Mul(a.ToBig(), b.ToBig())
result.Div(&product, denominator.ToBig())

resultU, overflow := uint256.FromBig(&result)
result, overflow := new(uint256.Int).MulDivOverflow(a, b, denominator)
if overflow {
return nil, ErrMulDivOverflow
}
return resultU, nil
return result, nil
}

// Returns ceil(x / y)
// DivRoundingUp Returns ceil(x / y)
func DivRoundingUp(a, denominator, result *uint256.Int) {
var rem uint256.Int
result.DivMod(a, denominator, &rem)
Expand Down
82 changes: 72 additions & 10 deletions utils/full_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@ func TestMulDiv(t *testing.T) {
expResult string
}{
{MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Dec()},
{"0x100000000000000000000000000000000", "0x80000000000000000000000000000000", "0x180000000000000000000000000000000", "113427455640312821154458202477256070485"},
{"0x100000000000000000000000000000000", "0x2300000000000000000000000000000000", "0x800000000000000000000000000000000", "1488735355279105777652263907513985925120"},
{"0x100000000000000000000000000000000", "0x3e800000000000000000000000000000000", "0xbb800000000000000000000000000000000", "113427455640312821154458202477256070485"},
{"0x100000000000000000000000000000000", "0x80000000000000000000000000000000",
"0x180000000000000000000000000000000", "113427455640312821154458202477256070485"},
{"0x100000000000000000000000000000000", "0x2300000000000000000000000000000000",
"0x800000000000000000000000000000000", "1488735355279105777652263907513985925120"},
{"0x100000000000000000000000000000000", "0x3e800000000000000000000000000000000",
"0xbb800000000000000000000000000000000", "113427455640312821154458202477256070485"},

{"0x61ae64157b363469ec1e000000000000000000000000", "0x5d5502f19f7baee2e5fa2", "0x69b797741ba66bda48a81e9", "126036350226489723925526476841950279379016090973169"},
{"0x61ae64157b363469ec1e000000000000000000000000", "0x5d5502f19f7baee2e5fa2", "0x69b797741ba66bda48a81e9",
"126036350226489723925526476841950279379016090973169"},
}
for i, tt := range tests {
t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
Expand Down Expand Up @@ -111,6 +115,57 @@ func TestMulDivV2(t *testing.T) {
}
}

// BenchmarkMulDivV2
// old version (using univ3 math lib (?)):
// BenchmarkMulDivV2/test_0xa9ab0f5808bbc8ef22eb4ac1a00fcd41c618ed2fc4af748e9e0a03496c8_0x3dae37091f647259cfbe64b41a2464804821425_0xfdd6776a6c2c
// BenchmarkMulDivV2/test_0xa9ab0f5808bbc8ef22eb4ac1a00fcd41c618ed2fc4af748e9e0a03496c8_0x3dae37091f647259cfbe64b41a2464804821425_0xfdd6776a6c2c-16 10213753 124.9 ns/op
// BenchmarkMulDivV2/test_0x43f001346c79fd41_0xabcec79aae99f1d096bf86eba29af446ee3c99c69cac66993c35509_0x86dbc64474f66ca6bca0e1a49615e845e0f5ed60e2a52333a8c2b2bf691b271
// BenchmarkMulDivV2/test_0x43f001346c79fd41_0xabcec79aae99f1d096bf86eba29af446ee3c99c69cac66993c35509_0x86dbc64474f66ca6bca0e1a49615e845e0f5ed60e2a52333a8c2b2bf691b271-16 2472106 488.8 ns/op
// BenchmarkMulDivV2/test_0x7328a5a0_0xbf42f97a8b8c6fba95e767cdfcbba62c36c378ccc619c39ceb5ffd8cd_0xed14bf3a06155275a2c4680b39ffce5620af83345f9a5198fde9
// BenchmarkMulDivV2/test_0x7328a5a0_0xbf42f97a8b8c6fba95e767cdfcbba62c36c378ccc619c39ceb5ffd8cd_0xed14bf3a06155275a2c4680b39ffce5620af83345f9a5198fde9-16 2332716 529.5 ns/op
// BenchmarkMulDivV2/test_0x72b9b242544fb8fc7f5_0x8f945829b2f3890e_0xf04677b55a48822f663e0450310052ae
// BenchmarkMulDivV2/test_0x72b9b242544fb8fc7f5_0x8f945829b2f3890e_0xf04677b55a48822f663e0450310052ae-16 6543910 201.9 ns/op
// BenchmarkMulDivV2/test_0x43a131_0x50a329770_0x9e6d9c6a06552e8399e1c98125246037a2ab7781b005b13eb7474d5
// BenchmarkMulDivV2/test_0x43a131_0x50a329770_0x9e6d9c6a06552e8399e1c98125246037a2ab7781b005b13eb7474d5-16 7596972 148.4 ns/op
// new version (using uint256's code):
// BenchmarkMulDivV2/test_0xa9ab0f5808bbc8ef22eb4ac1a00fcd41c618ed2fc4af748e9e0a03496c8_0x3dae37091f647259cfbe64b41a2464804821425_0xfdd6776a6c2c
// BenchmarkMulDivV2/test_0xa9ab0f5808bbc8ef22eb4ac1a00fcd41c618ed2fc4af748e9e0a03496c8_0x3dae37091f647259cfbe64b41a2464804821425_0xfdd6776a6c2c-16 18522574 65.71 ns/op
// BenchmarkMulDivV2/test_0x43f001346c79fd41_0xabcec79aae99f1d096bf86eba29af446ee3c99c69cac66993c35509_0x86dbc64474f66ca6bca0e1a49615e845e0f5ed60e2a52333a8c2b2bf691b271
// BenchmarkMulDivV2/test_0x43f001346c79fd41_0xabcec79aae99f1d096bf86eba29af446ee3c99c69cac66993c35509_0x86dbc64474f66ca6bca0e1a49615e845e0f5ed60e2a52333a8c2b2bf691b271-16 20101362 67.59 ns/op
// BenchmarkMulDivV2/test_0x7328a5a0_0xbf42f97a8b8c6fba95e767cdfcbba62c36c378ccc619c39ceb5ffd8cd_0xed14bf3a06155275a2c4680b39ffce5620af83345f9a5198fde9
// BenchmarkMulDivV2/test_0x7328a5a0_0xbf42f97a8b8c6fba95e767cdfcbba62c36c378ccc619c39ceb5ffd8cd_0xed14bf3a06155275a2c4680b39ffce5620af83345f9a5198fde9-16 18418658 59.09 ns/op
// BenchmarkMulDivV2/test_0x72b9b242544fb8fc7f5_0x8f945829b2f3890e_0xf04677b55a48822f663e0450310052ae
// BenchmarkMulDivV2/test_0x72b9b242544fb8fc7f5_0x8f945829b2f3890e_0xf04677b55a48822f663e0450310052ae-16 24769188 50.33 ns/op
// BenchmarkMulDivV2/test_0x43a131_0x50a329770_0x9e6d9c6a06552e8399e1c98125246037a2ab7781b005b13eb7474d5
// BenchmarkMulDivV2/test_0x43a131_0x50a329770_0x9e6d9c6a06552e8399e1c98125246037a2ab7781b005b13eb7474d5-16 41646259 33.35 ns/op
func BenchmarkMulDivV2(tb *testing.B) {
rand.Seed(0)
for i := 0; i < 5; i++ {
a := RandUint256()
b := RandUint256()
deno := RandUint256()

tb.Run(fmt.Sprintf("test %s %s %s", a.Hex(), b.Hex(), deno.Hex()), func(tb *testing.B) {
r, err := MulDiv(a, b, deno)

var rv2 Uint256
var errv2 error
tb.ResetTimer()
for i := 0; i < tb.N; i++ {
errv2 = MulDivV2(a, b, deno, &rv2, nil)
}
tb.StopTimer()

if err != nil {
require.NotNil(tb, errv2)
} else {
require.Nil(tb, errv2)
assert.Equal(tb, r.Dec(), rv2.Dec())
}
})
}
}

func TestMulDivRoundingUp(t *testing.T) {
// https://github.com/Uniswap/v3-core/blob/main/test/FullMath.spec.ts

Expand All @@ -121,11 +176,15 @@ func TestMulDivRoundingUp(t *testing.T) {
expResult string
}{
{MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Dec()},
{"0x100000000000000000000000000000000", "0x80000000000000000000000000000000", "0x180000000000000000000000000000000", "113427455640312821154458202477256070486"},
{"0x100000000000000000000000000000000", "0x2300000000000000000000000000000000", "0x800000000000000000000000000000000", "1488735355279105777652263907513985925120"},
{"0x100000000000000000000000000000000", "0x3e800000000000000000000000000000000", "0xbb800000000000000000000000000000000", "113427455640312821154458202477256070486"},
{"0x100000000000000000000000000000000", "0x80000000000000000000000000000000",
"0x180000000000000000000000000000000", "113427455640312821154458202477256070486"},
{"0x100000000000000000000000000000000", "0x2300000000000000000000000000000000",
"0x800000000000000000000000000000000", "1488735355279105777652263907513985925120"},
{"0x100000000000000000000000000000000", "0x3e800000000000000000000000000000000",
"0xbb800000000000000000000000000000000", "113427455640312821154458202477256070486"},

{"0x2a60f4810d72e89eaee06f20122f1de80adc64777e121", "0xfd21718acef075500c6395ba922064220", "0xd195e7433221b9e4b6ef3f19b457c9c9797ae6b5eaacb402113dce147e97979f", "14406918379743960"},
{"0x2a60f4810d72e89eaee06f20122f1de80adc64777e121", "0xfd21718acef075500c6395ba922064220",
"0xd195e7433221b9e4b6ef3f19b457c9c9797ae6b5eaacb402113dce147e97979f", "14406918379743960"},
}
for i, tt := range tests {
t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
Expand Down Expand Up @@ -154,8 +213,11 @@ func TestMulDivRoundingUp(t *testing.T) {
// {"0x100000000000000000000000000000000", "0x100000000000000000000000000000000", "0x0"},
{"0x100000000000000000000000000000000", "0x100000000000000000000000000000000", "0x1"},
{MaxUint256.Hex(), MaxUint256.Hex(), new(Uint256).SubUint64(MaxUint256, 1).Hex()},
{"0x1e695d2db4f97", "0x10d5effea103c44aaf18a26b449186a7de3dd6c1ce3d26d03dfd9", "0x2"}, // mulDiv overflows 256 bits after rounding up
{"0xffffffffffffffffffffffffffffffffffffffb07f6d608e4dcc38020b140b35", "0xffffffffffffffffffffffffffffffffffffffb07f6d608e4dcc38020b140b36", "0xffffffffffffffffffffffffffffffffffffff60fedac11c9b9870041628166c"}, // mulDiv overflows 256 bits after rounding up case 2
{"0x1e695d2db4f97", "0x10d5effea103c44aaf18a26b449186a7de3dd6c1ce3d26d03dfd9",
"0x2"}, // mulDiv overflows 256 bits after rounding up
{"0xffffffffffffffffffffffffffffffffffffffb07f6d608e4dcc38020b140b35",
"0xffffffffffffffffffffffffffffffffffffffb07f6d608e4dcc38020b140b36",
"0xffffffffffffffffffffffffffffffffffffff60fedac11c9b9870041628166c"}, // mulDiv overflows 256 bits after rounding up case 2
}
for i, tt := range failTests {
t.Run(fmt.Sprintf("fail test %d", i), func(t *testing.T) {
Expand Down
Loading

0 comments on commit bf7bcb8

Please sign in to comment.