From 61bc092458084939d549b726ca87701c2c56a870 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:12:11 +0200 Subject: [PATCH] Process routes before peers (#2105) --- client/internal/dns/server_test.go | 6 +- client/internal/engine.go | 45 ++++++++---- client/internal/engine_test.go | 6 +- client/internal/peer/conn.go | 28 -------- client/internal/routemanager/manager_test.go | 2 +- .../systemops/systemops_generic.go | 3 +- .../systemops/systemops_generic_test.go | 6 +- iface/bind/bind.go | 10 ++- iface/bind/udp_mux_universal.go | 72 ++++++++++++++++++- iface/iface_android.go | 6 +- iface/iface_darwin.go | 7 +- iface/iface_ios.go | 6 +- iface/iface_test.go | 16 ++--- iface/iface_unix.go | 7 +- iface/iface_windows.go | 7 +- iface/tun_android.go | 4 +- iface/tun_darwin.go | 4 +- iface/tun_ios.go | 4 +- iface/tun_kernel_unix.go | 7 +- iface/tun_netstack.go | 4 +- iface/tun_usp_unix.go | 4 +- iface/tun_windows.go | 4 +- 22 files changed, 165 insertions(+), 93 deletions(-) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 3709c32ce48..6cbd9ea1527 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -265,7 +265,7 @@ func TestUpdateDNSServer(t *testing.T) { if err != nil { t.Fatal(err) } - wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil) + wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -343,7 +343,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { } privKey, _ := wgtypes.GeneratePrivateKey() - wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil) + wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) if err != nil { t.Errorf("build interface wireguard: %v", err) return @@ -801,7 +801,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { } privKey, _ := wgtypes.GeneratePrivateKey() - wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil) + wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) if err != nil { t.Fatalf("build interface wireguard: %v", err) return nil, err diff --git a/client/internal/engine.go b/client/internal/engine.go index 68c287046b5..5e1e469163f 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -29,6 +29,7 @@ import ( "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/wgproxy" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -735,6 +736,20 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } + protoRoutes := networkMap.GetRoutes() + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + + _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) + if err != nil { + log.Errorf("failed to update clientRoutes, err: %v", err) + } + + e.clientRoutesMu.Lock() + e.clientRoutes = clientRoutes + e.clientRoutesMu.Unlock() + log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) e.updateOfflinePeers(networkMap.GetOfflinePeers()) @@ -776,19 +791,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } } - protoRoutes := networkMap.GetRoutes() - if protoRoutes == nil { - protoRoutes = []*mgmProto.Route{} - } - - _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) - if err != nil { - log.Errorf("failed to update clientRoutes, err: %v", err) - } - - e.clientRoutesMu.Lock() - e.clientRoutes = clientRoutes - e.clientRoutesMu.Unlock() protoDNSConfig := networkMap.GetDNSConfig() if protoDNSConfig == nil { @@ -1287,7 +1289,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { default: } - return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs) + return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes) } func (e *Engine) wgInterfaceCreate() (err error) { @@ -1485,6 +1487,21 @@ func (e *Engine) startNetworkMonitor() { }() } +func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { + var vpnRoutes []netip.Prefix + for _, routes := range e.GetClientRoutes() { + if len(routes) > 0 && routes[0] != nil { + vpnRoutes = append(vpnRoutes, routes[0].Network) + } + } + + if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn { + return true, prefix, nil + } + + return false, netip.Prefix{}, nil +} + // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9f95fbc27cc..0db0ab74c56 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -217,7 +217,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil) + engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -574,7 +574,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) assert.NoError(t, err, "shouldn't return error") input := struct { inputSerial uint64 @@ -745,7 +745,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil) assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 3a38d14c1fd..0d8fd932c0a 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net" - "net/netip" "runtime" "strings" "sync" @@ -15,7 +14,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/iface" @@ -763,10 +761,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa return } - if candidateViaRoutes(candidate, haRoutes) { - return - } - err := conn.agent.AddRemoteCandidate(candidate) if err != nil { log.Errorf("error while handling remote candidate from peer %s", conn.config.Key) @@ -797,25 +791,3 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive RelPort: relatedAdd.Port, }) } - -func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { - var vpnRoutes []netip.Prefix - for _, routes := range clientRoutes { - if len(routes) > 0 && routes[0] != nil { - vpnRoutes = append(vpnRoutes, routes[0].Network) - } - } - - addr, err := netip.ParseAddr(candidate.Address()) - if err != nil { - log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err) - return false - } - - if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn { - log.Debugf("Ignoring candidate [%s], its address is routed to network %s", candidate.String(), prefix) - return true - } - - return false -} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 1b226da29bb..455c7ac0b9a 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -407,7 +407,7 @@ func TestManagerUpdateRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 53bab6edf40..c2db523697c 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -154,7 +154,8 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { - log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) + log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop) + exitNextHop = initialNextHop } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 594aaee4ace..292166582be 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -213,7 +213,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -345,7 +345,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen newNet, err := stdnet.NewNet() require.NoError(t, err) - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) require.NoError(t, err, "should create testing WireGuard interface") err = wgInterface.Create() diff --git a/iface/bind/bind.go b/iface/bind/bind.go index 00af25f67fd..ba6153cb738 100644 --- a/iface/bind/bind.go +++ b/iface/bind/bind.go @@ -28,11 +28,14 @@ type ICEBind struct { transportNet transport.Net udpMux *UniversalUDPMuxDefault + + filterFn FilterFn } -func NewICEBind(transportNet transport.Net) *ICEBind { +func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { ib := &ICEBind{ transportNet: transportNet, + filterFn: filterFn, } rc := receiverCreator{ @@ -59,8 +62,9 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC s.udpMux = NewUniversalUDPMuxDefault( UniversalUDPMuxParams{ - UDPConn: conn, - Net: s.transportNet, + UDPConn: conn, + Net: s.transportNet, + FilterFn: s.filterFn, }, ) return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { diff --git a/iface/bind/udp_mux_universal.go b/iface/bind/udp_mux_universal.go index 7121f1ff456..ebbefe03566 100644 --- a/iface/bind/udp_mux_universal.go +++ b/iface/bind/udp_mux_universal.go @@ -8,6 +8,8 @@ import ( "context" "fmt" "net" + "net/netip" + "sync" "time" log "github.com/sirupsen/logrus" @@ -17,6 +19,10 @@ import ( "github.com/pion/transport/v3" ) +// FilterFn is a function that filters out candidates based on the address. +// If it returns true, the address is to be filtered. It also returns the prefix of matching route. +type FilterFn func(address netip.Addr) (bool, netip.Prefix, error) + // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // It then passes packets to the UDPMux that does the actual connection muxing. type UniversalUDPMuxDefault struct { @@ -34,6 +40,7 @@ type UniversalUDPMuxParams struct { UDPConn net.PacketConn XORMappedAddrCacheTTL time.Duration Net transport.Net + FilterFn FilterFn } // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux @@ -56,6 +63,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef PacketConn: params.UDPConn, mux: m, logger: params.Logger, + filterFn: params.FilterFn, } // embed UDPMux @@ -105,8 +113,68 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) { // udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets type udpConn struct { net.PacketConn - mux *UniversalUDPMuxDefault - logger logging.LeveledLogger + mux *UniversalUDPMuxDefault + logger logging.LeveledLogger + filterFn FilterFn + // TODO: reset cache on route changes + addrCache sync.Map +} + +func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { + if u.filterFn == nil { + return u.PacketConn.WriteTo(b, addr) + } + + if isRouted, found := u.addrCache.Load(addr.String()); found { + return u.handleCachedAddress(isRouted.(bool), b, addr) + } + + return u.handleUncachedAddress(b, addr) +} + +func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { + if isRouted { + return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr) + } + return u.PacketConn.WriteTo(b, addr) +} + +func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { + if err := u.performFilterCheck(addr); err != nil { + return 0, err + } + return u.PacketConn.WriteTo(b, addr) +} + +func (u *udpConn) performFilterCheck(addr net.Addr) error { + host, err := getHostFromAddr(addr) + if err != nil { + log.Errorf("Failed to get host from address %s: %v", addr, err) + return nil + } + + a, err := netip.ParseAddr(host) + if err != nil { + log.Errorf("Failed to parse address %s: %v", addr, err) + return nil + } + + if isRouted, prefix, err := u.filterFn(a); err != nil { + log.Errorf("Failed to check if address %s is routed: %v", addr, err) + } else { + u.addrCache.Store(addr.String(), isRouted) + if isRouted { + // Extra log, as the error only shows up with ICE logging enabled + log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix) + return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix) + } + } + return nil +} + +func getHostFromAddr(addr net.Addr) (string, error) { + host, _, err := net.SplitHostPort(addr.String()) + return host, err } // GetSharedConn returns the shared udp conn diff --git a/iface/iface_android.go b/iface/iface_android.go index d1876e4955d..99f6885a5e9 100644 --- a/iface/iface_android.go +++ b/iface/iface_android.go @@ -4,17 +4,19 @@ import ( "fmt" "github.com/pion/transport/v3" + + "github.com/netbirdio/netbird/iface/bind" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter), + tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_darwin.go b/iface/iface_darwin.go index d68f562cd2c..15e4a781735 100644 --- a/iface/iface_darwin.go +++ b/iface/iface_darwin.go @@ -7,11 +7,12 @@ import ( "github.com/pion/transport/v3" + "github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { return nil, err @@ -22,11 +23,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) + wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } diff --git a/iface/iface_ios.go b/iface/iface_ios.go index 39032e6bdb5..6babe596419 100644 --- a/iface/iface_ios.go +++ b/iface/iface_ios.go @@ -6,16 +6,18 @@ import ( "fmt" "github.com/pion/transport/v3" + + "github.com/netbirdio/netbird/iface/bind" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd), + tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_test.go b/iface/iface_test.go index f227eaf8351..43c44b770fb 100644 --- a/iface/iface_test.go +++ b/iface/iface_test.go @@ -41,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -114,7 +114,7 @@ func Test_CreateInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -149,7 +149,7 @@ func Test_Close(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -182,7 +182,7 @@ func Test_ConfigureInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -230,7 +230,7 @@ func Test_UpdatePeer(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -291,7 +291,7 @@ func Test_RemovePeer(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -345,7 +345,7 @@ func Test_ConnectPeers(t *testing.T) { t.Fatal(err) } - iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil) + iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -368,7 +368,7 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil) + iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } diff --git a/iface/iface_unix.go b/iface/iface_unix.go index b378abef3c9..9608df1ad9a 100644 --- a/iface/iface_unix.go +++ b/iface/iface_unix.go @@ -8,11 +8,12 @@ import ( "github.com/pion/transport/v3" + "github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { return nil, err @@ -22,7 +23,7 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, // move the kernel/usp/netstack preference evaluation to upper layer if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) + wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) wgIFace.userspaceBind = true return wgIFace, nil } @@ -36,7 +37,7 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, if !tunModuleIsLoaded() { return nil, fmt.Errorf("couldn't check or load tun module") } - wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) wgIFace.userspaceBind = true return wgIFace, nil } diff --git a/iface/iface_windows.go b/iface/iface_windows.go index d3a16a52fe4..c5edd27a9ce 100644 --- a/iface/iface_windows.go +++ b/iface/iface_windows.go @@ -5,11 +5,12 @@ import ( "github.com/pion/transport/v3" + "github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { return nil, err @@ -20,11 +21,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) + wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } diff --git a/iface/tun_android.go b/iface/tun_android.go index 834b2cb42d9..dc6abea36b7 100644 --- a/iface/tun_android.go +++ b/iface/tun_android.go @@ -31,13 +31,13 @@ type wgTunDevice struct { configurer wgConfigurer } -func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter) wgTunDevice { +func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice { return wgTunDevice{ address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), tunAdapter: tunAdapter, } } diff --git a/iface/tun_darwin.go b/iface/tun_darwin.go index 8dc10bd0e73..7d684f52e96 100644 --- a/iface/tun_darwin.go +++ b/iface/tun_darwin.go @@ -27,14 +27,14 @@ type tunDevice struct { configurer wgConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { +func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { return &tunDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), } } diff --git a/iface/tun_ios.go b/iface/tun_ios.go index ea980818d78..83e26e08d6c 100644 --- a/iface/tun_ios.go +++ b/iface/tun_ios.go @@ -29,13 +29,13 @@ type tunDevice struct { configurer wgConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int) *tunDevice { +func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice { return &tunDevice{ name: name, address: address, port: port, key: key, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), tunFd: tunFd, } } diff --git a/iface/tun_kernel_unix.go b/iface/tun_kernel_unix.go index db47b68cf30..019dd786bbd 100644 --- a/iface/tun_kernel_unix.go +++ b/iface/tun_kernel_unix.go @@ -27,6 +27,8 @@ type tunKernelDevice struct { link *wgLink udpMuxConn net.PacketConn udpMux *bind.UniversalUDPMuxDefault + + filterFn bind.FilterFn } func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { @@ -96,8 +98,9 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return nil, err } bindParams := bind.UniversalUDPMuxParams{ - UDPConn: rawSock, - Net: t.transportNet, + UDPConn: rawSock, + Net: t.transportNet, + FilterFn: t.filterFn, } mux := bind.NewUniversalUDPMuxDefault(bindParams) go mux.ReadFromConn(t.ctx) diff --git a/iface/tun_netstack.go b/iface/tun_netstack.go index e1d01ecc90b..beb3acc3fef 100644 --- a/iface/tun_netstack.go +++ b/iface/tun_netstack.go @@ -30,7 +30,7 @@ type tunNetstackDevice struct { configurer wgConfigurer } -func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice { +func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice { return &tunNetstackDevice{ name: name, address: address, @@ -38,7 +38,7 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string key: key, mtu: mtu, listenAddress: listenAddress, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), } } diff --git a/iface/tun_usp_unix.go b/iface/tun_usp_unix.go index 2e4be5280d1..b18794b2579 100644 --- a/iface/tun_usp_unix.go +++ b/iface/tun_usp_unix.go @@ -29,7 +29,7 @@ type tunUSPDevice struct { configurer wgConfigurer } -func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { +func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { log.Infof("using userspace bind mode") checkUser() @@ -40,7 +40,7 @@ func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu i port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), } } diff --git a/iface/tun_windows.go b/iface/tun_windows.go index 900e62fc3e8..5c77f1d166b 100644 --- a/iface/tun_windows.go +++ b/iface/tun_windows.go @@ -29,14 +29,14 @@ type tunDevice struct { configurer wgConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { +func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { return &tunDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), } }