Skip to content

Commit

Permalink
ft: optimize NewPoolV2 to not do unnecessary checksums (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
NgoKimPhu authored Dec 30, 2024
1 parent bf7bcb8 commit 4809511
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 22 deletions.
42 changes: 27 additions & 15 deletions entities/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ import (
"math/big"

"github.com/KyberNetwork/int256"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/constants"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/utils"
"github.com/daoleno/uniswap-sdk-core/entities"
"github.com/ethereum/go-ethereum/common"
"github.com/holiman/uint256"

"github.com/KyberNetwork/uniswapv3-sdk-uint256/constants"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/utils"
)

var (
Expand Down Expand Up @@ -69,12 +70,14 @@ type GetAmountResultV2 struct {
CrossInitTickLoops int
}

func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCodeHashManualOverride string) (common.Address, error) {
func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount,
initCodeHashManualOverride string) (common.Address, error) {
return utils.ComputePoolAddress(constants.FactoryAddress, tokenA, tokenB, fee, initCodeHashManualOverride)
}

// deprecated
func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, ticks TickDataProvider) (*Pool, error) {
func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *big.Int, liquidity *big.Int,
tickCurrent int, ticks TickDataProvider) (*Pool, error) {
return NewPoolV2(
tokenA, tokenB, fee,
uint256.MustFromBig(sqrtRatioX96),
Expand All @@ -94,7 +97,8 @@ func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX
* @param tickCurrent The current tick of the pool
* @param ticks The current state of the pool ticks or a data provider that can return tick data
*/
func NewPoolV2(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *utils.Uint160, liquidity *utils.Uint128, tickCurrent int, ticks TickDataProvider) (*Pool, error) {
func NewPoolV2(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *utils.Uint160,
liquidity *utils.Uint128, tickCurrent int, ticks TickDataProvider) (*Pool, error) {
if fee >= constants.FeeMax {
return nil, ErrFeeTooHigh
}
Expand All @@ -114,7 +118,7 @@ func NewPoolV2(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRati
}
token0 := tokenA
token1 := tokenB
isSorted, err := tokenA.SortsBefore(tokenB)
isSorted, err := SortsBefore(tokenA, tokenB)
if err != nil {
return nil, err
}
Expand All @@ -130,7 +134,7 @@ func NewPoolV2(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRati
SqrtRatioX96: sqrtRatioX96,
Liquidity: liquidity,
TickCurrent: tickCurrent,
TickDataProvider: ticks, // TODO: new tick data provider
TickDataProvider: ticks,
}, nil
}

Expand All @@ -148,7 +152,8 @@ func (p *Pool) Token0Price() *entities.Price {
if p.token0Price != nil {
return p.token0Price
}
p.token0Price = entities.NewPrice(p.Token0, p.Token1, constants.Q192, new(uint256.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96).ToBig())
p.token0Price = entities.NewPrice(p.Token0, p.Token1, constants.Q192,
new(uint256.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96).ToBig())
return p.token0Price
}

Expand All @@ -157,7 +162,8 @@ func (p *Pool) Token1Price() *entities.Price {
if p.token1Price != nil {
return p.token1Price
}
p.token1Price = entities.NewPrice(p.Token1, p.Token0, new(uint256.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96).ToBig(), constants.Q192)
p.token1Price = entities.NewPrice(p.Token1, p.Token0, new(uint256.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96).ToBig(),
constants.Q192)
return p.token1Price
}

Expand Down Expand Up @@ -187,7 +193,8 @@ func (p *Pool) ChainID() uint {
* @param sqrtPriceLimitX96 The Q64.96 sqrt price limit
* @returns The output amount and the pool with updated state
*/
func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *utils.Uint160) (*GetAmountResult, error) {
func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount,
sqrtPriceLimitX96 *utils.Uint160) (*GetAmountResult, error) {
if !(inputAmount.Currency.IsToken() && p.InvolvesToken(inputAmount.Currency.Wrapped())) {
return nil, ErrTokenNotInvolved
}
Expand Down Expand Up @@ -219,14 +226,16 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi
return nil, err
}
return &GetAmountResult{
ReturnedAmount: entities.FromRawAmount(outputToken, new(utils.Int256).Neg(swapResult.amountCalculated).ToBig()),
ReturnedAmount: entities.FromRawAmount(outputToken,
new(utils.Int256).Neg(swapResult.amountCalculated).ToBig()),
RemainingAmountIn: entities.FromRawAmount(inputAmount.Currency, swapResult.remainingAmountIn.ToBig()),
NewPoolState: pool,
CrossInitTickLoops: swapResult.crossInitTickLoops,
}, nil
}

func (p *Pool) GetOutputAmountV2(inputAmount *utils.Int256, zeroForOne bool, sqrtPriceLimitX96 *utils.Uint160) (*GetAmountResultV2, error) {
func (p *Pool) GetOutputAmountV2(inputAmount *utils.Int256, zeroForOne bool,
sqrtPriceLimitX96 *utils.Uint160) (*GetAmountResultV2, error) {
swapResult, err := p.swap(zeroForOne, inputAmount, sqrtPriceLimitX96)
if err != nil {
return nil, err
Expand All @@ -247,7 +256,8 @@ func (p *Pool) GetOutputAmountV2(inputAmount *utils.Int256, zeroForOne bool, sqr
* @param sqrtPriceLimitX96 The Q64.96 sqrt price limit. If zero for one, the price cannot be less than this value after the swap. If one for zero, the price cannot be greater than this value after the swap
* @returns The input amount and the pool with updated state
*/
func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *utils.Uint160) (*entities.CurrencyAmount, *Pool, error) {
func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount,
sqrtPriceLimitX96 *utils.Uint160) (*entities.CurrencyAmount, *Pool, error) {
if !(outputAmount.Currency.IsToken() && p.InvolvesToken(outputAmount.Currency.Wrapped())) {
return nil, nil, ErrTokenNotInvolved
}
Expand Down Expand Up @@ -292,7 +302,8 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi
* @returns swapResult.liquidity
* @returns swapResult.tickCurrent
*/
func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLimitX96 *utils.Uint160) (*SwapResult, error) {
func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLimitX96 *utils.Uint160) (*SwapResult,
error) {
var err error
if sqrtPriceLimitX96 == nil {
if zeroForOne {
Expand Down Expand Up @@ -379,7 +390,8 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLim
}

var nxtSqrtPriceX96 utils.Uint160
err = utils.ComputeSwapStep(state.sqrtPriceX96, &targetValue, state.liquidity, state.amountSpecifiedRemaining, p.Fee,
err = utils.ComputeSwapStep(state.sqrtPriceX96, &targetValue, state.liquidity, state.amountSpecifiedRemaining,
p.Fee,
&nxtSqrtPriceX96, &step.amountIn, &step.amountOut, &step.feeAmount)
if err != nil {
return nil, err
Expand Down
18 changes: 18 additions & 0 deletions entities/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package entities

import (
"bytes"

"github.com/daoleno/uniswap-sdk-core/entities"
)

// SortsBefore returns true if the address of token a sorts before the address of the token b.
func SortsBefore(a, b *entities.Token) (bool, error) {
if a.ChainId() != b.ChainId() {
return false, entities.ErrDifferentChain
}
if a.Address == b.Address {
return false, entities.ErrSameAddress
}
return bytes.Compare(a.Address[:], b.Address[:]) < 0, nil
}
41 changes: 41 additions & 0 deletions entities/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package entities

import (
"testing"

"github.com/daoleno/uniswap-sdk-core/entities"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/assert"
)

// BenchmarkSortsBefore
// BenchmarkSortsBefore/daoleno_SortsBefore
// BenchmarkSortsBefore/daoleno_SortsBefore-16 565605 2026 ns/op
// BenchmarkSortsBefore/KyberSwap_SortsBefore
// BenchmarkSortsBefore/KyberSwap_SortsBefore-16 160554603 7.101 ns/op
func BenchmarkSortsBefore(b *testing.B) {
tokenA := entities.NewToken(1, common.HexToAddress("0xB8c77482e45F1F44dE1745F52C74426C631bDD52"),
18, "BNB", "BNB")
tokenB := entities.NewToken(1, common.HexToAddress("0xb62132e35a6c13ee1ee0f84dc5d40bad8d815206"),
18, "NEXO", "Nexo")
var before bool
var err error

b.Run("daoleno SortsBefore", func(b *testing.B) {
for i := 0; i < b.N; i++ {
before, err = tokenA.SortsBefore(tokenB)
}
b.StopTimer()
assert.False(b, before)
assert.NoError(b, err)
})

b.Run("KyberSwap SortsBefore", func(b *testing.B) {
for i := 0; i < b.N; i++ {
before, err = SortsBefore(tokenA, tokenB)
}
b.StopTimer()
assert.False(b, before)
assert.NoError(b, err)
})
}
12 changes: 8 additions & 4 deletions utils/compute_pool_address.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package utils
import (
"math/big"

"github.com/KyberNetwork/uniswapv3-sdk-uint256/constants"
"github.com/daoleno/uniswap-sdk-core/entities"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"

"github.com/KyberNetwork/uniswapv3-sdk-uint256/constants"
kentities "github.com/KyberNetwork/uniswapv3-sdk-uint256/entities"
)

/**
Expand All @@ -18,8 +20,9 @@ import (
* @param fee The fee tier of the pool
* @returns The pool address
*/
func ComputePoolAddress(factoryAddress common.Address, tokenA *entities.Token, tokenB *entities.Token, fee constants.FeeAmount, initCodeHashManualOverride string) (common.Address, error) {
isSorted, err := tokenA.SortsBefore(tokenB)
func ComputePoolAddress(factoryAddress common.Address, tokenA *entities.Token, tokenB *entities.Token,
fee constants.FeeAmount, initCodeHashManualOverride string) (common.Address, error) {
isSorted, err := kentities.SortsBefore(tokenA, tokenB)
if err != nil {
return common.Address{}, err
}
Expand All @@ -37,7 +40,8 @@ func ComputePoolAddress(factoryAddress common.Address, tokenA *entities.Token, t
return getCreate2Address(factoryAddress, token0.Address, token1.Address, fee, initCodeHashManualOverride), nil
}

func getCreate2Address(factoyAddress, addressA, addressB common.Address, fee constants.FeeAmount, initCodeHashManualOverride string) common.Address {
func getCreate2Address(factoyAddress, addressA, addressB common.Address, fee constants.FeeAmount,
initCodeHashManualOverride string) common.Address {
var salt [32]byte
copy(salt[:], crypto.Keccak256(abiEncode(addressA, addressB, fee)))

Expand Down
8 changes: 5 additions & 3 deletions utils/price_tick_conversions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package utils
import (
"math/big"

"github.com/KyberNetwork/uniswapv3-sdk-uint256/constants"
"github.com/daoleno/uniswap-sdk-core/entities"

"github.com/KyberNetwork/uniswapv3-sdk-uint256/constants"
kentities "github.com/KyberNetwork/uniswapv3-sdk-uint256/entities"
)

/**
Expand All @@ -21,7 +23,7 @@ func TickToPrice(baseToken *entities.Token, quoteToken *entities.Token, tick int
}
ratioX192 := new(big.Int).Mul(sqrtRatioX96, sqrtRatioX96)

sorted, err := baseToken.SortsBefore(quoteToken)
sorted, err := kentities.SortsBefore(baseToken, quoteToken)
if err != nil {
return nil, err
}
Expand All @@ -37,7 +39,7 @@ func TickToPrice(baseToken *entities.Token, quoteToken *entities.Token, tick int
* i.e. the price of the returned tick is less than or equal to the input price
*/
func PriceToClosestTick(price *entities.Price, baseToken, quoteToken *entities.Token) (int, error) {
sorted, err := baseToken.SortsBefore(quoteToken)
sorted, err := kentities.SortsBefore(baseToken, quoteToken)
if err != nil {
return 0, err
}
Expand Down

0 comments on commit 4809511

Please sign in to comment.