diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8b8b3c5c077..f4a701f7fca 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -484,11 +484,11 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { // switch back to relay connection if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { conn.log.Debugf("ICE disconnected, set Relay to active connection") + conn.workerRelay.EnableWgWatcher(conn.ctx) err := conn.configureWGEndpoint(conn.endpointRelay) if err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } - conn.workerRelay.EnableWgWatcher(conn.ctx) conn.currentConnPriority = connPriorityRelay } @@ -551,6 +551,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { } } + conn.workerRelay.EnableWgWatcher(conn.ctx) err = conn.configureWGEndpoint(endpointUdpAddr) if err != nil { if err := wgProxy.CloseConn(); err != nil { @@ -560,7 +561,6 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { return } wgConfigWorkaround() - conn.workerRelay.EnableWgWatcher(conn.ctx) if conn.wgProxyRelay != nil { if err := conn.wgProxyRelay.CloseConn(); err != nil { diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 930a8f5b6d6..3457faa465d 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -14,7 +14,7 @@ import ( ) var ( - wgHandshakePeriod = 2 * time.Minute + wgHandshakePeriod = 3 * time.Minute wgHandshakeOvertime = 30 * time.Second ) @@ -109,7 +109,7 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) { } ctx, ctxCancel := context.WithCancel(ctx) - go w.wgStateCheck(ctx) + w.wgStateCheck(ctx) w.ctxWgWatch = ctx w.ctxCancelWgWatch = ctxCancel @@ -157,37 +157,50 @@ func (w *WorkerRelay) CloseConn() { } } -// wgStateCheck help to check the state of the wireguard handshake and relay connection +// wgStateCheck help to check the state of the WireGuard handshake and relay connection func (w *WorkerRelay) wgStateCheck(ctx context.Context) { - timer := time.NewTimer(wgHandshakeOvertime) - defer timer.Stop() - expected := wgHandshakeOvertime - for { - select { - case <-timer.C: - lastHandshake, err := w.wgState() - if err != nil { - w.log.Errorf("failed to read wg stats: %v", err) - continue - } - w.log.Tracef("last handshake: %v", lastHandshake) - - if time.Since(lastHandshake) > expected { - w.log.Infof("Wireguard handshake timed out, closing relay connection") - w.relayLock.Lock() - _ = w.relayedConn.Close() - w.relayLock.Unlock() - w.callBacks.OnDisconnected() + lastHandshake, err := w.wgState() + if err != nil { + w.log.Errorf("failed to read wg stats: %v", err) + lastHandshake = time.Time{} + } + + go func(lastHandshake time.Time) { + timer := time.NewTimer(wgHandshakeOvertime) + defer timer.Stop() + + for { + select { + case <-timer.C: + + handshake, err := w.wgState() + if err != nil { + w.log.Errorf("failed to read wg stats: %v", err) + timer.Reset(wgHandshakeOvertime) + continue + } + + w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) + + if handshake.Equal(lastHandshake) { + w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake) + w.relayLock.Lock() + _ = w.relayedConn.Close() + w.relayLock.Unlock() + w.callBacks.OnDisconnected() + return + } + + resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime)) + lastHandshake = handshake + timer.Reset(resetTime) + case <-ctx.Done(): + w.log.Debugf("WireGuard watcher stopped") return } - resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime)) - timer.Reset(resetTime) - expected = wgHandshakePeriod - case <-ctx.Done(): - w.log.Debugf("WireGuard watcher stopped") - return } - } + }(lastHandshake) + } func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {