diff --git a/conf/config.go b/conf/config.go index b600456da..d44a2d22e 100644 --- a/conf/config.go +++ b/conf/config.go @@ -9,11 +9,13 @@ import ( "crypto/rand" "crypto/subtle" "encoding/base64" + "errors" "fmt" - "net" "strings" "time" + "golang.zx2c4.com/go118/netip" + "golang.org/x/crypto/curve25519" "golang.zx2c4.com/wireguard/windows/l18n" @@ -22,8 +24,7 @@ import ( const KeyLength = 32 type IPCidr struct { - IP net.IP - Cidr uint8 + netip.Prefix } type Endpoint struct { @@ -46,7 +47,7 @@ type Interface struct { Addresses []IPCidr ListenPort uint16 MTU uint16 - DNS []net.IP + DNS []netip.Addr DNSSearch []string PreUp string PostUp string @@ -67,62 +68,28 @@ type Peer struct { LastHandshakeTime HandshakeTime } -func (r *IPCidr) String() string { - return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr) -} - -func (r *IPCidr) Bits() uint8 { - if r.IP.To4() != nil { - return 32 - } - return 128 -} - -func (r *IPCidr) IPNet() net.IPNet { - return net.IPNet{ - IP: r.IP, - Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())), - } -} - -func (r *IPCidr) MaskSelf() { - bits := int(r.Bits()) - mask := net.CIDRMask(int(r.Cidr), bits) - for i := 0; i < bits/8; i++ { - r.IP[i] &= mask[i] - } -} - func (conf *Config) IntersectsWith(other *Config) bool { - type hashableIPCidr struct { - ip string - cidr byte - } - allRoutes := make(map[hashableIPCidr]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3) + allRoutes := make(map[netip.Prefix]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3) for _, a := range conf.Interface.Addresses { - allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] = true - a.MaskSelf() - allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true + allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] = true + allRoutes[a.Masked()] = true } for i := range conf.Peers { for _, a := range conf.Peers[i].AllowedIPs { - a.MaskSelf() - allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true + allRoutes[a.Masked()] = true } } for _, a := range other.Interface.Addresses { - if allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] { + if allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] { return true } - a.MaskSelf() - if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] { + if allRoutes[a.Masked()] { return true } } for i := range other.Peers { for _, a := range other.Peers[i].AllowedIPs { - a.MaskSelf() - if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] { + if allRoutes[a.Masked()] { return true } } @@ -233,6 +200,27 @@ func (b Bytes) String() string { return l18n.Sprintf("%.2f\u00a0TiB", float64(b)/(1024*1024*1024)/1024) } +func (p IPCidr) MarshalBinary() ([]byte, error) { + b, err := p.Addr().MarshalBinary() + if err != nil { + return nil, err + } + return append(b, uint8(p.Bits())), nil +} + +func (p *IPCidr) UnmarshalBinary(b []byte) error { + if len(b) < 1 { + return errors.New("unexpected byte slice") + } + var addr netip.Addr + err := addr.UnmarshalBinary(b[:len(b)-1]) + if err != nil { + return err + } + *p = IPCidr{netip.PrefixFrom(addr, int(b[len(b)-1]))} + return nil +} + func (conf *Config) DeduplicateNetworkEntries() { m := make(map[string]bool, len(conf.Interface.Addresses)) i := 0 diff --git a/conf/dnsresolver_windows.go b/conf/dnsresolver_windows.go index 094b10291..8a16aa855 100644 --- a/conf/dnsresolver_windows.go +++ b/conf/dnsresolver_windows.go @@ -8,11 +8,12 @@ package conf import ( "fmt" "log" - "net" "syscall" "time" "unsafe" + "golang.zx2c4.com/go118/netip" + "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/services" ) @@ -66,24 +67,24 @@ func resolveHostnameOnce(name string) (resolvedIPString string, err error) { return } defer windows.FreeAddrInfoW(result) - ipv6 := "" + var v6 netip.Addr for ; result != nil; result = result.Next { switch result.Family { case windows.AF_INET: - return (net.IP)((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr[:]).String(), nil + return netip.AddrFrom4((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr).String(), nil case windows.AF_INET6: - if len(ipv6) != 0 { + if v6.IsValid() { continue } a := (*syscall.RawSockaddrInet6)(unsafe.Pointer(result.Addr)) - ipv6 = (net.IP)(a.Addr[:]).String() + v6 = netip.AddrFrom16(a.Addr) if a.Scope_id != 0 { - ipv6 += fmt.Sprintf("%%%d", a.Scope_id) + v6 = v6.WithZone(fmt.Sprint(a.Scope_id)) } } } - if len(ipv6) != 0 { - return ipv6, nil + if v6.IsValid() { + return v6.String(), nil } err = windows.WSAHOST_NOT_FOUND return diff --git a/conf/parser.go b/conf/parser.go index 83f259646..1199c2e78 100644 --- a/conf/parser.go +++ b/conf/parser.go @@ -7,10 +7,11 @@ package conf import ( "encoding/base64" - "net" "strconv" "strings" + "golang.zx2c4.com/go118/netip" + "golang.org/x/sys/windows" "golang.org/x/text/encoding/unicode" @@ -27,43 +28,16 @@ func (e *ParseError) Error() string { return l18n.Sprintf("%s: %q", e.why, e.offender) } -func parseIPCidr(s string) (ipcidr *IPCidr, err error) { - var addrStr, cidrStr string - var cidr int - - i := strings.IndexByte(s, '/') - if i < 0 { - addrStr = s - } else { - addrStr, cidrStr = s[:i], s[i+1:] - } - - err = &ParseError{l18n.Sprintf("Invalid IP address"), s} - addr := net.ParseIP(addrStr) - if addr == nil { - return - } - maybeV4 := addr.To4() - if maybeV4 != nil { - addr = maybeV4 +func parseIPCidr(s string) (IPCidr, error) { + ipcidr, err := netip.ParsePrefix(s) + if err == nil { + return IPCidr{ipcidr}, nil } - if len(cidrStr) > 0 { - err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s} - cidr, err = strconv.Atoi(cidrStr) - if err != nil || cidr < 0 || cidr > 128 { - return - } - if cidr > 32 && maybeV4 != nil { - return - } - } else { - if maybeV4 != nil { - cidr = 32 - } else { - cidr = 128 - } + addr, err := netip.ParseAddr(s) + if err != nil { + return IPCidr{}, &ParseError{l18n.Sprintf("Invalid IP address: "), s} } - return &IPCidr{addr, uint8(cidr)}, nil + return IPCidr{netip.PrefixFrom(addr, addr.BitLen())}, nil } func parseEndpoint(s string) (*Endpoint, error) { @@ -87,8 +61,8 @@ func parseEndpoint(s string) (*Endpoint, error) { if i := strings.LastIndexByte(host, '%'); i > 1 { end = i } - maybeV6 := net.ParseIP(host[1:end]) - if maybeV6 == nil || len(maybeV6) != net.IPv6len { + maybeV6, err2 := netip.ParseAddr(host[1:end]) + if err2 != nil || !maybeV6.Is6() { return nil, err } } else { @@ -96,7 +70,7 @@ func parseEndpoint(s string) (*Endpoint, error) { } host = host[1 : len(host)-1] } - return &Endpoint{host, uint16(port)}, nil + return &Endpoint{host, port}, nil } func parseMTU(s string) (uint16, error) { @@ -256,7 +230,7 @@ func FromWgQuick(s string, name string) (*Config, error) { if err != nil { return nil, err } - conf.Interface.Addresses = append(conf.Interface.Addresses, *a) + conf.Interface.Addresses = append(conf.Interface.Addresses, a) } case "dns": addresses, err := splitList(val) @@ -264,8 +238,8 @@ func FromWgQuick(s string, name string) (*Config, error) { return nil, err } for _, address := range addresses { - a := net.ParseIP(address) - if a == nil { + a, err := netip.ParseAddr(address) + if err != nil { conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address) } else { conf.Interface.DNS = append(conf.Interface.DNS, a) @@ -312,7 +286,7 @@ func FromWgQuick(s string, name string) (*Config, error) { if err != nil { return nil, err } - peer.AllowedIPs = append(peer.AllowedIPs, *a) + peer.AllowedIPs = append(peer.AllowedIPs, a) } case "persistentkeepalive": p, err := parsePersistentKeepalive(val) @@ -399,7 +373,7 @@ func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config } if p.Flags&driver.PeerHasEndpoint != 0 { peer.Endpoint.Port = p.Endpoint.Port() - peer.Endpoint.Host = p.Endpoint.IP().String() + peer.Endpoint.Host = p.Endpoint.Addr().String() } if p.Flags&driver.PeerHasPersistentKeepalive != 0 { peer.PersistentKeepalive = p.PersistentKeepalive @@ -416,16 +390,13 @@ func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config } else { a = a.NextAllowedIP() } - var ip net.IP + var ip netip.Addr if a.AddressFamily == windows.AF_INET { - ip = a.Address[:4] + ip = netip.AddrFrom4(*(*[4]byte)(a.Address[:4])) } else if a.AddressFamily == windows.AF_INET6 { - ip = a.Address[:16] + ip = netip.AddrFrom16(*(*[16]byte)(a.Address[:16])) } - peer.AllowedIPs = append(peer.AllowedIPs, IPCidr{ - IP: ip, - Cidr: a.Cidr, - }) + peer.AllowedIPs = append(peer.AllowedIPs, IPCidr{netip.PrefixFrom(ip, int(a.Cidr))}) } conf.Peers = append(conf.Peers, peer) } diff --git a/conf/parser_test.go b/conf/parser_test.go index f80d6d186..c2f757c74 100644 --- a/conf/parser_test.go +++ b/conf/parser_test.go @@ -6,10 +6,11 @@ package conf import ( - "net" "reflect" "runtime" "testing" + + "golang.zx2c4.com/go118/netip" ) const testInput = ` @@ -77,10 +78,9 @@ func contains(t *testing.T, list, element interface{}) bool { func TestFromWgQuick(t *testing.T) { conf, err := FromWgQuick(testInput, "test") if noError(t, err) { - lenTest(t, conf.Interface.Addresses, 2) - contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)}) - contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)}) + contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{0, 10, 0, 1}), 16)) + contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 192, 122, 1}), 24)) equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String()) equal(t, uint16(51820), conf.Interface.ListenPort) diff --git a/conf/writer.go b/conf/writer.go index 3e24559f2..5b7b4a759 100644 --- a/conf/writer.go +++ b/conf/writer.go @@ -7,10 +7,11 @@ package conf import ( "fmt" - "net" "strings" "unsafe" + "golang.zx2c4.com/go118/netip" + "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" @@ -111,8 +112,11 @@ func (config *Config) ToDriverConfiguration() (*driver.Interface, uint32) { } var endpoint winipcfg.RawSockaddrInet if !config.Peers[i].Endpoint.IsEmpty() { - flags |= driver.PeerHasEndpoint - endpoint.SetIP(net.ParseIP(config.Peers[i].Endpoint.Host), config.Peers[i].Endpoint.Port) + addr, err := netip.ParseAddr(config.Peers[i].Endpoint.Host) + if err == nil { + flags |= driver.PeerHasEndpoint + endpoint.SetAddrPort(netip.AddrPortFrom(addr, config.Peers[i].Endpoint.Port)) + } } c.AppendPeer(&driver.Peer{ Flags: flags, @@ -123,20 +127,19 @@ func (config *Config) ToDriverConfiguration() (*driver.Interface, uint32) { AllowedIPsCount: uint32(len(config.Peers[i].AllowedIPs)), }) for j := range config.Peers[i].AllowedIPs { - var family winipcfg.AddressFamily - var ip net.IP - if ip = config.Peers[i].AllowedIPs[j].IP.To4(); ip != nil { - family = windows.AF_INET - } else if ip = config.Peers[i].AllowedIPs[j].IP.To16(); ip != nil { - family = windows.AF_INET6 - } else { - ip = config.Peers[i].AllowedIPs[j].IP - } a := &driver.AllowedIP{ - AddressFamily: family, - Cidr: config.Peers[i].AllowedIPs[j].Cidr, + Address: config.Peers[i].AllowedIPs[j].Addr().As16(), + Cidr: uint8(config.Peers[i].AllowedIPs[j].Bits()), + } + if config.Peers[i].AllowedIPs[j].Addr().Is4() { + a.AddressFamily = windows.AF_INET + ip := config.Peers[i].AllowedIPs[j].Addr().As4() + copy(a.Address[:], ip[:]) + } else if config.Peers[i].AllowedIPs[j].Addr().Is6() { + a.AddressFamily = windows.AF_INET6 + ip := config.Peers[i].AllowedIPs[j].Addr().As16() + copy(a.Address[:], ip[:]) } - copy(a.Address[:], ip) c.AppendAllowedIP(a) } } diff --git a/go.mod b/go.mod index 55779a430..b681ae3ed 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,10 @@ require ( github.com/lxn/walk v0.0.0-20210112085537-c389da54e794 github.com/lxn/win v0.0.0-20210218163916-a377121e959e golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/net v0.0.0-20211029160332-540bb53d3b2e - golang.org/x/sys v0.0.0-20211029165221-6e7872819dc8 - golang.org/x/text v0.3.8-0.20211029042148-bb1c79828956 + golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 + golang.org/x/sys v0.0.0-20211102061401-a2f17f7b995c + golang.org/x/text v0.3.8-0.20211102165214-8da7c0fd2b03 + golang.zx2c4.com/go118/netip v0.0.0-20211102181655-912d2728c2f2 ) require ( diff --git a/go.mod.master b/go.mod.master index b106dfea4..897e97fa7 100644 --- a/go.mod.master +++ b/go.mod.master @@ -9,6 +9,7 @@ require ( golang.org/x/net latest golang.org/x/sys latest golang.org/x/text master + golang.zx2c4.com/go118/netip master ) replace ( diff --git a/go.sum b/go.sum index 58d5ab145..7791ac1fe 100644 --- a/go.sum +++ b/go.sum @@ -6,11 +6,11 @@ golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20211029160332-540bb53d3b2e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.8-0.20211029042148-bb1c79828956 h1:xw/3G76i8BwoCoEZ8RzhVpFrHEz4Qm9D7zPckwa7KVM= -golang.org/x/text v0.3.8-0.20211029042148-bb1c79828956/go.mod h1:EFNZuWvGYxIRUEX+K8UmCFwYmZjqcrnq15ZuVldZkZ0= +golang.org/x/text v0.3.8-0.20211102165214-8da7c0fd2b03 h1:UN1T9lGOePBbThtjpZ+qBS2MiQ/rkMJvVpUeaLDvwZY= +golang.org/x/text v0.3.8-0.20211102165214-8da7c0fd2b03/go.mod h1:EFNZuWvGYxIRUEX+K8UmCFwYmZjqcrnq15ZuVldZkZ0= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ= golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= @@ -18,6 +18,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/go118/netip v0.0.0-20211102181655-912d2728c2f2 h1:vYWh5bxKFFZ9HoHbDuIEGEYulJgFea41t11f4AU4zoc= +golang.zx2c4.com/go118/netip v0.0.0-20211102181655-912d2728c2f2/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= golang.zx2c4.com/wireguard/windows v0.0.0-20210121140954-e7fc19d483bd h1:kAUzMAITME2MCtrXBaUa9P4tndiXGWO674k9gn6ZR28= golang.zx2c4.com/wireguard/windows v0.0.0-20210121140954-e7fc19d483bd/go.mod h1:Y+FYqVFaQO6a+1uigm0N0GiuaZrLEaBxEiJ8tfH9sMQ= golang.zx2c4.com/wireguard/windows v0.0.0-20210224134948-620c54ef6199 h1:ogXKLng/Myrt2odYTkleySGzQj/GWg9GV1AQ8P9NnU4= diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go index f315cd15f..97d186618 100644 --- a/tunnel/addressconfig.go +++ b/tunnel/addressconfig.go @@ -6,13 +6,12 @@ package tunnel import ( - "bytes" "fmt" "log" - "net" - "sort" "time" + "golang.zx2c4.com/go118/netip" + "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/services" @@ -20,19 +19,13 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) { +func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) { if len(addresses) == 0 { return } - addrToStr := func(ip *net.IP) string { - if ip4 := ip.To4(); ip4 != nil { - return string(ip4) - } - return string(*ip) - } - addrHash := make(map[string]bool, len(addresses)) + addrHash := make(map[netip.Addr]bool, len(addresses)) for i := range addresses { - addrHash[addrToStr(&addresses[i].IP)] = true + addrHash[addresses[i].Addr()] = true } interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault) if err != nil { @@ -43,11 +36,11 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add continue } for address := iface.FirstUnicastAddress; address != nil; address = address.Next { - ip := address.Address.IP() - if addrHash[addrToStr(&ip)] { - ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))} - log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", ipnet.String(), iface.FriendlyName()) - iface.LUID.DeleteIPAddress(ipnet) + ip, ok := netip.AddrFromSlice(address.Address.IP()) + if ok && addrHash[ip] { + prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength)) + log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", prefix.String(), iface.FriendlyName()) + iface.LUID.DeleteIPAddress(prefix) } } } @@ -69,14 +62,14 @@ startOver: for _, peer := range conf.Peers { estimatedRouteCount += len(peer.AllowedIPs) } - routes := make([]winipcfg.RouteData, 0, estimatedRouteCount) - addresses := make([]net.IPNet, len(conf.Interface.Addresses)) + routes := make(map[winipcfg.RouteData]bool, estimatedRouteCount) + addresses := make([]netip.Prefix, len(conf.Interface.Addresses)) var haveV4Address, haveV6Address bool for i, addr := range conf.Interface.Addresses { - addresses[i] = addr.IPNet() - if addr.Bits() == 32 { + addresses[i] = addr.Prefix + if addr.Addr().Is4() { haveV4Address = true - } else if addr.Bits() == 128 { + } else if addr.Addr().Is6() { haveV6Address = true } } @@ -85,53 +78,32 @@ startOver: foundDefault6 := false for _, peer := range conf.Peers { for _, allowedip := range peer.AllowedIPs { - allowedip.MaskSelf() - if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) { + if (allowedip.Addr().Is4() && !haveV4Address) || (allowedip.Addr().Is6() && !haveV6Address) { continue } route := winipcfg.RouteData{ - Destination: allowedip.IPNet(), + Destination: allowedip.Masked(), Metric: 0, } - if allowedip.Bits() == 32 { - if allowedip.Cidr == 0 { + if allowedip.Addr().Is4() { + if allowedip.Bits() == 0 { foundDefault4 = true } - route.NextHop = net.IPv4zero - } else if allowedip.Bits() == 128 { - if allowedip.Cidr == 0 { + route.NextHop = netip.AddrFrom4([4]byte{}) + } else if allowedip.Addr().Is6() { + if allowedip.Bits() == 0 { foundDefault6 = true } - route.NextHop = net.IPv6zero + route.NextHop = netip.AddrFrom16([16]byte{}) } - routes = append(routes, route) + routes[route] = true } } deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes)) - sort.Slice(routes, func(i, j int) bool { - if routes[i].Metric != routes[j].Metric { - return routes[i].Metric < routes[j].Metric - } - if c := bytes.Compare(routes[i].NextHop, routes[j].NextHop); c != 0 { - return c < 0 - } - if c := bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP); c != 0 { - return c < 0 - } - if c := bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask); c != 0 { - return c < 0 - } - return false - }) - for i := 0; i < len(routes); i++ { - if i > 0 && routes[i].Metric == routes[i-1].Metric && - bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) && - bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) && - bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) { - continue - } - deduplicatedRoutes = append(deduplicatedRoutes, &routes[i]) + for route := range routes { + r := route + deduplicatedRoutes = append(deduplicatedRoutes, &r) } if !conf.Interface.TableOff { @@ -189,14 +161,8 @@ startOver: func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error { doNotRestrict := true if len(conf.Peers) == 1 && !conf.Interface.TableOff { - nextallowedip: for _, allowedip := range conf.Peers[0].AllowedIPs { - if allowedip.Cidr == 0 { - for _, b := range allowedip.IP { - if b != 0 { - continue nextallowedip - } - } + if allowedip.Bits() == 0 && allowedip.Prefix == allowedip.Masked() { doNotRestrict = false break } diff --git a/tunnel/deterministicguid.go b/tunnel/deterministicguid.go index 455deaeb0..afdab11ee 100644 --- a/tunnel/deterministicguid.go +++ b/tunnel/deterministicguid.go @@ -80,13 +80,13 @@ func deterministicGUID(c *conf.Config) *windows.GUID { b2Number(len(peer.AllowedIPs)) sortedAllowedIPs := peer.AllowedIPs sort.Slice(sortedAllowedIPs, func(i, j int) bool { - if bi, bj := sortedAllowedIPs[i].Bits(), sortedAllowedIPs[j].Bits(); bi != bj { + if bi, bj := sortedAllowedIPs[i].Addr().BitLen(), sortedAllowedIPs[j].Addr().BitLen(); bi != bj { return bi < bj } - if sortedAllowedIPs[i].Cidr != sortedAllowedIPs[j].Cidr { - return sortedAllowedIPs[i].Cidr < sortedAllowedIPs[j].Cidr + if sortedAllowedIPs[i].Bits() != sortedAllowedIPs[j].Bits() { + return sortedAllowedIPs[i].Bits() < sortedAllowedIPs[j].Bits() } - return bytes.Compare(sortedAllowedIPs[i].IP[:], sortedAllowedIPs[j].IP[:]) < 0 + return sortedAllowedIPs[i].Addr().Compare(sortedAllowedIPs[j].Addr()) < 0 }) for _, allowedip := range sortedAllowedIPs { b2String(allowedip.String()) diff --git a/tunnel/firewall/blocker.go b/tunnel/firewall/blocker.go index 2cb2c7f2b..4be62aa98 100644 --- a/tunnel/firewall/blocker.go +++ b/tunnel/firewall/blocker.go @@ -7,9 +7,10 @@ package firewall import ( "errors" - "net" "unsafe" + "golang.zx2c4.com/go118/netip" + "golang.org/x/sys/windows" ) @@ -101,7 +102,7 @@ func registerBaseObjects(session uintptr) (*baseObjects, error) { return bo, nil } -func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []net.IP) error { +func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []netip.Addr) error { if wfpSession != 0 { return errors.New("The firewall has already been enabled") } diff --git a/tunnel/firewall/rules.go b/tunnel/firewall/rules.go index c4488a317..7d0ed0e73 100644 --- a/tunnel/firewall/rules.go +++ b/tunnel/firewall/rules.go @@ -8,10 +8,11 @@ package firewall import ( "encoding/binary" "errors" - "net" "runtime" "unsafe" + "golang.zx2c4.com/go118/netip" + "golang.org/x/sys/windows" ) @@ -985,7 +986,7 @@ func blockAll(session uintptr, baseObjects *baseObjects, weight uint8) error { } // Block all DNS traffic except towards specified DNS servers. -func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weightAllow uint8, weightDeny uint8) error { +func blockDNS(except []netip.Addr, session uintptr, baseObjects *baseObjects, weightAllow uint8, weightDeny uint8) error { if weightDeny >= weightAllow { return errors.New("The allow weight must be greater than the deny weight") } @@ -1106,16 +1107,16 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight allowConditionsV4 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except)) allowConditionsV4 = append(allowConditionsV4, denyConditions...) for _, ip := range except { - ip4 := ip.To4() - if ip4 == nil { + if !ip.Is4() { continue } + ip4 := ip.As4() allowConditionsV4 = append(allowConditionsV4, wtFwpmFilterCondition0{ fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS, matchType: cFWP_MATCH_EQUAL, conditionValue: wtFwpConditionValue0{ _type: cFWP_UINT32, - value: uintptr(binary.BigEndian.Uint32(ip4)), + value: uintptr(binary.BigEndian.Uint32(ip4[:])), }, }) } @@ -1124,11 +1125,10 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight allowConditionsV6 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except)) allowConditionsV6 = append(allowConditionsV6, denyConditions...) for _, ip := range except { - if ip.To4() != nil { + if !ip.Is6() { continue } - var address wtFwpByteArray16 - copy(address.byteArray16[:], ip) + address := wtFwpByteArray16{byteArray16: ip.As16()} allowConditionsV6 = append(allowConditionsV6, wtFwpmFilterCondition0{ fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS, matchType: cFWP_MATCH_EQUAL, diff --git a/tunnel/winipcfg/luid.go b/tunnel/winipcfg/luid.go index ca388acce..744fee627 100644 --- a/tunnel/winipcfg/luid.go +++ b/tunnel/winipcfg/luid.go @@ -7,9 +7,10 @@ package winipcfg import ( "errors" - "net" "strings" + "golang.zx2c4.com/go118/netip" + "golang.org/x/sys/windows" ) @@ -76,10 +77,10 @@ func LUIDFromIndex(index uint32) (LUID, error) { // IPAddress method returns MibUnicastIPAddressRow struct that matches to provided 'ip' argument. Corresponds to GetUnicastIpAddressEntry // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry) -func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) { +func (luid LUID) IPAddress(addr netip.Addr) (*MibUnicastIPAddressRow, error) { row := &MibUnicastIPAddressRow{InterfaceLUID: luid} - err := row.Address.SetIP(ip, 0) + err := row.Address.SetAddr(addr) if err != nil { return nil, err } @@ -94,25 +95,24 @@ func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) { // AddIPAddress method adds new unicast IP address to the interface. Corresponds to CreateUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry). -func (luid LUID) AddIPAddress(address net.IPNet) error { +func (luid LUID) AddIPAddress(address netip.Prefix) error { row := &MibUnicastIPAddressRow{} row.Init() row.InterfaceLUID = luid row.DadState = DadStatePreferred row.ValidLifetime = 0xffffffff row.PreferredLifetime = 0xffffffff - err := row.Address.SetIP(address.IP, 0) + err := row.Address.SetAddr(address.Addr()) if err != nil { return err } - ones, _ := address.Mask.Size() - row.OnLinkPrefixLength = uint8(ones) + row.OnLinkPrefixLength = uint8(address.Bits()) return row.Create() } // AddIPAddresses method adds multiple new unicast IP addresses to the interface. Corresponds to CreateUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry). -func (luid LUID) AddIPAddresses(addresses []net.IPNet) error { +func (luid LUID) AddIPAddresses(addresses []netip.Prefix) error { for i := range addresses { err := luid.AddIPAddress(addresses[i]) if err != nil { @@ -123,7 +123,7 @@ func (luid LUID) AddIPAddresses(addresses []net.IPNet) error { } // SetIPAddresses method sets new unicast IP addresses to the interface. -func (luid LUID) SetIPAddresses(addresses []net.IPNet) error { +func (luid LUID) SetIPAddresses(addresses []netip.Prefix) error { err := luid.FlushIPAddresses(windows.AF_UNSPEC) if err != nil { return err @@ -132,16 +132,15 @@ func (luid LUID) SetIPAddresses(addresses []net.IPNet) error { } // SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface. -func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error { +func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []netip.Prefix) error { err := luid.FlushIPAddresses(family) if err != nil { return err } for i := range addresses { - asV4 := addresses[i].IP.To4() - if asV4 == nil && family == windows.AF_INET { + if !addresses[i].Addr().Is4() && family == windows.AF_INET { continue - } else if asV4 != nil && family == windows.AF_INET6 { + } else if !addresses[i].Addr().Is6() && family == windows.AF_INET6 { continue } err := luid.AddIPAddress(addresses[i]) @@ -154,17 +153,16 @@ func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.I // DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry). -func (luid LUID) DeleteIPAddress(address net.IPNet) error { +func (luid LUID) DeleteIPAddress(address netip.Prefix) error { row := &MibUnicastIPAddressRow{} row.Init() row.InterfaceLUID = luid - err := row.Address.SetIP(address.IP, 0) + err := row.Address.SetAddr(address.Addr()) if err != nil { return err } // Note: OnLinkPrefixLength member is ignored by DeleteUnicastIpAddressEntry(). - ones, _ := address.Mask.Size() - row.OnLinkPrefixLength = uint8(ones) + row.OnLinkPrefixLength = uint8(address.Bits()) return row.Delete() } @@ -188,17 +186,17 @@ func (luid LUID) FlushIPAddresses(family AddressFamily) error { // Route method returns route determined with the input arguments. Corresponds to GetIpForwardEntry2 function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2). // NOTE: If the corresponding route isn't found, the method will return error. -func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2, error) { +func (luid LUID) Route(destination netip.Prefix, nextHop netip.Addr) (*MibIPforwardRow2, error) { row := &MibIPforwardRow2{} row.Init() row.InterfaceLUID = luid row.ValidLifetime = 0xffffffff row.PreferredLifetime = 0xffffffff - err := row.DestinationPrefix.SetIPNet(destination) + err := row.DestinationPrefix.SetPrefix(destination) if err != nil { return nil, err } - err = row.NextHop.SetIP(nextHop, 0) + err = row.NextHop.SetAddr(nextHop) if err != nil { return nil, err } @@ -212,15 +210,15 @@ func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2 // AddRoute method adds a route to the interface. Corresponds to CreateIpForwardEntry2 function, with added splitDefault feature. // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2) -func (luid LUID) AddRoute(destination net.IPNet, nextHop net.IP, metric uint32) error { +func (luid LUID) AddRoute(destination netip.Prefix, nextHop netip.Addr, metric uint32) error { row := &MibIPforwardRow2{} row.Init() row.InterfaceLUID = luid - err := row.DestinationPrefix.SetIPNet(destination) + err := row.DestinationPrefix.SetPrefix(destination) if err != nil { return err } - err = row.NextHop.SetIP(nextHop, 0) + err = row.NextHop.SetAddr(nextHop) if err != nil { return err } @@ -255,10 +253,9 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat return err } for _, rd := range routesData { - asV4 := rd.Destination.IP.To4() - if asV4 == nil && family == windows.AF_INET { + if !rd.Destination.Addr().Is4() && family == windows.AF_INET { continue - } else if asV4 != nil && family == windows.AF_INET6 { + } else if !rd.Destination.Addr().Is6() && family == windows.AF_INET6 { continue } err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric) @@ -271,15 +268,15 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat // DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2). -func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error { +func (luid LUID) DeleteRoute(destination netip.Prefix, nextHop netip.Addr) error { row := &MibIPforwardRow2{} row.Init() row.InterfaceLUID = luid - err := row.DestinationPrefix.SetIPNet(destination) + err := row.DestinationPrefix.SetPrefix(destination) if err != nil { return err } - err = row.NextHop.SetIP(nextHop, 0) + err = row.NextHop.SetAddr(nextHop) if err != nil { return err } @@ -312,17 +309,19 @@ func (luid LUID) FlushRoutes(family AddressFamily) error { } // DNS method returns all DNS server addresses associated with the adapter. -func (luid LUID) DNS() ([]net.IP, error) { +func (luid LUID) DNS() ([]netip.Addr, error) { addresses, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault) if err != nil { return nil, err } - r := make([]net.IP, 0, len(addresses)) + r := make([]netip.Addr, 0, len(addresses)) for _, addr := range addresses { if addr.LUID == luid { for dns := addr.FirstDNSServerAddress; dns != nil; dns = dns.Next { if ip := dns.Address.IP(); ip != nil { - r = append(r, ip) + if a, ok := netip.AddrFromSlice(ip); ok { + r = append(r, a) + } } else { return nil, windows.ERROR_INVALID_PARAMETER } @@ -333,17 +332,15 @@ func (luid LUID) DNS() ([]net.IP, error) { } // SetDNS method clears previous and associates new DNS servers and search domains with the adapter for a specific family. -func (luid LUID) SetDNS(family AddressFamily, servers []net.IP, domains []string) error { +func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []string) error { if family != windows.AF_INET && family != windows.AF_INET6 { return windows.ERROR_PROTOCOL_UNREACHABLE } var filteredServers []string for _, server := range servers { - if v4 := server.To4(); v4 != nil && family == windows.AF_INET { - filteredServers = append(filteredServers, v4.String()) - } else if v6 := server.To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 { - filteredServers = append(filteredServers, v6.String()) + if (server.Is4() && family == windows.AF_INET) || (server.Is6() && family == windows.AF_INET6) { + filteredServers = append(filteredServers, server.String()) } } servers16, err := windows.UTF16PtrFromString(strings.Join(filteredServers, ",")) diff --git a/tunnel/winipcfg/netsh.go b/tunnel/winipcfg/netsh.go index 1f3d12d02..17e0778c7 100644 --- a/tunnel/winipcfg/netsh.go +++ b/tunnel/winipcfg/netsh.go @@ -10,12 +10,13 @@ import ( "errors" "fmt" "io" - "net" "os/exec" "path/filepath" "strings" "syscall" + "golang.zx2c4.com/go118/netip" + "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" ) @@ -57,7 +58,7 @@ const ( netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no" ) -func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) error { +func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error { var templateFlush string if family == windows.AF_INET { templateFlush = netshCmdTemplateFlush4 @@ -72,10 +73,10 @@ func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) e } cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex)) for i := 0; i < len(dnses); i++ { - if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String())) - } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String())) + if dnses[i].Is4() && family == windows.AF_INET { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, dnses[i].String())) + } else if dnses[i].Is6() && family == windows.AF_INET6 { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, dnses[i].String())) } } return runNetsh(cmds) diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go index 789ee5017..599bf7898 100644 --- a/tunnel/winipcfg/types.go +++ b/tunnel/winipcfg/types.go @@ -8,9 +8,10 @@ package winipcfg import ( "encoding/binary" "fmt" - "net" "unsafe" + "golang.zx2c4.com/go118/netip" + "golang.org/x/sys/windows" ) @@ -584,8 +585,8 @@ const ( // RouteData structure describes a route to add type RouteData struct { - Destination net.IPNet - NextHop net.IP + Destination netip.Prefix + NextHop netip.Addr Metric uint32 } @@ -748,44 +749,50 @@ func htons(i uint16) uint16 { return *(*uint16)(unsafe.Pointer(&b[0])) } -// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port. +// SetAddrPort method sets family, address, and port to the given IPv4 or IPv6 address and port. // All other members of the structure are set to zero. -func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error { - if v4 := ip.To4(); v4 != nil { +func (addr *RawSockaddrInet) SetAddrPort(addrPort netip.AddrPort) error { + if addrPort.Addr().Is4() { addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)) addr4.Family = windows.AF_INET - copy(addr4.Addr[:], v4) - addr4.Port = htons(port) + addr4.Addr = addrPort.Addr().As4() + addr4.Port = htons(addrPort.Port()) for i := 0; i < 8; i++ { addr4.Zero[i] = 0 } return nil - } - - if v6 := ip.To16(); v6 != nil { + } else if addrPort.Addr().Is6() { addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)) addr6.Family = windows.AF_INET6 - addr6.Port = htons(port) + addr6.Addr = addrPort.Addr().As16() + addr6.Port = htons(addrPort.Port()) addr6.Flowinfo = 0 - copy(addr6.Addr[:], v6) addr6.Scope_id = 0 return nil } - return windows.ERROR_INVALID_PARAMETER } -// IP returns IPv4 or IPv6 address, or nil if the address is neither. -func (addr *RawSockaddrInet) IP() net.IP { +// SetAddr method sets family and address to the given IPv4 or IPv6 address. +// All other members of the structure are set to zero. +func (addr *RawSockaddrInet) SetAddr(netAddr netip.Addr) error { + return addr.SetAddrPort(netip.AddrPortFrom(netAddr, 0)) +} + +// AddrPort returns the IP address and port. +func (addr *RawSockaddrInet) AddrPort() netip.AddrPort { + return netip.AddrPortFrom(addr.Addr(), addr.Port()) +} + +// Addr returns IPv4 or IPv6 address, or an invalid address if the address is neither. +func (addr *RawSockaddrInet) Addr() netip.Addr { switch addr.Family { case windows.AF_INET: - return (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr[:] - + return netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr) case windows.AF_INET6: - return (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Addr[:] + return netip.AddrFrom16((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Addr) } - - return nil + return netip.Addr{} } // Port returns the port if the address if IPv4 or IPv6, or 0 if neither. @@ -793,11 +800,9 @@ func (addr *RawSockaddrInet) Port() uint16 { switch addr.Family { case windows.AF_INET: return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port) - case windows.AF_INET6: return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port) } - return 0 } @@ -874,32 +879,30 @@ func (tab *mibAnycastIPAddressTable) free() { // IPAddressPrefix structure stores an IP address prefix. // https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_ip_address_prefix type IPAddressPrefix struct { - Prefix RawSockaddrInet + RawPrefix RawSockaddrInet PrefixLength uint8 _ [2]byte } -// SetIPNet method sets IP address prefix using net.IPNet. -func (prefix *IPAddressPrefix) SetIPNet(net net.IPNet) error { - err := prefix.Prefix.SetIP(net.IP, 0) +// SetPrefix method sets IP address prefix using netip.Prefix. +func (prefix *IPAddressPrefix) SetPrefix(netPrefix netip.Prefix) error { + err := prefix.RawPrefix.SetAddr(netPrefix.Addr()) if err != nil { return err } - ones, _ := net.Mask.Size() - prefix.PrefixLength = uint8(ones) + prefix.PrefixLength = uint8(netPrefix.Bits()) return nil } -// IPNet method returns IP address prefix as net.IPNet. -// If the address is neither IPv4 not IPv6 an empty net.IPNet is returned. The resulting net.IPNet should be checked appropriately. -func (prefix *IPAddressPrefix) IPNet() net.IPNet { - switch prefix.Prefix.Family { +// Prefix returns IP address prefix as netip.Prefix. +func (prefix *IPAddressPrefix) Prefix() netip.Prefix { + switch prefix.RawPrefix.Family { case windows.AF_INET: - return net.IPNet{IP: (*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv4len)} + return netip.PrefixFrom(netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength)) case windows.AF_INET6: - return net.IPNet{IP: (*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv6len)} + return netip.PrefixFrom(netip.AddrFrom16((*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength)) } - return net.IPNet{} + return netip.Prefix{} } // MibIPforwardRow2 structure stores information about an IP route entry. diff --git a/tunnel/winipcfg/winipcfg_test.go b/tunnel/winipcfg/winipcfg_test.go index 5d3bc2767..d863b1a29 100644 --- a/tunnel/winipcfg/winipcfg_test.go +++ b/tunnel/winipcfg/winipcfg_test.go @@ -22,13 +22,13 @@ Some tests in this file require: package winipcfg import ( - "bytes" - "net" "strings" "syscall" "testing" "time" + "golang.zx2c4.com/go118/netip" + "golang.org/x/sys/windows" ) @@ -38,22 +38,13 @@ const ( // TODO: Add IPv6 tests. var ( - unexistentIPAddresToAdd = net.IPNet{ - IP: net.IP{172, 16, 1, 114}, - Mask: net.IPMask{255, 255, 255, 0}, - } - unexistentRouteIPv4ToAdd = RouteData{ - Destination: net.IPNet{ - IP: net.IP{172, 16, 200, 0}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - NextHop: net.IP{172, 16, 1, 2}, - Metric: 0, - } - dnsesToSet = []net.IP{ - net.IPv4(8, 8, 8, 8), - net.IPv4(8, 8, 4, 4), + nonexistantIPv4ToAdd = netip.MustParsePrefix("172.16.1.114/24") + nonexistentRouteIPv4ToAdd = RouteData{ + Destination: netip.MustParsePrefix("172.16.200.0/24"), + NextHop: netip.MustParseAddr("172.16.1.2"), + Metric: 0, } + dnsesToSet = []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")} ) func runningElevated() bool { @@ -380,9 +371,9 @@ func TestAddDeleteIPAddress(t *testing.T) { return } - addr, err := ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP) + addr, err := ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) if err == nil { - t.Errorf("Unicast address %s already exists. Please set unexistentIPAddresToAdd appropriately.", unexistentIPAddresToAdd.IP.String()) + t.Errorf("Unicast address %s already exists. Please set nonexistantIPv4ToAdd appropriately.", nonexistantIPv4ToAdd.Addr().String()) return } else if err != windows.ERROR_NOT_FOUND { t.Errorf("LUID.IPAddress() returned an error: %w", err) @@ -410,7 +401,7 @@ func TestAddDeleteIPAddress(t *testing.T) { for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next { count-- } - err = ifc.LUID.AddIPAddresses([]net.IPNet{unexistentIPAddresToAdd}) + err = ifc.LUID.AddIPAddresses([]netip.Prefix{nonexistantIPv4ToAdd}) if err != nil { t.Errorf("LUID.AddIPAddresses() returned an error: %w", err) } @@ -424,26 +415,26 @@ func TestAddDeleteIPAddress(t *testing.T) { if count != 1 { t.Errorf("After adding there are %d new interface(s).", count) } - addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP) + addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) if err != nil { t.Errorf("LUID.IPAddress() returned an error: %w", err) } else if addr == nil { - t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", unexistentIPAddresToAdd.IP.String()) + t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", nonexistantIPv4ToAdd.Addr().String()) } if !created { t.Errorf("Notification handler has not been called on add.") } - err = ifc.LUID.DeleteIPAddress(unexistentIPAddresToAdd) + err = ifc.LUID.DeleteIPAddress(nonexistantIPv4ToAdd) if err != nil { t.Errorf("LUID.DeleteIPAddress() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) - addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP) + addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) if err == nil { - t.Errorf("Unicast address %s still exists, although it's deleted successfully.", unexistentIPAddresToAdd.IP.String()) + t.Errorf("Unicast address %s still exists, although it's deleted successfully.", nonexistantIPv4ToAdd.Addr().String()) } else if err != windows.ERROR_NOT_FOUND { t.Errorf("LUID.IPAddress() returned an error: %w", err) } @@ -460,14 +451,13 @@ func TestGetRoutes(t *testing.T) { } func TestAddDeleteRoute(t *testing.T) { - findRoute := func(luid LUID, dest net.IPNet) ([]MibIPforwardRow2, error) { + findRoute := func(luid LUID, dest netip.Prefix) ([]MibIPforwardRow2, error) { var family AddressFamily - switch { - case dest.IP.To4() != nil: + if dest.Addr().Is4() { family = windows.AF_INET - case dest.IP.To16() != nil: + } else if dest.Addr().Is6() { family = windows.AF_INET6 - default: + } else { return nil, windows.ERROR_INVALID_PARAMETER } r, err := GetIPForwardTable2(family) @@ -475,9 +465,8 @@ func TestAddDeleteRoute(t *testing.T) { return nil, err } matches := make([]MibIPforwardRow2, 0, len(r)) - ones, _ := dest.Mask.Size() for _, route := range r { - if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(ones) && route.DestinationPrefix.Prefix.Family == family && route.DestinationPrefix.Prefix.IP().Equal(dest.IP) { + if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(dest.Bits()) && route.DestinationPrefix.RawPrefix.Family == family && route.DestinationPrefix.RawPrefix.Addr() == dest.Addr() { matches = append(matches, route) } } @@ -494,20 +483,20 @@ func TestAddDeleteRoute(t *testing.T) { return } - _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err == nil { - t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?") + t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?") return } else if err != windows.ERROR_NOT_FOUND { t.Errorf("LUID.Route() returned an error: %w", err) return } - routes, err := findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination) + routes, err := findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) if err != nil { t.Errorf("findRoute() returned an error: %w", err) } else if len(routes) != 0 { - t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?", len(routes)) + t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?", len(routes)) } var created, deleted bool @@ -524,42 +513,42 @@ func TestAddDeleteRoute(t *testing.T) { } else { defer cb.Unregister() } - err = ifc.LUID.AddRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop, unexistentRouteIPv4ToAdd.Metric) + err = ifc.LUID.AddRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop, nonexistentRouteIPv4ToAdd.Metric) if err != nil { t.Errorf("LUID.AddRoute() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) - route, err := ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + route, err := ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err == windows.ERROR_NOT_FOUND { t.Error("LUID.Route() returned nil although the route is added successfully.") } else if err != nil { t.Errorf("LUID.Route() returned an error: %w", err) - } else if !route.DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) || !route.NextHop.IP().Equal(unexistentRouteIPv4ToAdd.NextHop) { + } else if route.DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() || route.NextHop.Addr() != nonexistentRouteIPv4ToAdd.NextHop { t.Error("LUID.Route() returned a wrong route!") } if !created { t.Errorf("Route handler has not been called on add.") } - routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination) + routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) if err != nil { t.Errorf("findRoute() returned an error: %w", err) } else if len(routes) != 1 { t.Errorf("findRoute() returned %d items although %d is expected.", len(routes), 1) - } else if !routes[0].DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) { - t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.Prefix.IP().String(), unexistentRouteIPv4ToAdd.Destination.IP.String()) + } else if routes[0].DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() { + t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.RawPrefix.Addr().String(), nonexistentRouteIPv4ToAdd.Destination.Addr().String()) } - err = ifc.LUID.DeleteRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + err = ifc.LUID.DeleteRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err != nil { t.Errorf("LUID.DeleteRoute() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) - _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err == nil { t.Error("LUID.Route() returned a route although it is removed successfully.") } else if err != windows.ERROR_NOT_FOUND { @@ -569,7 +558,7 @@ func TestAddDeleteRoute(t *testing.T) { t.Errorf("Route handler has not been called on delete.") } - routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination) + routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) if err != nil { t.Errorf("findRoute() returned an error: %w", err) } else if len(routes) != 0 { @@ -606,7 +595,7 @@ func TestFlushDNS(t *testing.T) { t.Errorf("LUID.DNS() returned an error: %w", err) } for _, a := range dns { - if len(a) != 16 || a.To4() != nil || !((a[15] == 1 || a[15] == 2 || a[15] == 3) && bytes.HasPrefix(a, []byte{0xfe, 0xc0, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})) { + if a.Is4() { n++ } } @@ -651,7 +640,7 @@ func TestSetDNS(t *testing.T) { t.Errorf("dnsesToSet contains %d items, while DNSServerAddresses contains %d.", len(dnsesToSet), len(newDNSes)) } else { for i := range dnsesToSet { - if !dnsesToSet[i].Equal(newDNSes[i]) { + if dnsesToSet[i] != newDNSes[i] { t.Errorf("dnsesToSet[%d] = %s while DNSServerAddresses[%d] = %s.", i, dnsesToSet[i].String(), i, newDNSes[i].String()) } } diff --git a/ui/editdialog.go b/ui/editdialog.go index 3b1521ff9..764116f11 100644 --- a/ui/editdialog.go +++ b/ui/editdialog.go @@ -8,6 +8,8 @@ package ui import ( "strings" + "golang.zx2c4.com/go118/netip" + "github.com/lxn/walk" "github.com/lxn/win" "golang.org/x/sys/windows" @@ -185,10 +187,12 @@ func (dlg *EditDialog) onBlockUntunneledTrafficCBCheckedChanged() { return } var ( - v40 = [4]byte{} - v60 = [16]byte{} - v48 = [4]byte{0x80} - v68 = [16]byte{0x80} + v400 = conf.IPCidr{netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0)} + v600000 = conf.IPCidr{netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0)} + v401 = conf.IPCidr{netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 1)} + v600001 = conf.IPCidr{netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 1)} + v41281 = conf.IPCidr{netip.PrefixFrom(netip.AddrFrom4([4]byte{0x80}), 1)} + v680001 = conf.IPCidr{netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)} ) block := dlg.blockUntunneledTrafficCB.Checked() @@ -211,13 +215,13 @@ func (dlg *EditDialog) onBlockUntunneledTrafficCBCheckedChanged() { foundV680001 bool ) for _, allowedip := range cfg.Peers[0].AllowedIPs { - if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { + if allowedip == v600001 { foundV600001 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) { + } else if allowedip == v680001 { foundV680001 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { + } else if allowedip == v401 { foundV401 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) { + } else if allowedip == v41281 { foundV41281 = true } else { newAllowedIPs = append(newAllowedIPs, allowedip) @@ -227,44 +231,44 @@ func (dlg *EditDialog) onBlockUntunneledTrafficCBCheckedChanged() { goto err } if foundV401 && foundV41281 { - newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v40[:], 0}) + newAllowedIPs = append(newAllowedIPs, v400) } else if foundV401 { - newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v40[:], 1}) + newAllowedIPs = append(newAllowedIPs, v401) } else if foundV41281 { - newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v48[:], 1}) + newAllowedIPs = append(newAllowedIPs, v41281) } if foundV600001 && foundV680001 { - newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v60[:], 0}) + newAllowedIPs = append(newAllowedIPs, v600000) } else if foundV600001 { - newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v60[:], 1}) + newAllowedIPs = append(newAllowedIPs, v600001) } else if foundV680001 { - newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v68[:], 1}) + newAllowedIPs = append(newAllowedIPs, v680001) } cfg.Peers[0].AllowedIPs = newAllowedIPs } else { var ( - foundV400 bool - foundV600 bool + foundV400 bool + foundV600000 bool ) for _, allowedip := range cfg.Peers[0].AllowedIPs { - if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { - foundV600 = true - } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { + if allowedip == v600000 { + foundV600000 = true + } else if allowedip == v400 { foundV400 = true } else { newAllowedIPs = append(newAllowedIPs, allowedip) } } - if !(foundV400 || foundV600) { + if !(foundV400 || foundV600000) { goto err } if foundV400 { - newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v40[:], 1}) - newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v48[:], 1}) + newAllowedIPs = append(newAllowedIPs, v401) + newAllowedIPs = append(newAllowedIPs, v41281) } - if foundV600 { - newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v60[:], 1}) - newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v68[:], 1}) + if foundV600000 { + newAllowedIPs = append(newAllowedIPs, v600001) + newAllowedIPs = append(newAllowedIPs, v680001) } cfg.Peers[0].AllowedIPs = newAllowedIPs }