diff --git a/go.mod b/go.mod index d269f50b..68539814 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,12 @@ module github.com/scylladb/scylla-go-driver -go 1.18 +go 1.19 require ( github.com/google/go-cmp v0.5.6 github.com/klauspost/compress v1.15.1 github.com/pierrec/lz4/v4 v4.1.14 - go.uber.org/atomic v1.9.0 + go.uber.org/atomic v1.10.0 go.uber.org/goleak v1.1.12 ) diff --git a/go.sum b/go.sum index 1a791895..2babdf6f 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/session.go b/session.go index b68c2b15..4f37edd1 100644 --- a/session.go +++ b/session.go @@ -161,7 +161,7 @@ func (s *Session) Prepare(ctx context.Context, content string) (Query, error) { stmt := transport.Statement{Content: content, Consistency: frame.ALL} // Prepare on all nodes concurrently. - nodes := s.cluster.Topology().Nodes + nodes := s.cluster.Nodes() resStmt := make([]transport.Statement, len(nodes)) resErr := make([]error, len(nodes)) var wg sync.WaitGroup @@ -234,7 +234,7 @@ func (s *Session) handleAutoAwaitSchemaAgreement(ctx context.Context, stmt strin func (s *Session) CheckSchemaAgreement(ctx context.Context) (bool, error) { // Get schema version from all nodes concurrently. - nodes := s.cluster.Topology().Nodes + nodes := s.cluster.Nodes() versions := make([]frame.UUID, len(nodes)) errors := make([]error, len(nodes)) var wg sync.WaitGroup diff --git a/transport/cluster.go b/transport/cluster.go index e95511ff..e8d818af 100644 --- a/transport/cluster.go +++ b/transport/cluster.go @@ -26,7 +26,7 @@ type ( ) type Cluster struct { - topology atomic.Value // *topology + topology atomic.Pointer[topology] control *Conn cfg ConnConfig handledEvents []frame.EventType // This will probably be moved to config. @@ -38,6 +38,10 @@ type Cluster struct { queryInfoCounter atomic.Uint64 } +func (c *Cluster) Nodes() []*Node { + return c.topology.Load().Nodes +} + type topology struct { localDC string peers peerMap @@ -86,13 +90,13 @@ type QueryInfo struct { func (c *Cluster) NewQueryInfo() QueryInfo { return QueryInfo{ tokenAware: false, - topology: c.Topology(), + topology: c.topology.Load(), offset: c.generateOffset(), } } func (c *Cluster) NewTokenAwareQueryInfo(t Token, ks string) (QueryInfo, error) { - top := c.Topology() + top := c.topology.Load() // When keyspace is not specified, we take default keyspace from ConnConfig. if ks == "" { if c.cfg.Keyspace == "" { @@ -144,7 +148,7 @@ func NewCluster(ctx context.Context, cfg ConnConfig, p HostSelectionPolicy, e [] if p, ok := p.(*TokenAwarePolicy); ok { localDC = p.localDC } - c.setTopology(&topology{localDC: localDC}) + c.topology.Store(&topology{localDC: localDC}) if control, err := c.NewControl(ctx); err != nil { return nil, fmt.Errorf("create control connection: %w", err) @@ -190,9 +194,9 @@ func (c *Cluster) refreshTopology(ctx context.Context) error { return fmt.Errorf("query info about nodes in cluster: %w", err) } - old := c.Topology().peers + old := c.topology.Load().peers t := newTopology() - t.localDC = c.Topology().localDC + t.localDC = c.topology.Load().localDC t.keyspaces, err = c.updateKeyspace(ctx) if err != nil { return fmt.Errorf("query keyspaces: %w", err) @@ -247,7 +251,7 @@ func (c *Cluster) refreshTopology(ctx context.Context) error { t.policyInfo.Preprocess(t, keyspace{}) } - c.setTopology(t) + c.topology.Store(t) drainChan(c.refreshChan) return nil } @@ -442,14 +446,6 @@ func parseTokensFromRow(n *Node, r frame.Row, ring *Ring) error { return nil } -func (c *Cluster) Topology() *topology { - return c.topology.Load().(*topology) -} - -func (c *Cluster) setTopology(t *topology) { - c.topology.Store(t) -} - // handleEvent creates function which is passed to control connection // via registerEvents in order to handle events right away instead // of registering handlers for them. @@ -478,7 +474,7 @@ func (c *Cluster) handleTopologyChange(v *TopologyChange) { func (c *Cluster) handleStatusChange(v *StatusChange) { log.Printf("cluster: handle status change: %+#v", v) - m := c.Topology().peers + m := c.topology.Load().peers addr := v.Address.String() if n, ok := m[addr]; ok { switch v.Status { @@ -549,7 +545,7 @@ func (c *Cluster) tryReopenControl(ctx context.Context) { func (c *Cluster) handleClose() { log.Printf("cluster: handle cluster close") c.control.Close() - m := c.Topology().peers + m := c.topology.Load().peers for _, v := range m { if v.pool != nil { v.pool.Close() diff --git a/transport/cluster_integration_test.go b/transport/cluster_integration_test.go index d8f29c13..ab93d483 100644 --- a/transport/cluster_integration_test.go +++ b/transport/cluster_integration_test.go @@ -17,7 +17,7 @@ import ( const awaitingChanges = 100 * time.Millisecond func compareNodes(c *Cluster, addr string, expected *Node) error { - m := c.Topology().peers + m := c.topology.Load().peers got, ok := m[addr] switch { case !ok: @@ -80,7 +80,7 @@ func TestClusterIntegration(t *testing.T) { } // There should be at least system keyspaces present. - if len(c.topology.Load().(*topology).keyspaces) == 0 { + if len(c.topology.Load().keyspaces) == 0 { t.Fatalf("Keyspaces failed to load") } diff --git a/transport/export_test.go b/transport/export_test.go index 84a4a911..f9258628 100644 --- a/transport/export_test.go +++ b/transport/export_test.go @@ -2,8 +2,8 @@ package transport func (p *ConnPool) AllConns() []*Conn { var conns = make([]*Conn, len(p.conns)) - for i, v := range p.conns { - conns[i], _ = v.Load().(*Conn) + for i := range conns { + conns[i] = p.loadConn(i) } return conns } diff --git a/transport/policy_test.go b/transport/policy_test.go index d00f0d2a..f7c59786 100644 --- a/transport/policy_test.go +++ b/transport/policy_test.go @@ -33,7 +33,7 @@ func mockCluster(t *topology, ks, localDC string) *Cluster { } else { t.policyInfo.Preprocess(t, keyspace{}) } - c.setTopology(t) + c.topology.Store(t) return &c } @@ -139,16 +139,16 @@ func TestDCAwareRoundRobinPolicy(t *testing.T) { //nolint:paralleltest // Can't } /* - mockTopologyTokenAwareSimpleStrategy creates cluster topology with info about 3 nodes living in the same datacenter. +mockTopologyTokenAwareSimpleStrategy creates cluster topology with info about 3 nodes living in the same datacenter. - Ring field is populated as follows: - ring tokens: 50 100 150 200 250 300 400 500 - corresponding node ids: 2 1 2 3 1 2 3 1 +Ring field is populated as follows: +ring tokens: 50 100 150 200 250 300 400 500 +corresponding node ids: 2 1 2 3 1 2 3 1 - Keyspaces: - names: "rf2" "rf3" - strategies: simple simple - rep factors: 2 3 +Keyspaces: +names: "rf2" "rf3" +strategies: simple simple +rep factors: 2 3. */ func mockTopologyTokenAwareSimpleStrategy() *topology { dummyNodes := []*Node{ @@ -239,24 +239,24 @@ func TestTokenAwareSimpleStrategyPolicy(t *testing.T) { //nolint:paralleltest // } /* - mockTopologyTokenAwareNetworkStrategy creates cluster topology with info about 8 nodes - living in two different datacenters. +mockTopologyTokenAwareNetworkStrategy creates cluster topology with info about 8 nodes +living in two different datacenters. - Ring field is populated as follows: - ring tokens: 50 100 150 200 250 300 400 500 510 - corresponding node ids: 1 5 2 1 6 4 8 7 3 +Ring field is populated as follows: +ring tokens: 50 100 150 200 250 300 400 500 510 +corresponding node ids: 1 5 2 1 6 4 8 7 3 - Datacenter: waw - nodes in rack r1: 1 2 - nodes in rack r2: 3 4 +Datacenter: waw +nodes in rack r1: 1 2 +nodes in rack r2: 3 4 - Datacenter: her - nodes in rack r3: 5 6 - nodes in rack r4: 7 8 +Datacenter: her +nodes in rack r3: 5 6 +nodes in rack r4: 7 8 - Keyspace: "waw/her" - strategy: network topology - replication factors: waw: 2 her: 3 +Keyspace: "waw/her" +strategy: network topology +replication factors: waw: 2 her: 3. */ func mockTopologyTokenAwareDCAwareStrategy() *topology { dummyNodes := []*Node{ diff --git a/transport/pool.go b/transport/pool.go index b41a21d4..f0f770cc 100644 --- a/transport/pool.go +++ b/transport/pool.go @@ -19,7 +19,7 @@ type ConnPool struct { host string nrShards int msbIgnore uint8 - conns []atomic.Value + conns []atomic.Pointer[Conn] connClosedCh chan int // notification channel for when connection is closed connObs ConnObserver } @@ -99,13 +99,11 @@ func (p *ConnPool) storeConn(conn *Conn) { } func (p *ConnPool) loadConn(shard int) *Conn { - conn, _ := p.conns[shard].Load().(*Conn) - return conn + return p.conns[shard].Load() } func (p *ConnPool) clearConn(shard int) bool { - conn, _ := p.conns[shard].Swap((*Conn)(nil)).(*Conn) - return conn != nil + return p.conns[shard].Swap(nil) != nil } func (p *ConnPool) Close() { @@ -115,7 +113,7 @@ func (p *ConnPool) Close() { // closeAll is called by PoolRefiller. func (p *ConnPool) closeAll() { for i := range p.conns { - if conn, ok := p.conns[i].Swap((*Conn)(nil)).(*Conn); ok { + if conn := p.conns[i].Swap(nil); conn != nil { conn.Close() } } @@ -168,7 +166,7 @@ func (r *PoolRefiller) init(ctx context.Context, host string) error { host: host, nrShards: int(ss.NrShards), msbIgnore: ss.MsbIgnore, - conns: make([]atomic.Value, int(ss.NrShards)), + conns: make([]atomic.Pointer[Conn], int(ss.NrShards)), connClosedCh: make(chan int, int(ss.NrShards)+1), connObs: r.cfg.ConnObserver, }