diff --git a/session.go b/session.go index b68c2b15..4f37edd1 100644 --- a/session.go +++ b/session.go @@ -161,7 +161,7 @@ func (s *Session) Prepare(ctx context.Context, content string) (Query, error) { stmt := transport.Statement{Content: content, Consistency: frame.ALL} // Prepare on all nodes concurrently. - nodes := s.cluster.Topology().Nodes + nodes := s.cluster.Nodes() resStmt := make([]transport.Statement, len(nodes)) resErr := make([]error, len(nodes)) var wg sync.WaitGroup @@ -234,7 +234,7 @@ func (s *Session) handleAutoAwaitSchemaAgreement(ctx context.Context, stmt strin func (s *Session) CheckSchemaAgreement(ctx context.Context) (bool, error) { // Get schema version from all nodes concurrently. - nodes := s.cluster.Topology().Nodes + nodes := s.cluster.Nodes() versions := make([]frame.UUID, len(nodes)) errors := make([]error, len(nodes)) var wg sync.WaitGroup diff --git a/transport/cluster.go b/transport/cluster.go index 0a6aad56..e8d818af 100644 --- a/transport/cluster.go +++ b/transport/cluster.go @@ -38,6 +38,10 @@ type Cluster struct { queryInfoCounter atomic.Uint64 } +func (c *Cluster) Nodes() []*Node { + return c.topology.Load().Nodes +} + type topology struct { localDC string peers peerMap @@ -86,13 +90,13 @@ type QueryInfo struct { func (c *Cluster) NewQueryInfo() QueryInfo { return QueryInfo{ tokenAware: false, - topology: c.Topology(), + topology: c.topology.Load(), offset: c.generateOffset(), } } func (c *Cluster) NewTokenAwareQueryInfo(t Token, ks string) (QueryInfo, error) { - top := c.Topology() + top := c.topology.Load() // When keyspace is not specified, we take default keyspace from ConnConfig. if ks == "" { if c.cfg.Keyspace == "" { @@ -144,7 +148,7 @@ func NewCluster(ctx context.Context, cfg ConnConfig, p HostSelectionPolicy, e [] if p, ok := p.(*TokenAwarePolicy); ok { localDC = p.localDC } - c.setTopology(&topology{localDC: localDC}) + c.topology.Store(&topology{localDC: localDC}) if control, err := c.NewControl(ctx); err != nil { return nil, fmt.Errorf("create control connection: %w", err) @@ -190,9 +194,9 @@ func (c *Cluster) refreshTopology(ctx context.Context) error { return fmt.Errorf("query info about nodes in cluster: %w", err) } - old := c.Topology().peers + old := c.topology.Load().peers t := newTopology() - t.localDC = c.Topology().localDC + t.localDC = c.topology.Load().localDC t.keyspaces, err = c.updateKeyspace(ctx) if err != nil { return fmt.Errorf("query keyspaces: %w", err) @@ -247,7 +251,7 @@ func (c *Cluster) refreshTopology(ctx context.Context) error { t.policyInfo.Preprocess(t, keyspace{}) } - c.setTopology(t) + c.topology.Store(t) drainChan(c.refreshChan) return nil } @@ -442,14 +446,6 @@ func parseTokensFromRow(n *Node, r frame.Row, ring *Ring) error { return nil } -func (c *Cluster) Topology() *topology { - return c.topology.Load() -} - -func (c *Cluster) setTopology(t *topology) { - c.topology.Store(t) -} - // handleEvent creates function which is passed to control connection // via registerEvents in order to handle events right away instead // of registering handlers for them. @@ -478,7 +474,7 @@ func (c *Cluster) handleTopologyChange(v *TopologyChange) { func (c *Cluster) handleStatusChange(v *StatusChange) { log.Printf("cluster: handle status change: %+#v", v) - m := c.Topology().peers + m := c.topology.Load().peers addr := v.Address.String() if n, ok := m[addr]; ok { switch v.Status { @@ -549,7 +545,7 @@ func (c *Cluster) tryReopenControl(ctx context.Context) { func (c *Cluster) handleClose() { log.Printf("cluster: handle cluster close") c.control.Close() - m := c.Topology().peers + m := c.topology.Load().peers for _, v := range m { if v.pool != nil { v.pool.Close() diff --git a/transport/cluster_integration_test.go b/transport/cluster_integration_test.go index 3a752e2a..ab93d483 100644 --- a/transport/cluster_integration_test.go +++ b/transport/cluster_integration_test.go @@ -17,7 +17,7 @@ import ( const awaitingChanges = 100 * time.Millisecond func compareNodes(c *Cluster, addr string, expected *Node) error { - m := c.Topology().peers + m := c.topology.Load().peers got, ok := m[addr] switch { case !ok: @@ -80,7 +80,7 @@ func TestClusterIntegration(t *testing.T) { } // There should be at least system keyspaces present. - if len(c.Topology().keyspaces) == 0 { + if len(c.topology.Load().keyspaces) == 0 { t.Fatalf("Keyspaces failed to load") } diff --git a/transport/policy_test.go b/transport/policy_test.go index 966ef3c0..f7c59786 100644 --- a/transport/policy_test.go +++ b/transport/policy_test.go @@ -33,7 +33,7 @@ func mockCluster(t *topology, ks, localDC string) *Cluster { } else { t.policyInfo.Preprocess(t, keyspace{}) } - c.setTopology(t) + c.topology.Store(t) return &c }