From 037f10d3c0ace26a4f698058ea20f9803f7a92ac Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Mon, 25 Dec 2023 23:23:19 +0800 Subject: [PATCH] mdns: add test for mdns-lookup --- src/dns_client.c | 19 +++++- src/dns_client.h | 7 ++ src/dns_server.c | 96 ++++++++++++++++++++++---- src/util.c | 8 ++- test/cases/test-mdns.cc | 148 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 259 insertions(+), 19 deletions(-) create mode 100644 test/cases/test-mdns.cc diff --git a/src/dns_client.c b/src/dns_client.c index 25fe39de69..a2eebce54a 100644 --- a/src/dns_client.c +++ b/src/dns_client.c @@ -70,9 +70,6 @@ #define SOCKET_PRIORITY (6) #define SOCKET_IP_TOS (IPTOS_LOWDELAY | IPTOS_RELIABILITY) -#define DNS_MDNS_IP "224.0.0.251" -#define DNS_MDNS_PORT 5353 - /* ECS info */ struct dns_client_ecs { int enable; @@ -4613,6 +4610,22 @@ static int _dns_client_add_mdns_server(void) goto errout; } +#ifdef TEST + ret = _dns_client_server_add(DNS_MDNS_IP, "lo", DNS_MDNS_PORT, DNS_SERVER_MDNS, &server_flags); + if (ret != 0) { + tlog(TLOG_ERROR, "add mdns server failed."); + goto errout; + } + + if (dns_client_add_to_group(DNS_SERVER_GROUP_MDNS, DNS_MDNS_IP, DNS_MDNS_PORT, DNS_SERVER_MDNS, &server_flags) != + 0) { + tlog(TLOG_ERROR, "add mdns server to group failed."); + goto errout; + } + + return 0; +#endif + if (getifaddrs(&ifaddr) == -1) { goto errout; } diff --git a/src/dns_client.h b/src/dns_client.h index 38ee9af225..dba191df2f 100644 --- a/src/dns_client.h +++ b/src/dns_client.h @@ -29,6 +29,13 @@ extern "C" { #define DNS_SERVER_GROUP_DEFAULT "default" #define DNS_SERVER_GROUP_MDNS "mdns" #define DNS_SERVER_GROUP_LOCAL "local" +#ifdef TEST +#define DNS_MDNS_IP "127.0.0.1" +#define DNS_MDNS_PORT 55353 +#else +#define DNS_MDNS_IP "224.0.0.251" +#define DNS_MDNS_PORT 5353 +#endif typedef enum { DNS_SERVER_UDP, diff --git a/src/dns_server.c b/src/dns_server.c index f7e436d297..b1217bfe9d 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -433,6 +433,10 @@ static int _dns_server_get_conf_ttl(struct dns_request *request, int ttl) int rr_ttl_min = dns_conf_rr_ttl_min; int rr_ttl_max = dns_conf_rr_ttl_max; + if (request->is_mdns_lookup) { + rr_ttl_min = DNS_SERVER_ADDR_TTL; + } + struct dns_ttl_rule *ttl_rule = _dns_server_get_dns_rule(request, DOMAIN_RULE_TTL); if (ttl_rule != NULL) { if (ttl_rule->ttl > 0) { @@ -1407,11 +1411,16 @@ static int _dns_server_get_cache_timeout(struct dns_request *request, struct dns { int timeout = 0; int prefetch_time = 0; + int is_serve_expired = dns_conf_serve_expired; if (request->rcode != DNS_RC_NOERROR) { return ttl + 1; } + if (request->is_mdns_lookup == 1) { + return ttl + 1; + } + if (dns_conf_prefetch && _dns_cache_is_specify_packet(request->qtype) != 0) { prefetch_time = 1; } @@ -1424,8 +1433,12 @@ static int _dns_server_get_cache_timeout(struct dns_request *request, struct dns prefetch_time = 0; } + if (request->no_serve_expired) { + is_serve_expired = 0; + } + if (prefetch_time == 1) { - if (dns_conf_serve_expired) { + if (is_serve_expired) { timeout = dns_conf_serve_expired_prefetch_time; if (timeout == 0) { timeout = dns_conf_serve_expired_ttl / 2; @@ -1451,7 +1464,7 @@ static int _dns_server_get_cache_timeout(struct dns_request *request, struct dns } } else { timeout = ttl; - if (dns_conf_serve_expired) { + if (is_serve_expired) { timeout += dns_conf_serve_expired_ttl; } @@ -1491,6 +1504,12 @@ static int _dns_server_request_update_cache(struct dns_request *request, int spe cache_key.query_flag = request->server_flags; if (request->prefetch) { + /* no prefetch for mdns request */ + if (request->is_mdns_lookup) { + ret = 0; + goto errout; + } + if (dns_cache_replace(&cache_key, request->rcode, ttl, speed, _dns_server_get_cache_timeout(request, &cache_key, ttl), !(request->prefetch_flags & PREFETCH_FLAGS_EXPIRED), cache_data) != 0) { @@ -1681,6 +1700,12 @@ static int _dns_cache_packet(struct dns_server_post_context *context) cache_key.query_flag = request->server_flags; if (request->prefetch) { + /* no prefetch for mdns request */ + if (request->is_mdns_lookup) { + ret = 0; + goto errout; + } + if (dns_cache_replace(&cache_key, request->rcode, request->ip_ttl, -1, _dns_server_get_cache_timeout(request, &cache_key, request->ip_ttl), !(request->prefetch_flags & PREFETCH_FLAGS_EXPIRED), cache_packet) != 0) { @@ -2219,6 +2244,31 @@ static int _dns_server_reply_all_pending_list(struct dns_request *request, struc return ret; } +static void _dns_server_need_append_mdns_local_cname(struct dns_request *request) +{ + if (request->is_mdns_lookup == 0) { + return; + } + + if (request->has_cname != 0) { + return; + } + + if (request->domain[0] == '\0') { + return; + } + + if (strstr(request->domain, ".") != NULL) { + return; + } + + request->has_cname = 1; + snprintf(request->cname, sizeof(request->cname), "%.*s.%s", + (int)(sizeof(request->cname) - sizeof(DNS_SERVER_GROUP_LOCAL) - 1), request->domain, + DNS_SERVER_GROUP_LOCAL); + return; +} + static void _dns_server_check_complete_dualstack(struct dns_request *request, struct dns_request *dualstack_request) { if (dualstack_request == NULL || request == NULL) { @@ -2300,6 +2350,10 @@ static int _dns_server_request_complete_with_all_IPs(struct dns_request *request ttl = DNS_SERVER_FAIL_TTL; } + if (request->ip_ttl == 0) { + request->ip_ttl = ttl; + } + if (request->prefetch == 1) { return 0; } @@ -2320,6 +2374,8 @@ static int _dns_server_request_complete_with_all_IPs(struct dns_request *request goto out; } + _dns_server_need_append_mdns_local_cname(request); + if (request->has_soa) { tlog(TLOG_INFO, "result: %s, qtype: %d, SOA", request->domain, request->qtype); } else { @@ -2559,6 +2615,8 @@ static void _dns_server_complete_with_multi_ipaddress(struct dns_request *reques return; } + _dns_server_need_append_mdns_local_cname(request); + _dns_server_post_context_init(&context, request); context.do_cache = 1; context.do_ipset = 1; @@ -3367,6 +3425,7 @@ static int _dns_server_process_answer(struct dns_request *request, const char *d domain, request->qtype, request->soa.mname, request->soa.rname, request->soa.serial, request->soa.refresh, request->soa.retry, request->soa.expire, request->soa.minimum); + request->ip_ttl = _dns_server_get_conf_ttl(request, ttl); int soa_num = atomic_inc_return(&request->soa_num); if ((soa_num >= ((int)ceil((float)dns_server_alive_num() / 3) + 1) || soa_num > 4) && atomic_read(&request->ip_map_num) <= 0) { @@ -3461,7 +3520,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, const } } - ttl = ttl_tmp; + ttl = _dns_server_get_conf_ttl(request, ttl_tmp); _dns_server_request_release(request); } break; case DNS_T_AAAA: { @@ -3502,7 +3561,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, const } } - ttl = ttl_tmp; + ttl = _dns_server_get_conf_ttl(request, ttl_tmp); _dns_server_request_release(request); } break; case DNS_T_CNAME: { @@ -3515,7 +3574,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, const char tmpbuf[DNS_MAX_CNAME_LEN]; dns_get_CNAME(rrs, tmpname, DNS_MAX_CNAME_LEN, &ttl, tmpbuf, DNS_MAX_CNAME_LEN); if (request->ip_ttl == 0) { - request->ip_ttl = ttl; + request->ip_ttl = _dns_server_get_conf_ttl(request, ttl); } } break; @@ -3646,6 +3705,7 @@ static int _dns_server_get_answer(struct dns_server_post_context *context) "%d, minimum: %d", request->domain, request->qtype, request->soa.mname, request->soa.rname, request->soa.serial, request->soa.refresh, request->soa.retry, request->soa.expire, request->soa.minimum); + request->ip_ttl = _dns_server_get_conf_ttl(request, ttl); } break; default: break; @@ -3708,7 +3768,7 @@ static void _dns_server_query_end(struct dns_request *request) if (request->is_mdns_lookup == 1 && request->rcode == DNS_RC_SERVFAIL) { request->rcode = DNS_RC_NOERROR; request->force_soa = 1; - request->ip_ttl = _dns_server_get_local_ttl(request); + request->ip_ttl = _dns_server_get_conf_ttl(request, DNS_SERVER_ADDR_TTL); } pthread_mutex_lock(&request->ip_map_lock); @@ -5560,6 +5620,16 @@ static int _dns_server_setup_query_option(struct dns_request *request, struct dn return 0; } +static void _dns_server_mdns_query_setup_server_group(struct dns_request *request, const char **group_name) +{ + if (request->is_mdns_lookup == 0 || group_name == NULL) { + return; + } + + *group_name = DNS_SERVER_GROUP_MDNS; + return; +} + static int _dns_server_mdns_query_setup(struct dns_request *request, const char *group_name, char **request_domain, char *domain_buffer, int domain_buffer_len) { @@ -5672,6 +5742,10 @@ static int _dns_server_do_query(struct dns_request *request, int skip_notify_eve safe_strncpy(request->dns_group_name, group_name, DNS_GROUP_NAME_LEN); } + if (_dns_server_mdns_query_setup(request, group_name, &request_domain, domain_buffer, sizeof(domain_buffer)) != 0) { + goto errout; + } + if (_dns_server_process_cname_pre(request) != 0) { goto errout; } @@ -5738,6 +5812,7 @@ static int _dns_server_do_query(struct dns_request *request, int skip_notify_eve // setup options _dns_server_setup_query_option(request, &options); + _dns_server_mdns_query_setup_server_group(request, &group_name); pthread_mutex_lock(&server.request_list_lock); if (list_empty(&server.request_list) && skip_notify_event == 1) { @@ -5746,15 +5821,6 @@ static int _dns_server_do_query(struct dns_request *request, int skip_notify_eve list_add_tail(&request->list, &server.request_list); pthread_mutex_unlock(&server.request_list_lock); - if (_dns_server_mdns_query_setup(request, group_name, &request_domain, domain_buffer, sizeof(domain_buffer)) != 0) { - goto errout; - } - - /* if request MDNS */ - if (request->is_mdns_lookup) { - group_name = DNS_SERVER_GROUP_MDNS; - } - // Get reference for DNS query request->request_wait++; _dns_server_request_get(request); diff --git a/src/util.c b/src/util.c index a20e0ed814..be6263d0a2 100644 --- a/src/util.c +++ b/src/util.c @@ -449,6 +449,11 @@ int check_is_ipv6(const char *ip) continue; } + /* scope id, end of ipv6 address*/ + if (c == '%') { + break; + } + if (c == ':') { colon_num++; dig_num = 0; @@ -1949,7 +1954,8 @@ static int _dns_debug_display(struct dns_packet *packet) char name[DNS_MAX_CNAME_LEN] = {0}; char target[DNS_MAX_CNAME_LEN]; - ret = dns_get_SRV(rrs, name, DNS_MAX_CNAME_LEN, &ttl, &priority, &weight, &port, target, DNS_MAX_CNAME_LEN); + ret = dns_get_SRV(rrs, name, DNS_MAX_CNAME_LEN, &ttl, &priority, &weight, &port, target, + DNS_MAX_CNAME_LEN); if (ret < 0) { tlog(TLOG_DEBUG, "decode SRV failed, %s", name); return -1; diff --git a/test/cases/test-mdns.cc b/test/cases/test-mdns.cc new file mode 100644 index 0000000000..d50acaeabc --- /dev/null +++ b/test/cases/test-mdns.cc @@ -0,0 +1,148 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "dns_client.h" +#include "include/utils.h" +#include "server.h" +#include "gtest/gtest.h" +#include + +class mDNS : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST(mDNS, query) +{ + smartdns::MockServer server_upstream1; + smartdns::MockServer server_upstream2; + smartdns::Server server; + + std::string listen_url = "udp://"; + listen_url += DNS_MDNS_IP; + listen_url += ":" + std::to_string(DNS_MDNS_PORT); + + server_upstream1.Start(listen_url.c_str(), [](struct smartdns::ServerRequestContext *request) { + std::string domain = request->domain; + if (request->domain.length() == 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (request->qtype == DNS_T_A) { + unsigned char addr[][4] = {{1, 2, 3, 4}}; + dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]); + } else if (request->qtype == DNS_T_AAAA) { + unsigned char addr[][16] = {{1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}; + dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]); + } else { + return smartdns::SERVER_REQUEST_ERROR; + } + + return smartdns::SERVER_REQUEST_OK; + }); + + server_upstream2.Start("udp://0.0.0.0:61053", + [](struct smartdns::ServerRequestContext *request) { return smartdns::SERVER_REQUEST_SOA; }); + + server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100); + server.MockPing(PING_TYPE_ICMP, "102:304:500::1", 60, 100); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +dualstack-ip-selection no +mdns-lookup yes +)"""); + smartdns::Client client; + + ASSERT_TRUE(client.Query("b.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 0); + EXPECT_EQ(client.GetStatus(), "NXDOMAIN"); + + ASSERT_TRUE(client.Query("host A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "host"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "host.local."); + EXPECT_EQ(client.GetAnswer()[1].GetName(), "host.local"); + EXPECT_EQ(client.GetAnswer()[1].GetData(), "1.2.3.4"); + + ASSERT_TRUE(client.Query("host AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "host"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "host.local."); + EXPECT_EQ(client.GetAnswer()[1].GetName(), "host.local"); + EXPECT_EQ(client.GetAnswer()[1].GetData(), "102:304:500::1"); + + ASSERT_TRUE(client.Query("host.local A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "host.local"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} + +TEST(mDNS, ptr) +{ + smartdns::MockServer server_upstream1; + smartdns::MockServer server_upstream2; + smartdns::Server server; + + std::string listen_url = "udp://"; + listen_url += DNS_MDNS_IP; + listen_url += ":" + std::to_string(DNS_MDNS_PORT); + + server_upstream1.Start(listen_url.c_str(), [](struct smartdns::ServerRequestContext *request) { + std::string domain = request->domain; + if (request->domain.length() == 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (request->qtype != DNS_T_PTR) { + return smartdns::SERVER_REQUEST_SOA; + } + + dns_add_PTR(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, "host.local"); + + return smartdns::SERVER_REQUEST_OK; + }); + + server_upstream2.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) { + return smartdns::SERVER_REQUEST_ERROR; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +dualstack-ip-selection no +mdns-lookup yes +)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("-x 192.168.1.1", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "1.1.168.192.in-addr.arpa"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "host.local."); +}