From ecaa10727788f1a6392efd4c951d083afc5d8acf Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Wed, 27 Dec 2023 23:23:18 +0800 Subject: [PATCH] test: add test case for edns. --- src/dns_server.c | 9 +- src/util.c | 49 +++- test/cases/test-bind.cc | 2 +- test/cases/test-cache.cc | 2 +- test/cases/test-edns.cc | 341 ++++++++++++++++++++++++++ test/cases/test-same-pending-query.cc | 4 +- test/client.cc | 21 ++ test/client.h | 3 + 8 files changed, 421 insertions(+), 10 deletions(-) create mode 100644 test/cases/test-edns.cc diff --git a/src/dns_server.c b/src/dns_server.c index 9f6a2a0f8f4..6d5be71faf7 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -1421,7 +1421,7 @@ static int _dns_server_get_cache_timeout(struct dns_request *request, struct dns return ttl + 1; } - if (dns_conf_prefetch && _dns_cache_is_specify_packet(request->qtype) != 0) { + if (dns_conf_prefetch) { prefetch_time = 1; } @@ -2403,6 +2403,9 @@ static int _dns_server_request_complete_with_all_IPs(struct dns_request *request context.do_audit = 1; context.do_reply = 1; context.reply_ttl = _dns_server_get_reply_ttl(request, ttl); + if (with_all_ips == 0) { + context.cache_ttl = _dns_server_get_reply_ttl(request, ttl); + } context.skip_notify_count = 1; context.select_all_best_ip = with_all_ips; context.no_release_parent = 1; @@ -5824,7 +5827,7 @@ static int _dns_server_do_query(struct dns_request *request, int skip_notify_eve if (_dns_server_process_dns64(request) != 0) { goto errout; } - + // Get reference for DNS query request->request_wait++; _dns_server_request_get(request); @@ -7052,7 +7055,7 @@ static dns_cache_tmout_action_t _dns_server_cache_expired(struct dns_cache *dns_ return DNS_CACHE_TMOUT_ACTION_DEL; } - if (dns_conf_prefetch == 1 && _dns_cache_is_specify_packet(dns_cache->info.qtype) != 0) { + if (dns_conf_prefetch == 1) { if (dns_conf_serve_expired == 1) { return _dns_server_prefetch_expired_domain(dns_cache); } else { diff --git a/src/util.c b/src/util.c index be6263d0a29..f2c62b1d7b8 100644 --- a/src/util.c +++ b/src/util.c @@ -1828,7 +1828,7 @@ daemon_ret daemon_run(int *wstatus) return DAEMON_RET_ERR; } -#ifdef DEBUG +#if defined(DEBUG) || defined(TEST) struct _dns_read_packet_info { int data_len; int message_len; @@ -1922,8 +1922,9 @@ static int _dns_debug_display(struct dns_packet *packet) struct dns_rrs *rrs = NULL; int rr_count = 0; char req_host[MAX_IP_LEN]; + int ret; - for (j = 1; j < DNS_RRS_END; j++) { + for (j = 1; j < DNS_RRS_OPT; j++) { rrs = dns_get_rrs_start(packet, j, &rr_count); printf("section: %d\n", j); for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { @@ -1949,7 +1950,6 @@ static int _dns_debug_display(struct dns_packet *packet) unsigned short priority = 0; unsigned short weight = 0; unsigned short port = 0; - int ret = 0; char name[DNS_MAX_CNAME_LEN] = {0}; char target[DNS_MAX_CNAME_LEN]; @@ -1969,7 +1969,6 @@ static int _dns_debug_display(struct dns_packet *packet) char target[DNS_MAX_CNAME_LEN] = {0}; struct dns_https_param *p = NULL; int priority = 0; - int ret = 0; ret = dns_get_HTTPS_svcparm_start(rrs, &p, name, DNS_MAX_CNAME_LEN, &ttl, &priority, target, DNS_MAX_CNAME_LEN); @@ -2068,6 +2067,48 @@ static int _dns_debug_display(struct dns_packet *packet) printf("\n"); } + rr_count = 0; + rrs = dns_get_rrs_start(packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return 0; + } + + printf("section opt:\n"); + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { + switch (rrs->type) { + case DNS_OPT_T_TCP_KEEPALIVE: { + unsigned short idle_timeout = 0; + ret = dns_get_OPT_TCP_KEEPALIVE(rrs, &idle_timeout); + if (idle_timeout == 0) { + continue; + } + + printf("tcp keepalive: %d\n", idle_timeout); + } break; + case DNS_OPT_T_ECS: { + struct dns_opt_ecs ecs; + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, &ecs); + if (ret != 0) { + continue; + } + printf("ecs family: %d, src_prefix: %d, scope_prefix: %d, ", ecs.family, ecs.source_prefix, + ecs.scope_prefix); + if (ecs.family == 1) { + char ip[16] = {0}; + inet_ntop(AF_INET, ecs.addr, ip, sizeof(ip)); + printf("ecs address: %s\n", ip); + } else if (ecs.family == 2) { + char ip[64] = {0}; + inet_ntop(AF_INET6, ecs.addr, ip, sizeof(ip)); + printf("ecs address: %s\n", ip); + } + } break; + default: + break; + } + } + return 0; } diff --git a/test/cases/test-bind.cc b/test/cases/test-bind.cc index 0f327250570..bbf0baa3149 100644 --- a/test/cases/test-bind.cc +++ b/test/cases/test-bind.cc @@ -104,7 +104,7 @@ server 127.0.0.1:61053 std::cout << client.GetResult() << std::endl; ASSERT_EQ(client.GetAnswerNum(), 1); EXPECT_EQ(client.GetStatus(), "NOERROR"); - EXPECT_GE(client.GetAnswer()[0].GetTTL(), 609); + EXPECT_GE(client.GetAnswer()[0].GetTTL(), 3); EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); } diff --git a/test/cases/test-cache.cc b/test/cases/test-cache.cc index 3a2e35688e2..824a59c5f96 100644 --- a/test/cases/test-cache.cc +++ b/test/cases/test-cache.cc @@ -178,7 +178,7 @@ rr-ttl-reply-max 6 ASSERT_EQ(client.GetAnswerNum(), 1); EXPECT_EQ(client.GetStatus(), "NOERROR"); EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); - EXPECT_GE(client.GetAnswer()[0].GetTTL(), 5); + EXPECT_GE(client.GetAnswer()[0].GetTTL(), 3); EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); } diff --git a/test/cases/test-edns.cc b/test/cases/test-edns.cc new file mode 100644 index 00000000000..3ba3996e5d6 --- /dev/null +++ b/test/cases/test-edns.cc @@ -0,0 +1,341 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "util.h" +#include "gtest/gtest.h" +#include + +class EDNS : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(EDNS, client) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + struct dns_opt_ecs ecs; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + int rr_count = 0; + int i = 0; + int ret = 0; + struct dns_rrs *rrs = NULL; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count > 0) { + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + switch (rrs->type) { + case DNS_OPT_T_ECS: { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, &ecs); + if (ret != 0) { + continue; + } + + dns_add_OPT_ECS(request->response_packet, &ecs); + + } break; + default: + break; + } + } + } + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700); + return smartdns::SERVER_REQUEST_OK; + } + + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +speed-check-mode none +)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A +subnet=2.2.2.2/24", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + ASSERT_EQ(client.GetOpt().size(), 2); + EXPECT_EQ(client.GetOpt()[1], "CLIENT-SUBNET: 2.2.2.0/24/0"); + EXPECT_EQ(ecs.family, 1); + EXPECT_EQ(ecs.source_prefix, 24); + EXPECT_EQ(ecs.scope_prefix, 0); + unsigned char edns_addr[4] = {2, 2, 2, 0}; + EXPECT_EQ(memcmp(ecs.addr, &edns_addr, 4), 0); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} + +TEST_F(EDNS, server) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + struct dns_opt_ecs ecs; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + int rr_count = 0; + int i = 0; + int ret = 0; + struct dns_rrs *rrs = NULL; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count > 0) { + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + switch (rrs->type) { + case DNS_OPT_T_ECS: { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, &ecs); + if (ret != 0) { + continue; + } + + dns_add_OPT_ECS(request->response_packet, &ecs); + + } break; + default: + break; + } + } + } + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700); + return smartdns::SERVER_REQUEST_OK; + } + + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 -subnet=2.2.2.0/24 +speed-check-mode none +)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(ecs.family, 1); + EXPECT_EQ(ecs.source_prefix, 24); + EXPECT_EQ(ecs.scope_prefix, 0); + unsigned char edns_addr[4] = {2, 2, 2, 0}; + EXPECT_EQ(memcmp(ecs.addr, &edns_addr, 4), 0); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} + +TEST_F(EDNS, server_v6) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + struct dns_opt_ecs ecs; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + int rr_count = 0; + int i = 0; + int ret = 0; + struct dns_rrs *rrs = NULL; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count > 0) { + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + switch (rrs->type) { + case DNS_OPT_T_ECS: { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, &ecs); + if (ret != 0) { + continue; + } + + dns_add_OPT_ECS(request->response_packet, &ecs); + + } break; + default: + break; + } + } + } + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700); + return smartdns::SERVER_REQUEST_OK; + } + + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 -subnet=64:ff9b::/96 +speed-check-mode none +)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(ecs.family, 2); + EXPECT_EQ(ecs.source_prefix, 96); + EXPECT_EQ(ecs.scope_prefix, 0); + unsigned char edns_addr[16] = {00, 0x64, 0xff, 0x9b, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + EXPECT_EQ(memcmp(ecs.addr, &edns_addr, 16), 0); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} + +TEST_F(EDNS, edns_client_subnet) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + struct dns_opt_ecs ecs; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + int rr_count = 0; + int i = 0; + int ret = 0; + struct dns_rrs *rrs = NULL; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count > 0) { + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + switch (rrs->type) { + case DNS_OPT_T_ECS: { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, &ecs); + if (ret != 0) { + continue; + } + + dns_add_OPT_ECS(request->response_packet, &ecs); + + } break; + default: + break; + } + } + } + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700); + return smartdns::SERVER_REQUEST_OK; + } + + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +speed-check-mode none +edns-client-subnet 2.2.2.2/24 +)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(ecs.family, 1); + EXPECT_EQ(ecs.source_prefix, 24); + EXPECT_EQ(ecs.scope_prefix, 0); + unsigned char edns_addr[4] = {2, 2, 2, 0}; + EXPECT_EQ(memcmp(ecs.addr, &edns_addr, 4), 0); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} + +TEST_F(EDNS, edns_client_subnet_v6) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + struct dns_opt_ecs ecs; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + int rr_count = 0; + int i = 0; + int ret = 0; + struct dns_rrs *rrs = NULL; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count > 0) { + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + switch (rrs->type) { + case DNS_OPT_T_ECS: { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, &ecs); + if (ret != 0) { + continue; + } + + dns_add_OPT_ECS(request->response_packet, &ecs); + + } break; + default: + break; + } + } + } + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700); + return smartdns::SERVER_REQUEST_OK; + } + + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +speed-check-mode none +edns-client-subnet 64:ff9b::/96 +)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(ecs.family, 2); + EXPECT_EQ(ecs.source_prefix, 96); + EXPECT_EQ(ecs.scope_prefix, 0); + unsigned char edns_addr[16] = {00, 0x64, 0xff, 0x9b, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + EXPECT_EQ(memcmp(ecs.addr, &edns_addr, 16), 0); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} diff --git a/test/cases/test-same-pending-query.cc b/test/cases/test-same-pending-query.cc index 57840d84926..a8d17c9777f 100644 --- a/test/cases/test-same-pending-query.cc +++ b/test/cases/test-same-pending-query.cc @@ -74,7 +74,7 @@ log-level error std::vector threads; uint64_t tick = get_tick_count(); - for (int i = 0; i < 10; i++) { + for (int i = 0; i < 5; i++) { auto t = std::thread([&]() { for (int j = 0; j < 10; j++) { smartdns::Client client; @@ -91,4 +91,6 @@ log-level error for (auto &t : threads) { t.join(); } + + EXPECT_LT(qid_map.size(), 80); } diff --git a/test/client.cc b/test/client.cc index 38c4e40fede..d4837fe82e6 100644 --- a/test/client.cc +++ b/test/client.cc @@ -179,6 +179,11 @@ std::vector Client::GetAdditional() return records_additional_; } +std::vector Client::GetOpt() +{ + return records_opt_; +} + int Client::GetAnswerNum() { return answer_num_; @@ -266,6 +271,22 @@ bool Client::ParserResult() return false; } + std::regex reg_opt(";; OPT PSEUDOSECTION:\\n((?:.|\\n|\\r\\n)+?)\\n;;", + std::regex::ECMAScript | std::regex::optimize); + if (std::regex_search(result_, match, reg_opt)) { + std::string opt_str = match[1]; + + std::vector lines = StringSplit(opt_str, '\n'); + for (auto &line : lines) { + if (line.length() <= 0) { + continue; + } + + line = line.substr(2); + records_opt_.push_back(line); + } + } + std::regex reg_answer_num(", ANSWER: ([0-9]+),"); if (std::regex_search(result_, match, reg_answer_num)) { answer_num_ = std::stoi(match[1]); diff --git a/test/client.h b/test/client.h index 2c59de9a226..41fd8a19ad4 100644 --- a/test/client.h +++ b/test/client.h @@ -69,6 +69,8 @@ class Client std::vector GetAdditional(); + std::vector GetOpt(); + int GetAnswerNum(); int GetAuthorityNum(); @@ -103,6 +105,7 @@ class Client std::vector records_answer_; std::vector records_authority_; std::vector records_additional_; + std::vector records_opt_; }; } // namespace smartdns