Skip to content

Commit

Permalink
fix: Unix socket issues
Browse files Browse the repository at this point in the history
  • Loading branch information
dogukanoksuz committed May 9, 2023
1 parent 0ca04c8 commit 81b167e
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 42 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ require (
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d
github.com/alessio/shellescape v1.4.1
github.com/andybalholm/brotli v1.0.5 // indirect
github.com/avast/retry-go v3.0.0+incompatible
github.com/go-sql-driver/mysql v1.7.1 // indirect
github.com/hirochachacha/go-smb2 v1.1.0
github.com/jackc/pgpassfile v1.0.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/avast/retry-go v3.0.0+incompatible h1:4SOWQ7Qs+oroOTQOYnAHqelpCO0biHSxpiH9JdtuBj0=
github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
Expand Down
2 changes: 2 additions & 0 deletions internal/bridge/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ func (t *TunnelPool) Get(remoteHost, remotePort, username string) (*Tunnel, erro

// Set Tunnel connection to pool
func (t *TunnelPool) Set(remoteHost, remotePort, username string, tunnel *Tunnel) {
mutex.Lock()
defer mutex.Unlock()
Tunnels[remoteHost+":"+remotePort+":"+username] = tunnel
}

Expand Down
60 changes: 50 additions & 10 deletions internal/bridge/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net"
"time"

"github.com/avast/retry-go"
"github.com/limanmys/render-engine/pkg/helpers"
"golang.org/x/crypto/ssh"
)
Expand All @@ -24,7 +25,19 @@ func InitShellWithPassword(username, password, host, port string) (*ssh.Client,
return nil, err
}

conn, err := ssh.Dial("tcp", net.JoinHostPort(ipAddress, port), config)
var conn *ssh.Client
err = retry.Do(
func() error {
conn, err = ssh.Dial("tcp", net.JoinHostPort(ipAddress, port), config)
if err != nil {
return err
}
return nil
},
retry.Attempts(5),
retry.Delay(1*time.Second),
)

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -53,10 +66,18 @@ func InitShellWithCert(username, certificate, host, port string) (*ssh.Client, e
return nil, err
}

conn, err := ssh.Dial("tcp", net.JoinHostPort(ipAddress, port), config)
if err != nil {
return nil, err
}
var conn *ssh.Client
err = retry.Do(
func() error {
conn, err = ssh.Dial("tcp", net.JoinHostPort(ipAddress, port), config)
if err != nil {
return err
}
return nil
},
retry.Attempts(5),
retry.Delay(1*time.Second),
)

return conn, nil
}
Expand All @@ -77,10 +98,18 @@ func VerifySSH(username, password, host, port string) bool {
return false
}

conn, err := ssh.Dial("tcp", net.JoinHostPort(ipAddress, port), config)
if err != nil {
return false
}
var conn *ssh.Client
err = retry.Do(
func() error {
conn, err = ssh.Dial("tcp", net.JoinHostPort(ipAddress, port), config)
if err != nil {
return err
}
return nil
},
retry.Attempts(5),
retry.Delay(1*time.Second),
)

defer conn.Close()
return true
Expand All @@ -107,7 +136,18 @@ func VerifySSHCertificate(username, certificate, host, port string) bool {
return false
}

conn, err := ssh.Dial("tcp", net.JoinHostPort(ipAddress, port), config)
var conn *ssh.Client
err = retry.Do(
func() error {
conn, err = ssh.Dial("tcp", net.JoinHostPort(ipAddress, port), config)
if err != nil {
return err
}
return nil
},
retry.Attempts(5),
retry.Delay(1*time.Second),
)
if err != nil {
return false
}
Expand Down
90 changes: 58 additions & 32 deletions internal/bridge/ssh_tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"github.com/avast/retry-go"
"github.com/limanmys/render-engine/pkg/logger"
"github.com/phayes/freeport"
"golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -44,26 +45,28 @@ var mut sync.Mutex = sync.Mutex{}

