diff --git a/Makefile b/Makefile index 47eef93..05d4d18 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ run: build docker: docker build -t glutton . - docker run --rm --cap-add=NET_ADMIN -it glutton + docker run --rm --cap-add=NET_ADMIN -it --name glutton glutton test: go test -v ./... diff --git a/config/rules.yaml b/config/rules.yaml index 2e73028..35d39f6 100644 --- a/config/rules.yaml +++ b/config/rules.yaml @@ -29,6 +29,9 @@ rules: - match: tcp dst port 11211 type: conn_handler target: memcache + - match: udp dst port 53 + type: conn_handler + target: dns - match: tcp type: conn_handler target: tcp diff --git a/glutton.go b/glutton.go index 7a54247..56e9907 100644 --- a/glutton.go +++ b/glutton.go @@ -171,8 +171,22 @@ func (g *Glutton) udpListen(wg *sync.WaitGroup) { if hfunc, ok := g.udpProtocolHandlers[rule.Target]; ok { data := buffer[:n] go func() { - if err := hfunc(g.ctx, srcAddr, dstAddr, data, md); err != nil { + response, err := hfunc(g.ctx, srcAddr, dstAddr, data, md) + if err != nil { g.Logger.Error("failed to handle UDP payload", producer.ErrAttr(err)) + return + } + if response != nil { + con, err := net.DialUDP("udp", dstAddr, srcAddr) + if err != nil { + g.Logger.Error("failed to dial UDP connection", producer.ErrAttr(err)) + return + } + defer con.Close() + _, err = con.Write(response) + if err != nil { + g.Logger.Error("failed to send UDP response", producer.ErrAttr(err)) + } } }() } diff --git a/go.mod b/go.mod index de00cbf..abf2207 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.10.0 + golang.org/x/net v0.33.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v2 v2.4.0 ) @@ -44,7 +45,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 // indirect - golang.org/x/net v0.33.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 // indirect diff --git a/protocols/protocols.go b/protocols/protocols.go index d5aba62..87a5d42 100644 --- a/protocols/protocols.go +++ b/protocols/protocols.go @@ -15,13 +15,16 @@ import ( type TCPHandlerFunc func(ctx context.Context, conn net.Conn, md connection.Metadata) error -type UDPHandlerFunc func(ctx context.Context, srcAddr, dstAddr *net.UDPAddr, data []byte, md connection.Metadata) error +type UDPHandlerFunc func(ctx context.Context, srcAddr, dstAddr *net.UDPAddr, data []byte, md connection.Metadata) ([]byte, error) // MapUDPProtocolHandlers map protocol handlers to corresponding protocol func MapUDPProtocolHandlers(log interfaces.Logger, h interfaces.Honeypot) map[string]UDPHandlerFunc { protocolHandlers := map[string]UDPHandlerFunc{} - protocolHandlers["udp"] = func(ctx context.Context, srcAddr, dstAddr *net.UDPAddr, data []byte, md connection.Metadata) error { - return udp.HandleUDP(ctx, srcAddr, dstAddr, data, md, log, h) + protocolHandlers["udp"] = func(ctx context.Context, srcAddr, dstAddr *net.UDPAddr, data []byte, md connection.Metadata) ([]byte, error) { + return nil, udp.HandleUDP(ctx, srcAddr, dstAddr, data, md, log, h) + } + protocolHandlers["dns"] = func(ctx context.Context, srcAddr, dstAddr *net.UDPAddr, data []byte, md connection.Metadata) ([]byte, error) { + return udp.HandleDNS(ctx, srcAddr, dstAddr, data, md, log, h) } return protocolHandlers } diff --git a/protocols/udp/dns.go b/protocols/udp/dns.go new file mode 100644 index 0000000..c670f52 --- /dev/null +++ b/protocols/udp/dns.go @@ -0,0 +1,115 @@ +package udp + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/mushorg/glutton/connection" + "github.com/mushorg/glutton/producer" + "github.com/mushorg/glutton/protocols/interfaces" + + "golang.org/x/net/dns/dnsmessage" +) + +var ( + maxRequestCount = 3 + throttleSeconds = int64(60) +) + +type throttleState struct { + count int + last int64 +} + +var throttle = map[string]throttleState{} + +func cleanupThrottle() { + for ip, state := range throttle { + if state.last+throttleSeconds < time.Now().Unix() { + delete(throttle, ip) + } + } +} + +func shouldThrottle(ip string) bool { + defer func() { go cleanupThrottle() }() + if _, ok := throttle[ip]; ok { + if throttle[ip].count > maxRequestCount { + if throttle[ip].last+throttleSeconds > time.Now().Unix() { + return true + } + throttle[ip] = throttleState{count: 1, last: time.Now().Unix()} + return false + } + throttle[ip] = throttleState{count: throttle[ip].count + 1, last: time.Now().Unix()} + return false + } + throttle[ip] = throttleState{count: 1, last: time.Now().Unix()} + return false +} + +// HandleDNS handles DNS packets +func HandleDNS(ctx context.Context, srcAddr, dstAddr *net.UDPAddr, data []byte, md connection.Metadata, log interfaces.Logger, h interfaces.Honeypot) ([]byte, error) { + if shouldThrottle(srcAddr.IP.String()) { + return nil, fmt.Errorf("throttling DNS requests") + } + var p dnsmessage.Parser + if _, err := p.Start(data); err != nil { + return nil, fmt.Errorf("failed to parse DNS query: %w", err) + } + + questions, err := p.AllQuestions() + if err != nil { + return nil, fmt.Errorf("failed to parse DNS questions: %w", err) + } + + msg := dnsmessage.Message{ + Header: dnsmessage.Header{Response: true, Authoritative: true}, + } + + for _, q := range questions { + msg.Questions = append(msg.Questions, q) + name, err := dnsmessage.NewName(q.Name.String()) + if err != nil { + return nil, err + } + + answer := dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: name, + Type: q.Type, + Class: q.Class, + TTL: 453, + }, + } + + switch q.Type { + case dnsmessage.TypeA: + answer.Body = &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}} + case dnsmessage.TypeAAAA: + answer.Body = &dnsmessage.AAAAResource{AAAA: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 127, 0, 0, 1}} + case dnsmessage.TypeCNAME: + answer.Body = &dnsmessage.CNAMEResource{CNAME: dnsmessage.MustNewName("localhost")} + case dnsmessage.TypeNS: + answer.Body = &dnsmessage.NSResource{NS: dnsmessage.MustNewName("localhost")} + case dnsmessage.TypePTR: + answer.Body = &dnsmessage.PTRResource{PTR: dnsmessage.MustNewName("localhost")} + case dnsmessage.TypeTXT: + answer.Body = &dnsmessage.TXTResource{TXT: []string{"localhost"}} + } + msg.Answers = append(msg.Answers, answer) + } + + buf, err := msg.Pack() + if err != nil { + return nil, fmt.Errorf("failed to pack DNS response: %w", err) + } + + if err := h.ProduceUDP("dns", srcAddr, dstAddr, md, data[:len(data)%1024], nil); err != nil { + log.Error("failed to produce DNS payload", producer.ErrAttr(err)) + return nil, err + } + return buf, nil +} diff --git a/protocols/udp/dns_test.go b/protocols/udp/dns_test.go new file mode 100644 index 0000000..011fe66 --- /dev/null +++ b/protocols/udp/dns_test.go @@ -0,0 +1,37 @@ +package udp + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestThrottling(t *testing.T) { + testIP := net.IPv4(192, 168, 1, 1).String() + throttle[testIP] = throttleState{count: maxRequestCount + 1, last: time.Now().Unix()} + tests := []struct { + name string + ip string + expected bool + }{ + { + name: "throttle", + ip: testIP, + expected: true, + }, + { + name: "no throttle", + ip: net.IPv4(192, 168, 1, 2).String(), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ok := shouldThrottle(tt.ip) + require.Equal(t, tt.expected, ok) + }) + } +}