Skip to content

Commit

Permalink
Fix bug for parsing IPv6 addresses (#17)
Browse files Browse the repository at this point in the history
IPv6 addresses contain colons ':' and since sshtunnel uses strings.Split(addr, ":"), to spit host and port, it ends up resulting in faulty splits. Using net.SplitHostPort fixes this since it handles IPv6 addresses correctly. If the user did not supply a port, then endpoint.Host is left as is.

This includes a partially breaking change where NewEdnpoint and NewSSHTunnel will now return an error. If you are not using IPv6 it is safe to ignore this error.
  • Loading branch information
bnmoch3 authored Jun 20, 2023
1 parent 27700fc commit 3d56ada
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 10 deletions.
24 changes: 19 additions & 5 deletions endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sshtunnel

import (
"fmt"
"net"
"strconv"
"strings"
)
Expand All @@ -12,7 +13,11 @@ type Endpoint struct {
User string
}

func NewEndpoint(s string) *Endpoint {
// NewEndpoint creates an Endpoint from a string that contains a user, host and
// port. Both User and Port are optional (depending on context). The host can
// be a domain name, IPv4 address or IPv6 address. If it's an IPv6, it must be
// enclosed in square brackets
func NewEndpoint(s string) (*Endpoint, error) {
endpoint := &Endpoint{
Host: s,
}
Expand All @@ -22,12 +27,21 @@ func NewEndpoint(s string) *Endpoint {
endpoint.Host = parts[1]
}

if parts := strings.Split(endpoint.Host, ":"); len(parts) > 1 {
endpoint.Host = parts[0]
endpoint.Port, _ = strconv.Atoi(parts[1])
host, port, err := net.SplitHostPort(endpoint.Host)
if err != nil {
// if error results from missing port in address, we ignore the error
// since either we'll use a random port assigned by the OS or set a
// suitable default directly, e.g. port 22 for SSH. Also worth noting,
// the host is set to the rest of the string since no port is provided
if !strings.Contains(err.Error(), "missing port in address") {
return nil, err
}
} else {
endpoint.Host = host
endpoint.Port, _ = strconv.Atoi(port)
}

return endpoint
return endpoint, nil
}

func (endpoint *Endpoint) String() string {
Expand Down
69 changes: 69 additions & 0 deletions endpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package sshtunnel_test

import (
"reflect"
"testing"

"github.com/elliotchance/sshtunnel"
)

func TestCreateEndpoint(t *testing.T) {
// these are test cases for which we expect no error to occur when
// constructing endpoints i.e. they should be correct
testCases := []struct {
input string
expectedEndpoint *sshtunnel.Endpoint
}{
{
"localhost:9000",
&sshtunnel.Endpoint{
Host: "localhost",
Port: 9000,
User: "",
},
},
{
"[email protected]",
&sshtunnel.Endpoint{
Host: "jumpbox.us-east-1.mydomain.com",
Port: 0,
User: "ec2-user",
},
},
{
"dqrsdfdssdfx.us-east-1.redshift.amazonaws.com:5439",
&sshtunnel.Endpoint{
Host: "dqrsdfdssdfx.us-east-1.redshift.amazonaws.com",
Port: 5439,
User: "",
},
},
{
"[email protected]:22", // IPv4 address
&sshtunnel.Endpoint{
Host: "1.2.3.4",
Port: 22,
User: "admin",
},
},
{
"admin@[2001:db8:1::ab9:C0A8:102]:22", // IPv6 address
&sshtunnel.Endpoint{
Host: "2001:db8:1::ab9:C0A8:102",
Port: 22,
User: "admin",
},
},
}
for i, tc := range testCases {
got, err := sshtunnel.NewEndpoint(tc.input)
if err != nil {
t.Errorf("unexpected error for correct input '%s': %v",
tc.input, err)
}
if !reflect.DeepEqual(got, tc.expectedEndpoint) {
t.Errorf("For test case %d, expected: %+v, got: %+v",
i, *tc.expectedEndpoint, *got)
}
}
}
20 changes: 15 additions & 5 deletions ssh_tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,25 @@ func (tunnel *SSHTunnel) Close() {
}

// NewSSHTunnel creates a new single-use tunnel. Supplying "0" for localport will use a random port.
func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string, localport string) *SSHTunnel {
func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string, localport string) (*SSHTunnel, error) {

localEndpoint := NewEndpoint("localhost:" + localport)
localEndpoint, err := NewEndpoint("localhost:" + localport)
if err != nil {
return nil, err
}

server := NewEndpoint(tunnel)
server, err := NewEndpoint(tunnel)
if err != nil {
return nil, err
}
if server.Port == 0 {
server.Port = 22
}

remoteEndpoint, err := NewEndpoint(destination)
if err != nil {
return nil, err
}
sshTunnel := &SSHTunnel{
Config: &ssh.ClientConfig{
User: server.User,
Expand All @@ -167,9 +177,9 @@ func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string, localp
},
Local: localEndpoint,
Server: server,
Remote: NewEndpoint(destination),
Remote: remoteEndpoint,
close: make(chan interface{}),
}

return sshTunnel
return sshTunnel, nil
}

0 comments on commit 3d56ada

Please sign in to comment.