Skip to content

Commit

Permalink
Process routes before peers (#2105)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal authored Jun 19, 2024
1 parent b679404 commit 61bc092
Show file tree
Hide file tree
Showing 22 changed files with 165 additions and 93 deletions.
6 changes: 3 additions & 3 deletions client/internal/dns/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -343,7 +343,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}

privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
Expand Down Expand Up @@ -801,7 +801,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
}

privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err
Expand Down
45 changes: 31 additions & 14 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
Expand Down Expand Up @@ -735,6 +736,20 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}

protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}

_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}

e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()

log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))

e.updateOfflinePeers(networkMap.GetOfflinePeers())
Expand Down Expand Up @@ -776,19 +791,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}

_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}

e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()

protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
Expand Down Expand Up @@ -1287,7 +1289,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
default:
}

return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs)
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
}

func (e *Engine) wgInterfaceCreate() (err error) {
Expand Down Expand Up @@ -1485,6 +1487,21 @@ func (e *Engine) startNetworkMonitor() {
}()
}

func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}

if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}

return false, netip.Prefix{}, nil
}

// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
Expand Down
6 changes: 3 additions & 3 deletions client/internal/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -574,7 +574,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
assert.NoError(t, err, "shouldn't return error")
input := struct {
inputSerial uint64
Expand Down Expand Up @@ -745,7 +745,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil)
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
assert.NoError(t, err, "shouldn't return error")

mockRouteManager := &routemanager.MockManager{
Expand Down
28 changes: 0 additions & 28 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
Expand All @@ -15,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
Expand Down Expand Up @@ -763,10 +761,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
return
}

if candidateViaRoutes(candidate, haRoutes) {
return
}

err := conn.agent.AddRemoteCandidate(candidate)
if err != nil {
log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
Expand Down Expand Up @@ -797,25 +791,3 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
RelPort: relatedAdd.Port,
})
}

func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var vpnRoutes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}

addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}

if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
log.Debugf("Ignoring candidate [%s], its address is routed to network %s", candidate.String(), prefix)
return true
}

return false
}
2 changes: 1 addition & 1 deletion client/internal/routemanager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()

Expand Down
3 changes: 2 additions & 1 deletion client/internal/routemanager/systemops/systemops_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac

// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)

exitNextHop = initialNextHop
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()

Expand Down Expand Up @@ -213,7 +213,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()

Expand Down Expand Up @@ -345,7 +345,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
newNet, err := stdnet.NewNet()
require.NoError(t, err)

wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WireGuard interface")

err = wgInterface.Create()
Expand Down
10 changes: 7 additions & 3 deletions iface/bind/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ type ICEBind struct {

transportNet transport.Net
udpMux *UniversalUDPMuxDefault

filterFn FilterFn
}

func NewICEBind(transportNet transport.Net) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
ib := &ICEBind{
transportNet: transportNet,
filterFn: filterFn,
}

rc := receiverCreator{
Expand All @@ -59,8 +62,9 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC

s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
Expand Down
72 changes: 70 additions & 2 deletions iface/bind/udp_mux_universal.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"

log "github.com/sirupsen/logrus"
Expand All @@ -17,6 +19,10 @@ import (
"github.com/pion/transport/v3"
)

// FilterFn is a function that filters out candidates based on the address.
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)

// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
Expand All @@ -34,6 +40,7 @@ type UniversalUDPMuxParams struct {
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
}

// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
Expand All @@ -56,6 +63,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
}

// embed UDPMux
Expand Down Expand Up @@ -105,8 +113,68 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
}

func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}

if isRouted, found := u.addrCache.Load(addr.String()); found {
return u.handleCachedAddress(isRouted.(bool), b, addr)
}

return u.handleUncachedAddress(b, addr)
}

func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}

func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}

func (u *udpConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)
return nil
}

a, err := netip.ParseAddr(host)
if err != nil {
log.Errorf("Failed to parse address %s: %v", addr, err)
return nil
}

if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}
return nil
}

func getHostFromAddr(addr net.Addr) (string, error) {
host, _, err := net.SplitHostPort(addr.String())
return host, err
}

// GetSharedConn returns the shared udp conn
Expand Down
6 changes: 4 additions & 2 deletions iface/iface_android.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ import (
"fmt"

"github.com/pion/transport/v3"

"github.com/netbirdio/netbird/iface/bind"
)

// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) {
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
if err != nil {
return nil, err
}

wgIFace := &WGIface{
tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter),
tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true,
}
return wgIFace, nil
Expand Down
Loading

0 comments on commit 61bc092

Please sign in to comment.