From 1e2c348dbed37b110adfbd3e2346cf6d4247407b Mon Sep 17 00:00:00 2001 From: Sudhi Herle Date: Thu, 2 Jul 2020 17:02:58 -0700 Subject: [PATCH] Working quic server+client: - added mock quic server and client - fixed bug in quic server + streams (teach quicServer to start new go-routine to handle streams) --- go.mod | 4 +- go.sum | 10 +- gotun/conf.go | 38 ++++-- gotun/mocked_test.go | 285 +++++++++++++++++++++++++++++++++++++++- gotun/quic_test.go | 302 +++++++++++++++++++++++++++++++++++++++++++ gotun/quicdial.go | 17 ++- gotun/server.go | 89 ++++++++++++- gotun/tcp_test.go | 9 +- gotun/utils_test.go | 7 +- 9 files changed, 719 insertions(+), 42 deletions(-) create mode 100644 gotun/quic_test.go diff --git a/go.mod b/go.mod index 26155a5..c369882 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/opencoff/go-tunnel go 1.14 require ( - github.com/lucas-clemente/quic-go v0.17.1 - github.com/opencoff/go-logger v0.2.0 + github.com/lucas-clemente/quic-go v0.17.2 + github.com/opencoff/go-logger v0.2.2 github.com/opencoff/go-ratelimit v0.7.0 github.com/opencoff/pflag v0.5.0 gopkg.in/yaml.v2 v2.3.0 diff --git a/go.sum b/go.sum index ae3719b..b29e143 100644 --- a/go.sum +++ b/go.sum @@ -62,8 +62,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lucas-clemente/quic-go v0.17.1 h1:ezsH76xpn6hKugfsXUy6voIJBFmAOwnM/Oy9F4b/n+M= -github.com/lucas-clemente/quic-go v0.17.1/go.mod h1:I0+fcNTdb9eS1ZcjQZbDVPGchJ86chcIxPALn9lEJqE= +github.com/lucas-clemente/quic-go v0.17.2 h1:4iQInIuNQkPNZmsy9rCnwuOzpH0qGnDo4jn0QfI/qE4= +github.com/lucas-clemente/quic-go v0.17.2/go.mod h1:I0+fcNTdb9eS1ZcjQZbDVPGchJ86chcIxPALn9lEJqE= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/qpack v0.1.0/go.mod h1:LFt1NU/Ptjip0C2CPkhimBz5CGE3WGDAUWqna+CNTrI= @@ -80,8 +80,10 @@ github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.11.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.8.1/go.mod h1:Ho0h+IUsWyvy1OpqCwxlQ/21gkhVunqlU8fDGcoTdcA= -github.com/opencoff/go-logger v0.2.0 h1:8YvwfTljg0/kR87m8sEL0QFJczT694i30XiqGyaQjWk= -github.com/opencoff/go-logger v0.2.0/go.mod h1:0uZokzKt+uCJkbz12vSoChasSJoLc2aNuCS0A/U7Dqs= +github.com/opencoff/go-logger v0.2.1 h1:xQNQDLLSGh8gOTNVzc3hksRhQVyWCjSyOO27ruzlF9I= +github.com/opencoff/go-logger v0.2.1/go.mod h1:0uZokzKt+uCJkbz12vSoChasSJoLc2aNuCS0A/U7Dqs= +github.com/opencoff/go-logger v0.2.2 h1:xxQFaYbyXEYd4YL5QraKu2YUc1iEv5ACvx8L3uUxY4s= +github.com/opencoff/go-logger v0.2.2/go.mod h1:0uZokzKt+uCJkbz12vSoChasSJoLc2aNuCS0A/U7Dqs= github.com/opencoff/go-ratelimit v0.7.0 h1:hXadrYOPFlS6l+jmrok87BX0Oh+oeWuLelNYXAUgGqA= github.com/opencoff/go-ratelimit v0.7.0/go.mod h1:CZOjkRlhRo07XJt81kMF0NfOP7cYTfhZG1zPU5AAK78= github.com/opencoff/golang-lru v0.6.0 h1:e5jyAHA4AJbohh8mmPB6JpTvZMVrnh3z5GFAqTADVm8= diff --git a/gotun/conf.go b/gotun/conf.go index dbadf4b..3364501 100644 --- a/gotun/conf.go +++ b/gotun/conf.go @@ -9,6 +9,7 @@ package main import ( + "bytes" "crypto/tls" "crypto/x509" "fmt" @@ -432,38 +433,49 @@ func (c *Conf) Path(nm string) string { // Print config in human readable format func (c *Conf) Dump(w io.Writer) { - fmt.Fprintf(w, "config: %d listeners\n", len(c.Listen)) + + b := &bytes.Buffer{} + + fmt.Fprintf(b, "config: %d listeners\n", len(c.Listen)) for _, l := range c.Listen { - fmt.Fprintf(w, "listen on %s", l.Addr) + fmt.Fprintf(b, "listen on %s", l.Addr) + if l.Quic { + fmt.Fprintf(b, " quic") + } if t := l.Tls; t != nil { if len(t.Sni) > 0 { - fmt.Fprintf(w, " with tls sni using certstore %s", t.Sni) + fmt.Fprintf(b, " bith tls sni using certstore %s", t.Sni) } else { - fmt.Fprintf(w, " with tls using cert %s, key %s", + fmt.Fprintf(b, " bith tls using cert %s, key %s", t.Cert, t.Key) } if t.ClientCert == "required" { - fmt.Fprintf(w, " requiring client auth") + fmt.Fprintf(b, " requiring client auth") } else if t.ClientCert == "optional" { - fmt.Fprintf(w, " with optional client auth") + fmt.Fprintf(b, " bith optional client auth") } } c := &l.Connect - fmt.Fprintf(w, "\n\tconnect to %s", c.Addr) + fmt.Fprintf(b, "\n\tconnect to %s", c.Addr) if len(c.Bind) > 0 { - fmt.Fprintf(w, " from %s", c.Bind) + fmt.Fprintf(b, " from %s", c.Bind) } if len(c.ProxyProtocol) > 0 { - fmt.Fprintf(w, " using proxy-protocol %s", c.ProxyProtocol) + fmt.Fprintf(b, " using proxy-protocol %s", c.ProxyProtocol) + } + if c.Quic { + fmt.Fprintf(b, " with quic") } if t := c.Tls; t != nil { - fmt.Fprintf(w, " using tls") + fmt.Fprintf(b, " using tls") if len(t.Cert) > 0 { - fmt.Fprintf(w, " cert %s, key %s", t.Cert, t.Key) + fmt.Fprintf(b, " cert %s, key %s", t.Cert, t.Key) } - fmt.Fprintf(w, " and ca-bundle %s", t.Ca) + fmt.Fprintf(b, " and ca-bundle %s", t.Ca) } - fmt.Fprintf(w, "\n") + fmt.Fprintf(b, "\n") } + + w.Write(b.Bytes()) } diff --git a/gotun/mocked_test.go b/gotun/mocked_test.go index 3e09679..0a89f2e 100644 --- a/gotun/mocked_test.go +++ b/gotun/mocked_test.go @@ -21,6 +21,8 @@ import ( "sync" "testing" "time" + + "github.com/lucas-clemente/quic-go" ) const ( @@ -59,10 +61,10 @@ func newTcpServer(network, addr string, tcfg *tls.Config, t *testing.T) *tcpserv } func (s *tcpserver) stop() { - s.t.Logf("stopping mock server on %s", s.Addr()) s.cancel() s.Close() s.wg.Wait() + s.t.Logf("stopped mock tcp server on %s", s.Addr()) } func (s *tcpserver) accept() { @@ -71,7 +73,7 @@ func (s *tcpserver) accept() { assert := newAsserter(s.t) done := s.ctx.Done() addr := s.Addr().String() - s.t.Logf("%s: mock server waiting for new conn ..\n", addr) + s.t.Logf("%s: mock tcp server waiting for new conn ..\n", addr) for { conn, err := s.Accept() select { @@ -167,8 +169,6 @@ type tcpclient struct { t *testing.T ctx context.Context cancel context.CancelFunc - - wg sync.WaitGroup } func newTcpClient(network, addr string, tcfg *tls.Config, t *testing.T) *tcpclient { @@ -197,13 +197,14 @@ func (c *tcpclient) start(n int) error { return err } - c.t.Logf("mock tcp client: connected to %s\n", c.addr) + c.t.Logf("mock tcp client: %s connected to %s\n", c.LocalAddr().String(), c.addr) return c.loop(n) } func (c *tcpclient) stop() { c.cancel() c.Close() + c.t.Logf("mock tcp client %s-%s stopped\n", c.LocalAddr().String(), c.addr) } func (c *tcpclient) loop(n int) error { @@ -215,7 +216,7 @@ func (c *tcpclient) loop(n int) error { defer func() { c.Close() - c.t.Logf("mock tcp client: closing conn to %s\n", addr) + c.t.Logf("mock tcp client: closing conn %s\n", from) }() if c.tls != nil { @@ -308,6 +309,278 @@ func readfull(fd io.Reader, b []byte) (int, error) { } type quicserver struct { + quic.Listener + + t *testing.T + + nr int + nw int + + tls *tls.Config + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +func newQuicServer(network, addr string, tcfg *tls.Config, t *testing.T) *quicserver { + assert := newAsserter(t) + + la, err := net.ResolveUDPAddr("udp", addr) + assert(err == nil, "can't resolve addr %s: %s", addr, err) + + ln, err := net.ListenUDP("udp", la) + assert(err == nil, "can't listen UDP %s: %s", addr, err) + + q, err := quic.Listen(ln, tcfg, &quic.Config{}) + assert(err == nil, "can't listen quic %s: %s", addr, err) + + qs := &quicserver{ + Listener: q, + t: t, + } + + qs.ctx, qs.cancel = context.WithCancel(context.Background()) + qs.wg.Add(1) + go qs.accept() + return qs +} + +func (q *quicserver) stop() { + q.cancel() + q.Close() + q.wg.Wait() + q.t.Logf("stopped mock quic server on %s", q.Addr()) +} + +func (q *quicserver) accept() { + assert := newAsserter(q.t) + defer q.wg.Done() + + q.t.Logf("mock quic server listening on %s ..\n", q.Addr().String()) + done := q.ctx.Done() + for { + sess, err := q.Accept(q.ctx) + select { + case <-done: + return + default: + } + + assert(err == nil, "can't accept quic: %s", err) + + q.wg.Add(1) + go q.serveSession(sess) + } +} + +func (q *quicserver) serveSession(sess quic.Session) { + defer q.wg.Done() + + assert := newAsserter(q.t) + done := q.ctx.Done() + for { + conn, err := sess.AcceptStream(q.ctx) + select { + case <-done: + return + default: + } + assert(err == nil, "can't accept quic-stream: %s", err) + + qc := &qconn{ + Stream: conn, + s: sess, + } + + q.t.Logf("mock quic server accepted new stream %d from %s\n", conn.StreamID(), qc.RemoteAddr().String()) + q.wg.Add(1) + go q.relay(qc) + } +} + +func (q *quicserver) relay(fd *qconn) { + assert := newAsserter(q.t) + done := q.ctx.Done() + addr := fd.RemoteAddr().String() + from := fmt.Sprintf("%s-%s", fd.LocalAddr().String(), addr) + + defer func() { + q.wg.Done() + fd.Close() + q.t.Logf("mock quic server: closed conn from %s\n", from) + }() + + buf := make([]byte, IOSIZE) + var csum [sha256.Size]byte + + // All timeouts are v short + rto := 5 * time.Second + + h := sha256.New() + for i := 0; ; i++ { + fd.SetReadDeadline(time.Now().Add(rto)) + nr, err := readfull(fd, buf) + select { + case <-done: + return + default: + } + + if errors.Is(err, io.EOF) || nr == 0 { + q.t.Logf("%s: EOF? nr=%d, err %s\n", from, nr, err) + return + } + assert(err == nil, "%s: read err: %s", from, err) + + q.nr += nr + h.Reset() + h.Write(buf[:nr]) + sum := h.Sum(csum[:0]) + + //q.t.Logf("%s: %d: RX %d [%x]\n", from, i, nr, sum[:]) + fd.SetWriteDeadline(time.Now().Add(rto)) + nw, err := writefull(fd, sum[:]) + select { + case <-done: + return + default: + } + + assert(err == nil, "%s: write err: %s", from, err) + assert(nw == len(sum[:]), "%s: partial write; exp %d, saw %d", from, len(sum[:]), nw) + + //q.t.Logf("%s: RX %d bytes, TX %d\n", from, nr, len(sum[:])) + q.nw += len(sum[:]) + } +} + +type quicclient struct { + *qconn + + session quic.Session + + nr int + nw int + + t *testing.T + + ctx context.Context + cancel context.CancelFunc +} + +// abstraction to make this look like a net.Conn +type qconn struct { + quic.Stream + s quic.Session +} + +// qAddr is defined in quicdial.go +func (qc *qconn) LocalAddr() net.Addr { + return &qAddr{ + a: qc.s.LocalAddr(), + id: qc.StreamID(), + } +} + +func (qc *qconn) RemoteAddr() net.Addr { + return &qAddr{ + a: qc.s.RemoteAddr(), + id: qc.StreamID(), + } +} + +func newQuicClient(network, addr string, tcfg *tls.Config, t *testing.T) *quicclient { + assert := newAsserter(t) + ctx, cancel := context.WithCancel(context.Background()) + + d, err := quic.DialAddrContext(ctx, addr, tcfg, &quic.Config{}) + assert(err == nil, "can't dial quic %s: %s", addr, err) + + st := d.ConnectionState() + t.Logf("mock quic client: connected to %s [%s]\n", addr, st.ServerName) + + fd, err := d.OpenStream() + assert(err == nil, "can't open quic stream to %s: %s", addr, err) + + t.Logf("mock quic client: connected to %s\n", addr) + qc := &qconn{ + Stream: fd, + s: d, + } + + q := &quicclient{ + qconn: qc, + session: d, + t: t, + ctx: ctx, + cancel: cancel, + } + + return q +} + +func (q *quicclient) start(n int) { + assert := newAsserter(q.t) + done := q.ctx.Done() + addr := q.RemoteAddr().String() + from := fmt.Sprintf("%s-%s", q.LocalAddr().String(), addr) + + defer func() { + q.cancel() + q.Close() + q.t.Logf("mock quic client: closing conn %s\n", from) + }() + + buf := make([]byte, IOSIZE) + rand.Read(buf) + fd := q.qconn + + var sumr, csuma [sha256.Size]byte + + h := sha256.New() + for i := 0; i < n; i++ { + nw, err := writefull(fd, buf) + select { + case <-done: + return + default: + } + assert(err == nil, "%s: write err: %s", from, err) + assert(nw == len(buf), "%s: partial write, exp %d, saw %d", from, len(buf), nw) + + q.nw += nw + + h.Reset() + h.Write(buf) + suma := h.Sum(csuma[:0]) + + //c.t.Logf("%s: %d: TX %d [%x]\n", from, i, nw, suma[:]) + + nr, err := readfull(fd, sumr[:]) + select { + case <-done: + return + default: + } + + if errors.Is(err, io.EOF) || nr == 0 { + q.t.Logf("%s: EOF? nr %d\n", from, nr) + return + } + assert(err == nil, "%s: read err: %s", from, err) + assert(nr == len(sumr[:]), "%s: partial read, exp %d, saw %d", from, len(sumr[:]), nr) + + assert(byteEq(suma[:], sumr[:]), "%s: cksum mismatch;\nexp %x\nsaw %x", from, suma[:], sumr[:]) + inc(buf) + q.nr += len(sumr[:]) + //c.t.Logf("%s: TX %d, RX %d\n", addr, nw, len(sumr[:])) + } + return +} + +func (q *quicclient) stop() { + q.cancel() + q.Close() } type pki struct { diff --git a/gotun/quic_test.go b/gotun/quic_test.go new file mode 100644 index 0000000..7c74440 --- /dev/null +++ b/gotun/quic_test.go @@ -0,0 +1,302 @@ +// quic_test.go - test quic to {TCP, TLS} endpoints + +package main + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "sync" + "testing" +) + +// return a configured Conf +func quicSetup(lport, cport int) *Conf { + + // TCP connect + // We'll spin up a simple server on the connect endpoint + + laddr := fmt.Sprintf("127.0.0.1:%d", lport) + caddr := fmt.Sprintf("127.0.0.1:%d", cport) + + lc := &ListenConf{ + Addr: laddr, + Quic: true, + Connect: ConnectConf{ + Addr: caddr, + }, + } + + c := &Conf{ + Logging: "NONE", + Listen: []*ListenConf{lc}, + } + + return defaults(c) +} + +// Client -> gotun Quic +// gotun -> backend TCP +func TestQuicToTcp(t *testing.T) { + assert := newAsserter(t) + + pki, err := newPKI() + assert(err == nil, "can't create PKI: %s", err) + + cfg := quicSetup(8005, 8006) + lc := cfg.Listen[0] + + cert, err := pki.ServerCert("server.name", lc.Addr) + assert(err == nil, "can't create server cert: %s", err) + + pool := x509.NewCertPool() + pool.AddCert(pki.ca) + tlsCfg := &tls.Config{ + MinVersion: tls.VersionTLS13, + ServerName: "server.name", + SessionTicketsDisabled: true, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + + // Best disabled, as they don't provide Forward Secrecy, + // but might be necessary for some clients + // tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + // tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + }, + NextProtos: []string{"relay"}, + RootCAs: pool, + Certificates: []tls.Certificate{cert}, + ClientAuth: tls.NoClientCert, + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.X25519, + }, + } + + lc.serverCfg = tlsCfg + + // client TLS config; we need the proper root. But no client Certs. + ctlsCfg := *tlsCfg + ctlsCfg.Certificates = []tls.Certificate{} + + // create a server on the other end of a connector + s := newTcpServer("tcp", lc.Connect.Addr, nil, t) + assert(s != nil, "server creation failed") + + log := newLogger(t) + gt := NewServer(lc, cfg, log) + gt.Start() + + // Now create a mock client to send data to mock server + c := newQuicClient("udp", lc.Addr, &ctlsCfg, t) + assert(c != nil, "client creation failed") + + c.start(10) + + assert(c.nw == s.nr, "i/o mismatch: client TX %d, server RX %d", c.nw, s.nr) + assert(c.nr == s.nw, "i/o mismatch: server TX %d, client RX %d", s.nw, c.nr) + + c.stop() + s.stop() + gt.Stop() + log.Close() +} + +// Client -> gotun Quic with client auth +// gotun -> backend TCP +func TestQuicAuthToTcp(t *testing.T) { + assert := newAsserter(t) + + pki, err := newPKI() + assert(err == nil, "can't create PKI: %s", err) + + pkic, err := newPKI() + assert(err == nil, "can't create client PKI: %s", err) + + clientCert, err := pkic.ClientCert("client.name") + assert(err == nil, "can't create client cert: %s", err) + + spool := x509.NewCertPool() + spool.AddCert(pki.ca) + + cpool := x509.NewCertPool() + cpool.AddCert(pkic.ca) + + cfg := quicSetup(8008, 8009) + lc := cfg.Listen[0] + + cert, err := pki.ServerCert("server.name", lc.Addr) + assert(err == nil, "can't create server cert: %s", err) + + tlsCfg := &tls.Config{ + MinVersion: tls.VersionTLS13, + ServerName: "server.name", + SessionTicketsDisabled: true, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + + // Best disabled, as they don't provide Forward Secrecy, + // but might be necessary for some clients + // tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + // tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + }, + NextProtos: []string{"relay"}, + RootCAs: spool, + Certificates: []tls.Certificate{cert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: cpool, + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.X25519, + }, + } + + lc.serverCfg = tlsCfg + + // client TLS config; we need the proper root + ctlsCfg := *tlsCfg + ctlsCfg.Certificates = []tls.Certificate{clientCert} + + // create a server on the other end of a connector + s := newTcpServer("tcp", lc.Connect.Addr, nil, t) + assert(s != nil, "server creation failed") + + log := newLogger(t) + gt := NewServer(lc, cfg, log) + gt.Start() + + // Now create a mock client to send data to mock server + c := newQuicClient("udp", lc.Addr, &ctlsCfg, t) + assert(c != nil, "client creation failed") + + c.start(10) + + assert(c.nw == s.nr, "i/o mismatch: client TX %d, server RX %d", c.nw, s.nr) + assert(c.nr == s.nw, "i/o mismatch: server TX %d, client RX %d", s.nw, c.nr) + + c.stop() + s.stop() + gt.Stop() + log.Close() +} + +// Client -> tcp +// gotun -> backend quic +func TestTcpToQuicAuth(t *testing.T) { + assert := newAsserter(t) + + log := newLogger(t) + + pki, err := newPKI() + assert(err == nil, "can't create PKI: %s", err) + + pkic, err := newPKI() + assert(err == nil, "can't create client PKI: %s", err) + + clientCert, err := pkic.ClientCert("client.name") + assert(err == nil, "can't create client cert: %s", err) + + spool := x509.NewCertPool() + spool.AddCert(pki.ca) + + cpool := x509.NewCertPool() + cpool.AddCert(pkic.ca) + + cfg := testSetup(8008, 8009) + lc := cfg.Listen[0] + + // we want outgoing connect to be quic + lc.Connect.Quic = true + + cfg.Dump(log) + + cert, err := pki.ServerCert("server.name", lc.Addr) + assert(err == nil, "can't create server cert: %s", err) + + tlsCfg := &tls.Config{ + MinVersion: tls.VersionTLS13, + ServerName: "server.name", + SessionTicketsDisabled: true, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + + // Best disabled, as they don't provide Forward Secrecy, + // but might be necessary for some clients + // tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + // tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + }, + NextProtos: []string{"relay"}, + RootCAs: spool, + Certificates: []tls.Certificate{cert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: cpool, + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.X25519, + }, + } + + // client TLS config; we need the proper root + ctlsCfg := *tlsCfg + ctlsCfg.Certificates = []tls.Certificate{clientCert} + + // outbound connection is a Quic client + lc.clientCfg = &ctlsCfg + + // create a server on the other end of a connector + s := newQuicServer("udp", lc.Connect.Addr, tlsCfg, t) + assert(s != nil, "server creation failed") + + gt := NewServer(lc, cfg, log) + gt.Start() + + // Now create a mock client to send data to mock server + c := newTcpClient("tcp", lc.Addr, nil, t) + assert(c != nil, "client creation failed") + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + err := c.start(10) + assert(err == nil, "tcp client can't connect: %s", err) + wg.Done() + }() + + // now we test muxing multiple inbound TCPs to a single + // quic session + multiple streams + + c2 := newTcpClient("tcp", lc.Addr, nil, t) + assert(c2 != nil, "second client creation failed") + + wg.Add(1) + go func() { + err := c2.start(10) + assert(err == nil, "tcp client can't connect: %s", err) + wg.Done() + }() + + wg.Wait() + assert(c.nw+c2.nw == s.nr, "i/o mismatch: client TX %d; %d, server RX %d", c.nw, c2.nw, s.nr) + assert(c.nr+c2.nr == s.nw, "i/o mismatch: server TX %d, client RX %d; %d", s.nw, c.nr, c2.nr) + + c.stop() + s.stop() + gt.Stop() + log.Close() +} diff --git a/gotun/quicdial.go b/gotun/quicdial.go index c91e38c..8ba6098 100644 --- a/gotun/quicdial.go +++ b/gotun/quicdial.go @@ -71,24 +71,27 @@ func (q *quicDialer) Dial(network, addr string, _ Conn, ctx context.Context) (Co if err != nil { q.Unlock() - q.log.Warn("quic: can't dial %s: %s", addr, err) + q.log.Warn("quic-client: can't dial %s: %s", addr, err) return nil, fmt.Errorf("quic: %s: %w", addr, err) } state := d.ConnectionState() - q.log.Debug("quic: Established new session with %s [%s]", addr, state.ServerName) + q.log.Debug("quic-client: established new session %s-%s [%s]", d.LocalAddr().String(), + d.RemoteAddr().String(), state.ServerName) q.dest[key] = d } q.Unlock() - t, err := d.OpenStream() + t, err := d.OpenStreamSync(ctx) if err != nil { - q.log.Warn("quic: %s: can't open new stream: %s", addr, err) + q.log.Warn("quic-client: %s: can't open new stream: %s", addr, err) return nil, fmt.Errorf("quic: %s: %w", addr, err) } - log := q.log.New(fmt.Sprintf("%s.%#x", addr, t.StreamID()), 0) - log.Debug("quic: %s: opened new stream %#x", addr, t.StreamID()) + connstr := fmt.Sprintf("%s-%s.%#x", d.LocalAddr().String(), d.RemoteAddr().String(), t.StreamID()) + log := q.log.New(connstr, 0) + log.Debug("quic-client: opened new stream %#x", t.StreamID()) + c := &qConn{ Stream: t, s: d, @@ -112,7 +115,7 @@ func (a *qAddr) String() string { return fmt.Sprintf("%s.%#x", a.a.String(), a.id) } -// implement new.Conn interfaces too +// implement net.Conn interfaces too func (c *qConn) LocalAddr() net.Addr { return &qAddr{ a: c.s.LocalAddr(), diff --git a/gotun/server.go b/gotun/server.go index 11ca58e..ee25911 100644 --- a/gotun/server.go +++ b/gotun/server.go @@ -244,10 +244,11 @@ func (p *QuicServer) Stop() { func (p *TCPServer) serveTCP() { n := 0 + done := p.ctx.Done() for { conn, err := p.Accept() select { - case <-p.ctx.Done(): + case <-done: return default: } @@ -275,9 +276,16 @@ func (p *TCPServer) serveTCP() { func (p *QuicServer) serveQuic() { n := 0 + done := p.ctx.Done() for { p.rl.Wait(p.ctx) sess, err := p.Accept(p.ctx) + select { + case <-done: + return + default: + } + if err != nil { n += 1 if n >= 10 { @@ -293,8 +301,26 @@ func (p *QuicServer) serveQuic() { // wait for per-host ratelimiter p.rl.WaitHost(p.ctx, sess.RemoteAddr()) + n = 0 + p.wg.Add(1) + go p.serviceSession(sess) + } +} + +func (p *QuicServer) serviceSession(sess quic.Session) { + defer p.wg.Done() + done := p.ctx.Done() + + n := 0 + for { // we also accept the corresponding stream conn, err := sess.AcceptStream(p.ctx) + select { + case <-done: + return + default: + } + if err != nil { n += 1 if n >= 10 { @@ -312,7 +338,7 @@ func (p *QuicServer) serveQuic() { Stream: conn, s: sess, } - peer := qc.RemoteAddr() + peer := qc.LocalAddr() ctx := context.WithValue(p.ctx, "client", peer.String()) qc.log = p.log.New(peer.String(), 0) @@ -362,8 +388,8 @@ func (p *Server) handleConn(conn Conn, ctx context.Context, log *L.Logger) { // we grab the printable info before the socket is closed lhs_theirs := conn.RemoteAddr().String() - inbound := fmt.Sprintf("%s-%s", lhs_theirs, conn.LocalAddr().String()) rhs_theirs := peer.RemoteAddr().String() + inbound := fmt.Sprintf("%s-%s", conn.LocalAddr().String(), lhs_theirs) outbound := fmt.Sprintf("%s-%s", peer.LocalAddr().String(), rhs_theirs) // we really need to log this in the parent logger @@ -417,8 +443,60 @@ func (p *Server) putBuf(b []byte) { p.pool.Put(b) } +func (p *Server) cancellableCopy(d, s Conn, buf []byte, ctx context.Context, log *L.Logger) (x, y int) { + rto := time.Duration(p.Timeout.Read) * time.Second + wto := time.Duration(p.Timeout.Write) * time.Second + done := ctx.Done() + for { + s.SetReadDeadline(time.Now().Add(rto)) + nr, err := s.Read(buf) + select { + case <-done: + return + default: + } + + if err != nil { + if err != io.EOF && err != context.Canceled && !isReset(err) { + log.Debug("%s: nr %d, read err %s", s.LocalAddr().String(), nr, err) + return + } + } + + switch { + case nr == 0: + log.Debug("EOF") + return + + case nr > 0: + d.SetWriteDeadline(time.Now().Add(wto)) + x += nr + nw, err := d.Write(buf[:nr]) + select { + case <-done: + return + default: + } + + if err != nil { + log.Debug("%s: Write Err %s", d.RemoteAddr().String(), err) + return + } + if nw != nr { + return + } + y += nw + } + if err != nil { + log.Debug("%s: read error: %s", s.RemoteAddr().String(), err) + return + } + } +} + +/* // interruptible copy -func (p *Server) cancellableCopy(d, s Conn, buf []byte, ctx context.Context, log *L.Logger) (r, w int) { +func (p *Server) xcancellableCopy(d, s Conn, buf []byte, ctx context.Context, log *L.Logger) (r, w int) { ch := make(chan bool) go func() { @@ -441,7 +519,7 @@ func (p *Server) cancellableCopy(d, s Conn, buf []byte, ctx context.Context, log } // copy from 's' to 'd' using 'buf' -func (p *Server) copyBuf(d, s Conn, buf []byte, log *L.Logger) (x, y int) { +func (p *Server) xcopyBuf(d, s Conn, buf []byte, log *L.Logger) (x, y int) { rto := time.Duration(p.Timeout.Read) * time.Second wto := time.Duration(p.Timeout.Write) * time.Second for { @@ -529,6 +607,7 @@ func (p *TCPServer) Accept() (net.Conn, error) { return nc, nil } } +*/ func (s *Server) getSNIHandler(dir string, log *L.Logger) func(h *tls.ClientHelloInfo) (*tls.Certificate, error) { conf := s.conf diff --git a/gotun/tcp_test.go b/gotun/tcp_test.go index 6070c14..a653c25 100644 --- a/gotun/tcp_test.go +++ b/gotun/tcp_test.go @@ -39,6 +39,9 @@ func testSetup(lport, cport int) *Conf { func TestTcpToTcp(t *testing.T) { assert := newAsserter(t) + // create a logger + log := newLogger(t) + cfg := testSetup(9000, 9001) lc := cfg.Listen[0] @@ -47,9 +50,6 @@ func TestTcpToTcp(t *testing.T) { s := newTcpServer("tcp", lc.Connect.Addr, nil, t) assert(s != nil, "server creation failed") - // create a logger - log := newLogger(t) - gt := NewServer(lc, cfg, log) gt.Start() @@ -57,7 +57,8 @@ func TestTcpToTcp(t *testing.T) { c := newTcpClient("tcp", lc.Addr, nil, t) assert(c != nil, "client creation failed") - c.start(10) + err := c.start(10) + assert(err == nil, "can't start tcp client: %s", err) assert(c.nw == s.nr, "i/o mismatch: client TX %d, server RX %d", c.nw, s.nr) assert(c.nr == s.nw, "i/o mismatch: server TX %d, client RX %d", s.nw, c.nr) diff --git a/gotun/utils_test.go b/gotun/utils_test.go index c6b659a..31c7ee6 100644 --- a/gotun/utils_test.go +++ b/gotun/utils_test.go @@ -44,7 +44,12 @@ type logWriter struct { } func (a *logWriter) Write(b []byte) (int, error) { - a.Logf("# %s\n", string(b)) + var nl string + + if b[len(b)-1] != '\n' { + nl = "\n" + } + a.Logf("# %s%s", string(b), nl) return len(b), nil }