diff --git a/frame/buffer_read.go b/frame/buffer_read.go index ad2178c8..e411b5e0 100644 --- a/frame/buffer_read.go +++ b/frame/buffer_read.go @@ -3,6 +3,7 @@ package frame import ( "fmt" "log" + "net/netip" ) // All the read functions call readByte or readInto as they would want to read a single byte or copy a slice of bytes. @@ -160,7 +161,9 @@ func (b *Buffer) ReadInet() Inet { log.Printf("unknown ip length") } } - return Inet{IP: b.readCopy(int(n)), Port: b.ReadInt()} + + ip, _ := netip.AddrFromSlice(b.readCopy(int(n))) + return Inet{IP: ip, Port: b.ReadInt()} } func (b *Buffer) ReadString() string { diff --git a/frame/buffer_write.go b/frame/buffer_write.go index 6ed788be..b12fb0de 100644 --- a/frame/buffer_write.go +++ b/frame/buffer_write.go @@ -119,13 +119,14 @@ func (b *Buffer) WriteValue(v Value) { } func (b *Buffer) WriteInet(v Inet) { + addr := v.IP.AsSlice() if Debug { - if l := len(v.IP); l != 4 && l != 16 { - log.Printf("unknown IP length") + if len(addr) != 4 && len(addr) != 16 { + log.Printf("unknown ip length") } } - b.WriteByte(Byte(len(v.IP))) - b.Write(v.IP) + b.WriteByte(Byte(len(addr))) + b.Write(addr) b.WriteInt(v.Port) } diff --git a/frame/cqlvalue.go b/frame/cqlvalue.go index 16c81dc2..73be14d5 100644 --- a/frame/cqlvalue.go +++ b/frame/cqlvalue.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "fmt" "math" - "net" + "net/netip" "unicode" "unicode/utf8" ) @@ -153,16 +153,17 @@ func (c CqlValue) AsText() (string, error) { return string(c.Value), nil } -func (c CqlValue) AsIP() (net.IP, error) { +func (c CqlValue) AsIP() (netip.Addr, error) { if c.Type.ID != InetID { - return nil, fmt.Errorf("%v is not of Inet type", c) + return netip.Addr{}, fmt.Errorf("%v is not of Inet type", c) } - if len(c.Value) != 4 && len(c.Value) != 16 { - return nil, fmt.Errorf("invalid ip length") + ret, ok := netip.AddrFromSlice(c.Value) + if !ok { + return netip.Addr{}, fmt.Errorf("invalid ip length") } - return c.Value, nil + return ret, nil } func (c CqlValue) AsFloat32() (float32, error) { @@ -414,17 +415,15 @@ func CqlFromTimeUUID(b [16]byte) (CqlValue, error) { return c, nil } -func CqlFromIP(ip net.IP) (CqlValue, error) { - if len(ip) != 4 || len(ip) != 16 { - return CqlValue{}, fmt.Errorf("invalid ip address") +func CqlFromIP(ip netip.Addr) (CqlValue, error) { + if ip.BitLen() == 0 { + return CqlValue{}, fmt.Errorf("zero addr is not supported") } - c := CqlValue{ + return CqlValue{ Type: &Option{ID: InetID}, - Value: make(Bytes, len(ip)), - } - copy(c.Value, ip) - return c, nil + Value: ip.AsSlice(), + }, nil } func CqlFromFloat32(v float32) CqlValue { diff --git a/frame/cqlvalue_fuzz_test.go b/frame/cqlvalue_fuzz_test.go index 47065c2b..9f948a54 100644 --- a/frame/cqlvalue_fuzz_test.go +++ b/frame/cqlvalue_fuzz_test.go @@ -3,6 +3,7 @@ package frame import ( "math" "net" + "net/netip" "testing" "github.com/google/go-cmp/cmp" @@ -142,12 +143,17 @@ func FuzzCqlValueText(f *testing.F) { } func FuzzCqlValueIP(f *testing.F) { - testCases := [][]byte{{1, 2, 3}, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 1}.To16()} + testCases := [][]byte{net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 1}.To16()} for _, tc := range testCases { f.Add(tc) } f.Fuzz(func(t *testing.T, data []byte) { - in, err := CqlFromIP(data) + ip, ok := netip.AddrFromSlice(data) + if !ok { + t.Skip() + } + + in, err := CqlFromIP(ip) if err != nil { // We skip tests with incorrect CqlValue. t.Skip() diff --git a/frame/cqlvalue_test.go b/frame/cqlvalue_test.go index f5b72035..8d4e3b05 100644 --- a/frame/cqlvalue_test.go +++ b/frame/cqlvalue_test.go @@ -2,7 +2,7 @@ package frame import ( "math" - "net" + "net/netip" "testing" "github.com/google/go-cmp/cmp" @@ -509,7 +509,7 @@ func TestCqlValueAsIP(t *testing.T) { name string content CqlValue valid bool - expected net.IP + expected netip.Addr }{ { name: "wrong length", @@ -530,19 +530,19 @@ func TestCqlValueAsIP(t *testing.T) { name: "valid v4", content: CqlValue{ Type: &Option{ID: InetID}, - Value: Bytes(net.IP{127, 0, 0, 1}), + Value: Bytes{127, 0, 0, 1}, }, valid: true, - expected: net.IP{127, 0, 0, 1}, + expected: netip.AddrFrom4([4]byte{127, 0, 0, 1}), }, { name: "valid v6", content: CqlValue{ Type: &Option{ID: InetID}, - Value: Bytes(net.IP{127, 0, 0, 1}.To16()), + Value: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1}, }, valid: true, - expected: net.IP{127, 0, 0, 1}.To16(), + expected: netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1}), }, } @@ -557,8 +557,8 @@ func TestCqlValueAsIP(t *testing.T) { } return } - if diff := cmp.Diff(v, tc.expected); diff != "" { - t.Fatalf(diff) + if v != tc.expected { + t.Fatalf("expected %v, got %v", tc.expected, v) } }) } diff --git a/frame/response/event_test.go b/frame/response/event_test.go index 6718f56f..8985b4ec 100644 --- a/frame/response/event_test.go +++ b/frame/response/event_test.go @@ -1,6 +1,7 @@ package response import ( + "net/netip" "testing" "github.com/scylladb/scylla-go-driver/frame" @@ -21,7 +22,7 @@ func TestStatusChangeEvent(t *testing.T) { // nolint:dupl // Tests are different var b frame.Buffer b.WriteString("UP") b.WriteInet(frame.Inet{ - IP: []byte{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), Port: 9042, }) return b.Bytes() @@ -29,7 +30,7 @@ func TestStatusChangeEvent(t *testing.T) { // nolint:dupl // Tests are different expected: StatusChange{ Status: "UP", Address: frame.Inet{ - IP: []byte{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), Port: 9042, }, }, @@ -42,8 +43,8 @@ func TestStatusChangeEvent(t *testing.T) { // nolint:dupl // Tests are different var buf frame.Buffer buf.Write(tc.content) a := ParseStatusChange(&buf) - if diff := cmp.Diff(*a, tc.expected); diff != "" { - t.Fatal(diff) + if *a != tc.expected { + t.Fatalf("expected %v, got %v", tc.expected, *a) } }) } @@ -62,7 +63,7 @@ func TestTopologyChangeEvent(t *testing.T) { //nolint:dupl // Tests are differen var b frame.Buffer b.WriteString("NEW_NODE") b.WriteInet(frame.Inet{ - IP: []byte{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), Port: 9042, }) return b.Bytes() @@ -70,7 +71,7 @@ func TestTopologyChangeEvent(t *testing.T) { //nolint:dupl // Tests are differen expected: TopologyChange{ Change: "NEW_NODE", Address: frame.Inet{ - IP: []byte{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), Port: 9042, }, }, @@ -83,8 +84,8 @@ func TestTopologyChangeEvent(t *testing.T) { //nolint:dupl // Tests are differen var buf frame.Buffer buf.Write(tc.content) a := ParseTopologyChange(&buf) - if diff := cmp.Diff(*a, tc.expected); diff != "" { - t.Fatal(diff) + if *a != tc.expected { + t.Fatalf("expected %v, got %v", tc.expected, *a) } }) } diff --git a/frame/types.go b/frame/types.go index deec5664..3fd87034 100644 --- a/frame/types.go +++ b/frame/types.go @@ -2,7 +2,7 @@ package frame import ( "errors" - "net" + "net/netip" ) // Generic types from CQL binary protocol. @@ -40,13 +40,13 @@ func (v Value) Clone() Value { // https://github.com/apache/cassandra/blob/adcff3f630c0d07d1ba33bf23fcb11a6db1b9af1/doc/native_protocol_v4.spec#L241-L245 type Inet struct { - IP Bytes + IP netip.Addr Port Int } // String only takes care of IP part of the address. func (i Inet) String() string { - return net.IP(i.IP).String() + return i.IP.String() } // https://github.com/apache/cassandra/blob/adcff3f630c0d07d1ba33bf23fcb11a6db1b9af1/doc/native_protocol_v4.spec#L183-L201 diff --git a/transport/cluster.go b/transport/cluster.go index e95511ff..c53c29b9 100644 --- a/transport/cluster.go +++ b/transport/cluster.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "net" + "net/netip" "sort" "strconv" "strings" @@ -320,7 +321,7 @@ func (c *Cluster) parseNodeFromRow(r frame.Row) (*Node, error) { } // Possible IP addresses starts from addrIndex in both system.local and system.peers queries. // They are grouped with decreasing priority. - var addr net.IP + var addr netip.Addr for i := addrIndex; i < len(r); i++ { addr, err = r[i].AsIP() if err == nil && !addr.IsUnspecified() { @@ -328,13 +329,16 @@ func (c *Cluster) parseNodeFromRow(r frame.Row) (*Node, error) { } else if err == nil && addr.IsUnspecified() { host, _, err := net.SplitHostPort(c.control.conn.RemoteAddr().String()) if err == nil { - addr = net.ParseIP(host) + addr, err = netip.ParseAddr(host) + if err != nil { + addr = netip.AddrFrom4([4]byte{0, 0, 0, 0}) + } break } } } - if addr == nil || addr.IsUnspecified() { - return nil, fmt.Errorf("all addr columns conatin invalid IP") + if addr.IsUnspecified() { + return nil, fmt.Errorf("all addr columns contain invalid IP") } return &Node{ hostID: hostID, diff --git a/transport/cluster_integration_test.go b/transport/cluster_integration_test.go index d8f29c13..33949d04 100644 --- a/transport/cluster_integration_test.go +++ b/transport/cluster_integration_test.go @@ -5,6 +5,7 @@ package transport import ( "context" "fmt" + "net/netip" "os/signal" "syscall" "testing" @@ -40,7 +41,7 @@ func TestClusterIntegration(t *testing.T) { defer cancel() addr := frame.Inet{ - IP: []byte{192, 168, 100, 100}, + IP: netip.MustParseAddr(TestHost), Port: 9042, }