// CreateTunnel starts a new tunnel instance and sets it into TunnelPool
func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) int {
mut.Lock()
defer mut.Unlock()

ch := make(chan int)
time.AfterFunc(30*time.Second, func() {
ch <- 1
})

t, err := Tunnels.Get(remoteHost, remotePort, username)
if err == nil {
if t.password != password {
return 0
}

OL:
startedLoop:
for {
if t.Started {
break
}

select {
case <-ch:
break OL
break startedLoop
default:
time.Sleep(5 * time.Millisecond)
continue
Expand All @@ -74,9 +77,6 @@ func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) in
return t.Port
}

mut.Lock()
defer mut.Unlock()

port, err := freeport.GetFreePort()
if err != nil {
logger.Sugar().Errorw(err.Error())
Expand All @@ -92,7 +92,7 @@ func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) in
}

sshTunnel := &Tunnel{
auth: []ssh.AuthMethod{ssh.Password(password)},
auth: []ssh.AuthMethod{ssh.RetryableAuthMethod(ssh.Password(password), 3)},
hostKeys: ssh.InsecureIgnoreHostKey(),
user: username,
mode: '>',
Expand All @@ -114,15 +114,15 @@ func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) in

hasError := sshTunnel.Start()
if !hasError {
L:
loop:
for {
if sshTunnel.Started {
break
}

select {
case <-ch:
break L
break loop
default:
time.Sleep(5 * time.Millisecond)
continue
Expand Down Expand Up @@ -177,20 +177,33 @@ func (t *Tunnel) bindTunnel(ctx context.Context, wg *sync.WaitGroup, hasError *b
for {
var once sync.Once // Only print errors once per session
func() {
// Connect to the server host via SSH.
cl, err := ssh.Dial("tcp", t.hostAddr, &ssh.ClientConfig{
User: t.user,
Auth: t.auth,
HostKeyCallback: t.hostKeys,
Timeout: 5 * time.Second,
})
var cl *ssh.Client
var err error

err = retry.Do(
func() error {
cl, err = ssh.Dial("tcp", t.hostAddr, &ssh.ClientConfig{
User: t.user,
Auth: t.auth,
HostKeyCallback: t.hostKeys,
Timeout: 5 * time.Second,
})
if err != nil {
return err
}
return nil
},
retry.Attempts(5),
retry.Delay(1*time.Second),
)

if err != nil {
once.Do(func() {
t.log.Errorw("ssh dial error", "details", fmt.Sprintf("%v, %v", t, err))
t.errHandler()
t.Stop()
*hasError = true
wg.Done()
t.errHandler()
})
return
}
Expand All @@ -210,10 +223,10 @@ func (t *Tunnel) bindTunnel(ctx context.Context, wg *sync.WaitGroup, hasError *b
if err != nil {
once.Do(func() {
t.log.Errorw("bind error", "details", fmt.Sprintf("%v, %v", t, err))
t.errHandler()
t.Stop()
*hasError = true
wg.Done()
t.errHandler()
})
return
}
Expand Down Expand Up @@ -244,8 +257,8 @@ func (t *Tunnel) bindTunnel(ctx context.Context, wg *sync.WaitGroup, hasError *b
t.log.Errorw("accept error", "details", fmt.Sprintf("%v, %v", t, err))
t.Stop()
*hasError = true
t.errHandler()
wg.Done()
t.errHandler()
})
return
}
Expand Down Expand Up @@ -273,21 +286,35 @@ func (t *Tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh
}()

// Establish the outbound connection.
var once sync.Once
var cn2 net.Conn
var err error
switch t.mode {
case '>':
cn2, err = client.Dial(t.dialType, t.dialAddr)
case '<':
cn2, err = net.Dial(t.dialType, t.dialAddr)
}
if err != nil {
t.Stop()
t.log.Errorw("ssh dial error", "details", fmt.Sprintf("%v, %v", t, err))
t.errHandler()
*hasError = true

wg.Done()
err = retry.Do(
func() error {
switch t.mode {
case '>':
cn2, err = client.Dial(t.dialType, t.dialAddr)
case '<':
cn2, err = net.Dial(t.dialType, t.dialAddr)
}

if err != nil {
return err
}
return nil
},
retry.Attempts(5),
retry.Delay(1*time.Second),
)
if err != nil {
once.Do(func() {
t.Stop()
t.log.Errorw("ssh dial error", "details", fmt.Sprintf("%v, %v", t, err))
*hasError = true
wg.Done()
t.errHandler()
})
return
}

Expand All @@ -300,7 +327,6 @@ func (t *Tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh
//defer t.log.Infow("connection closed", "details", t)

// Copy bytes from one connection to the other until one side closes.
var once sync.Once
var wg2 sync.WaitGroup
wg2.Add(2)
go func() {
Expand Down

0 comments on commit 81b167e

Please sign in to comment.