Skip to content

Commit

Permalink
Merge pull request #167 from moov-io/trap-errors-and-reconnect
Browse files Browse the repository at this point in the history
feat: trap error messages to re-connect underneath of callers
  • Loading branch information
adamdecaf authored Sep 13, 2023
2 parents 9c407ee + 48f310e commit cadc3d3
Showing 1 changed file with 69 additions and 10 deletions.
79 changes: 69 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ type Client interface {
}

type client struct {
conn *ssh.Client
client *sftp.Client
cfg ClientConfig
logger log.Logger
cfg ClientConfig

mu sync.Mutex // protects all read/write methods
conn *ssh.Client
client *sftp.Client
}

func NewClient(logger log.Logger, cfg *ClientConfig) (Client, error) {
Expand All @@ -89,11 +90,17 @@ func NewClient(logger log.Logger, cfg *ClientConfig) (Client, error) {

conn, err := cc.connection()
cc.record(err) // track up metric for remote server
err = cc.clearConnectionOnError(err)

// Print an initial startup message
if conn != nil && logger != nil {
wd, _ := conn.Getwd()
logger.Logf("starting SFTP client in %s", wd)
wd, wdErr := conn.Getwd()
if wdErr != nil {
err = cc.clearConnectionOnError(wdErr)
}
if wd != "" {
logger.Logf("starting SFTP client in %s", wd)
}
}

return cc, err
Expand Down Expand Up @@ -143,6 +150,39 @@ func (c *client) connection() (*sftp.Client, error) {
return c.client, nil
}

// clearConnectionOnError accepts any error from a call involving the SSH/SFTP connection.
// If an error is encountered that causes either connection (SSH or SFTP) to be
// lost it will tear down the connections. The next invocation of c.connection()
// will re-establish new connections.
//
// When the error is captured by clearConnectionOnError the client will attempt to reconnect
// and that new connection error will be returned.
func (c *client) clearConnectionOnError(err error) error {
if err == nil {
return nil
}
// Possible errors from github.com/pkg/sftp/request-errors.go
switch {
case errors.Is(err, sftp.ErrSSHFxEOF),
errors.Is(err, sftp.ErrSSHFxFailure),
errors.Is(err, sftp.ErrSSHFxBadMessage),
errors.Is(err, sftp.ErrSSHFxNoConnection),
errors.Is(err, sftp.ErrSSHFxConnectionLost):
// Teardown the existing connections
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
if c.client != nil {
c.client.Close()
c.client = nil
}
// Reconnect if needed and replace the initial error
_, err = c.connection()
}
return err
}

var (
hostKeyCallbackOnce sync.Once
hostKeyCallback = func(logger log.Logger) {
Expand Down Expand Up @@ -249,12 +289,14 @@ func (c *client) Ping() error {

conn, err := c.connection()
c.record(err)
err = c.clearConnectionOnError(err)
if err != nil {
return err
}

_, err = conn.ReadDir(".")
c.record(err)
err = c.clearConnectionOnError(err)
if err != nil {
return fmt.Errorf("sftp: ping %w", err)
}
Expand Down Expand Up @@ -290,16 +332,20 @@ func (c *client) Delete(path string) error {
defer c.mu.Unlock()

conn, err := c.connection()
err = c.clearConnectionOnError(err)
if err != nil {
return err
}

info, err := conn.Stat(path)
err = c.clearConnectionOnError(err)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("sftp: delete stat: %w", err)
}
if info != nil {
if err := conn.Remove(path); err != nil {
err := conn.Remove(path)
err = c.clearConnectionOnError(err)
if err != nil {
return fmt.Errorf("sftp: delete: %w", err)
}
}
Expand All @@ -316,6 +362,7 @@ func (c *client) UploadFile(path string, contents io.ReadCloser) error {
defer c.mu.Unlock()

conn, err := c.connection()
err = c.clearConnectionOnError(err)
if err != nil {
return err
}
Expand All @@ -325,9 +372,12 @@ func (c *client) UploadFile(path string, contents io.ReadCloser) error {
dir, _ := filepath.Split(path)

info, err := conn.Stat(dir)
err = c.clearConnectionOnError(err)
if info == nil || err != nil {
if os.IsNotExist(err) || strings.Contains(err.Error(), "file does not exist") {
if err := conn.MkdirAll(dir); err != nil {
err := conn.MkdirAll(dir)
err = c.clearConnectionOnError(err)
if err != nil {
return fmt.Errorf("sftp: problem creating %s as parent dir: %w", dir, err)
}
} else {
Expand All @@ -339,6 +389,7 @@ func (c *client) UploadFile(path string, contents io.ReadCloser) error {
// Some servers don't allow you to open a file for reading and writing at the same time.
// For these we follow the pkg/sftp docs to open files for writing (not reading).
fd, err := conn.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC)
err = c.clearConnectionOnError(err)
if err != nil {
return fmt.Errorf("sftp: problem creating remote file %s: %w", path, err)
}
Expand All @@ -348,20 +399,26 @@ func (c *client) UploadFile(path string, contents io.ReadCloser) error {
return fmt.Errorf("sftp: problem copying (n=%d) %s: %w", n, path, err)
}

if err := fd.Sync(); err != nil {
err = fd.Sync()
err = c.clearConnectionOnError(err)
if err != nil {
// Skip sync if the remote server doesn't support it
if !strings.Contains(err.Error(), "SSH_FX_OP_UNSUPPORTED") {
return fmt.Errorf("sftp: problem with sync on %s: %v", path, err)
}
}

if !c.cfg.SkipChmodAfterUpload {
if err = fd.Chmod(0600); err != nil {
err := fd.Chmod(0600)
err = c.clearConnectionOnError(err)
if err != nil {
return fmt.Errorf("sftp: problem chmod %s: %w", path, err)
}
}

if err := fd.Close(); err != nil {
err = fd.Close()
err = c.clearConnectionOnError(err)
if err != nil {
return fmt.Errorf("sftp: closing %s after writing failed: %w", path, err)
}

Expand Down Expand Up @@ -443,11 +500,13 @@ func (c *client) Reader(path string) (*File, error) {
defer c.mu.Unlock()

conn, err := c.connection()
err = c.clearConnectionOnError(err)
if err != nil {
return nil, err
}

fd, err := conn.Open(path)
err = c.clearConnectionOnError(err)
if err != nil {
return nil, fmt.Errorf("sftp: open %s: %w", path, err)
}
Expand Down

0 comments on commit cadc3d3

Please sign in to comment.