Skip to main content

vcl_protocol/
dns.rs

1//! # VCL DNS Leak Protection
2//!
3//! Prevents DNS queries from leaking outside the VCL tunnel.
4//!
5//! ## The problem
6//!
7//! ```text
8//! Without DNS protection:
9//!   App → DNS query → OS resolver → ISP DNS → LEAK!
10//!   App → data → VCL tunnel → OK
11//!
12//! With DNS protection:
13//!   App → DNS query → VCLDnsFilter → VCL tunnel → private DNS → OK
14//!   App → data → VCL tunnel → OK
15//! ```
16//!
17//! ## Example
18//!
19//! ```rust
20//! use vcl_protocol::dns::{DnsConfig, DnsFilter, DnsPacket};
21//!
22//! let config = DnsConfig::default();
23//! let mut filter = DnsFilter::new(config);
24//!
25//! // Check if a UDP packet is a DNS query that should be intercepted
26//! let raw = vec![0u8; 12]; // minimal DNS header
27//! if DnsFilter::is_dns_packet(&raw) {
28//!     // route through tunnel instead of OS resolver
29//! }
30//!
31//! println!("Upstream DNS: {:?}", filter.config().upstream_servers);
32//! ```
33use std::net::IpAddr;
34use std::collections::HashMap;
35use std::time::{Duration, Instant};
36use tracing::{debug, info, warn};
37
38/// Well-known privacy-respecting DNS servers.
39pub const CLOUDFLARE_DNS:  &str = "1.1.1.1:53";
40pub const CLOUDFLARE_DNS2: &str = "1.0.0.1:53";
41pub const GOOGLE_DNS:      &str = "8.8.8.8:53";
42pub const GOOGLE_DNS2:     &str = "8.8.4.4:53";
43pub const QUAD9_DNS:       &str = "9.9.9.9:53";
44
45/// DNS query type.
46#[derive(Debug, Clone, PartialEq)]
47pub enum DnsQueryType {
48    A,       // IPv4 address
49    AAAA,    // IPv6 address
50    CNAME,   // Canonical name
51    MX,      // Mail exchange
52    TXT,     // Text record
53    PTR,     // Reverse lookup
54    NS,      // Name server
55    Other(u16),
56}
57
58impl DnsQueryType {
59    pub fn from_u16(v: u16) -> Self {
60        match v {
61            1  => DnsQueryType::A,
62            28 => DnsQueryType::AAAA,
63            5  => DnsQueryType::CNAME,
64            15 => DnsQueryType::MX,
65            16 => DnsQueryType::TXT,
66            12 => DnsQueryType::PTR,
67            2  => DnsQueryType::NS,
68            o  => DnsQueryType::Other(o),
69        }
70    }
71
72    pub fn to_u16(&self) -> u16 {
73        match self {
74            DnsQueryType::A        => 1,
75            DnsQueryType::AAAA     => 28,
76            DnsQueryType::CNAME    => 5,
77            DnsQueryType::MX       => 15,
78            DnsQueryType::TXT      => 16,
79            DnsQueryType::PTR      => 12,
80            DnsQueryType::NS       => 2,
81            DnsQueryType::Other(o) => *o,
82        }
83    }
84}
85
86/// A parsed DNS packet (header + first question only).
87#[derive(Debug, Clone)]
88pub struct DnsPacket {
89    /// Transaction ID.
90    pub id: u16,
91    /// True if this is a query (QR bit = 0), false if response.
92    pub is_query: bool,
93    /// Query domain name (e.g. "example.com").
94    pub domain: String,
95    /// Query type.
96    pub query_type: DnsQueryType,
97    /// Raw packet bytes.
98    pub raw: Vec<u8>,
99}
100
101impl DnsPacket {
102    /// Parse a raw DNS packet.
103    ///
104    /// Returns `None` if the packet is too short or malformed.
105    pub fn parse(raw: Vec<u8>) -> Option<Self> {
106        if raw.len() < 12 {
107            return None;
108        }
109
110        let id = u16::from_be_bytes([raw[0], raw[1]]);
111        let flags = u16::from_be_bytes([raw[2], raw[3]]);
112        let is_query = (flags >> 15) == 0;
113        let qdcount = u16::from_be_bytes([raw[4], raw[5]]);
114
115        if qdcount == 0 {
116            return Some(DnsPacket {
117                id,
118                is_query,
119                domain: String::new(),
120                query_type: DnsQueryType::A,
121                raw,
122            });
123        }
124
125        // Parse first question
126        let (domain, offset) = parse_dns_name(&raw, 12)?;
127        if offset + 4 > raw.len() {
128            return None;
129        }
130        let qtype = u16::from_be_bytes([raw[offset], raw[offset + 1]]);
131
132        debug!(id, domain = %domain, is_query, "DNS packet parsed");
133
134        Some(DnsPacket {
135            id,
136            is_query,
137            domain,
138            query_type: DnsQueryType::from_u16(qtype),
139            raw,
140        })
141    }
142
143    /// Returns `true` if this is a query (not a response).
144    pub fn is_query(&self) -> bool {
145        self.is_query
146    }
147}
148
149/// Parse a DNS name from a packet at the given offset.
150/// Returns (name, offset_after_name).
151fn parse_dns_name(data: &[u8], mut offset: usize) -> Option<(String, usize)> {
152    let mut labels = Vec::new();
153    let mut iterations = 0;
154
155    loop {
156        if offset >= data.len() || iterations > 128 {
157            return None;
158        }
159        iterations += 1;
160
161        let len = data[offset] as usize;
162        if len == 0 {
163            offset += 1;
164            break;
165        }
166        // Compression pointer
167        if len & 0xC0 == 0xC0 {
168            offset += 2;
169            break;
170        }
171        offset += 1;
172        if offset + len > data.len() {
173            return None;
174        }
175        let label = std::str::from_utf8(&data[offset..offset + len]).ok()?;
176        labels.push(label.to_string());
177        offset += len;
178    }
179
180    Some((labels.join("."), offset))
181}
182
183/// Action to take for a DNS query.
184#[derive(Debug, Clone, PartialEq)]
185pub enum DnsAction {
186    /// Forward through the VCL tunnel to upstream DNS.
187    ForwardThroughTunnel,
188    /// Block this query (return NXDOMAIN).
189    Block,
190    /// Return a cached response.
191    ReturnCached(IpAddr),
192    /// Allow this query to go directly (split DNS for local domains).
193    AllowDirect,
194}
195
196/// A cached DNS entry.
197#[derive(Debug, Clone)]
198struct CacheEntry {
199    addr: IpAddr,
200    expires_at: Instant,
201}
202
203impl CacheEntry {
204    fn is_expired(&self) -> bool {
205        Instant::now() > self.expires_at
206    }
207}
208
209/// Configuration for DNS leak protection.
210#[derive(Debug, Clone)]
211pub struct DnsConfig {
212    /// Upstream DNS servers to use (inside the tunnel).
213    pub upstream_servers: Vec<String>,
214    /// Local/split DNS domains that bypass the tunnel (e.g. "corp.internal").
215    pub split_dns_domains: Vec<String>,
216    /// Domains to block completely (ad/tracking blocklist).
217    pub blocked_domains: Vec<String>,
218    /// Whether to cache DNS responses.
219    pub enable_cache: bool,
220    /// TTL for cached entries.
221    pub cache_ttl: Duration,
222    /// Maximum cache size.
223    pub max_cache_size: usize,
224}
225
226impl Default for DnsConfig {
227    fn default() -> Self {
228        DnsConfig {
229            upstream_servers: vec![
230                CLOUDFLARE_DNS.to_string(),
231                CLOUDFLARE_DNS2.to_string(),
232            ],
233            split_dns_domains: Vec::new(),
234            blocked_domains: Vec::new(),
235            enable_cache: true,
236            cache_ttl: Duration::from_secs(300),
237            max_cache_size: 1024,
238        }
239    }
240}
241
242impl DnsConfig {
243    /// Config using Cloudflare DNS (1.1.1.1).
244    pub fn cloudflare() -> Self {
245        DnsConfig::default()
246    }
247
248    /// Config using Google DNS (8.8.8.8).
249    pub fn google() -> Self {
250        DnsConfig {
251            upstream_servers: vec![
252                GOOGLE_DNS.to_string(),
253                GOOGLE_DNS2.to_string(),
254            ],
255            ..Default::default()
256        }
257    }
258
259    /// Config using Quad9 DNS (9.9.9.9) — blocks malware domains.
260    pub fn quad9() -> Self {
261        DnsConfig {
262            upstream_servers: vec![QUAD9_DNS.to_string()],
263            ..Default::default()
264        }
265    }
266
267    /// Add a split DNS domain (goes directly, not through tunnel).
268    pub fn with_split_domain(mut self, domain: &str) -> Self {
269        self.split_dns_domains.push(domain.to_string());
270        self
271    }
272
273    /// Add a blocked domain.
274    pub fn with_blocked_domain(mut self, domain: &str) -> Self {
275        self.blocked_domains.push(domain.to_string());
276        self
277    }
278}
279
280/// DNS leak protection filter.
281///
282/// Intercepts DNS queries, checks the blocklist and cache,
283/// and decides whether to forward through the tunnel or block.
284pub struct DnsFilter {
285    config: DnsConfig,
286    cache: HashMap<String, CacheEntry>,
287    /// Total queries intercepted.
288    total_intercepted: u64,
289    /// Total queries blocked.
290    total_blocked: u64,
291    /// Total cache hits.
292    total_cache_hits: u64,
293    /// Total queries forwarded through tunnel.
294    total_forwarded: u64,
295}
296
297impl DnsFilter {
298    /// Create a new DNS filter with the given config.
299    pub fn new(config: DnsConfig) -> Self {
300        info!(
301            upstream = ?config.upstream_servers,
302            blocked_count = config.blocked_domains.len(),
303            "DnsFilter created"
304        );
305        DnsFilter {
306            config,
307            cache: HashMap::new(),
308            total_intercepted: 0,
309            total_blocked: 0,
310            total_cache_hits: 0,
311            total_forwarded: 0,
312        }
313    }
314
315    /// Returns `true` if a raw UDP payload looks like a DNS packet.
316    ///
317    /// Checks minimum length and QR/opcode field sanity.
318    pub fn is_dns_packet(data: &[u8]) -> bool {
319        if data.len() < 12 {
320            return false;
321        }
322        // Opcode should be 0 (standard query) or 1 (inverse)
323        let opcode = (data[2] >> 3) & 0x0F;
324        opcode <= 2
325    }
326
327    /// Decide what to do with a DNS query for the given domain.
328    ///
329    /// Checks in order: cache → blocklist → split DNS → forward.
330    pub fn decide(&mut self, domain: &str, query_type: &DnsQueryType) -> DnsAction {
331        self.total_intercepted += 1;
332
333        // Clean expired cache entries periodically
334        if self.total_intercepted % 100 == 0 {
335            self.evict_expired();
336        }
337
338        // Check cache first
339        if self.config.enable_cache {
340            if let Some(entry) = self.cache.get(domain) {
341                if !entry.is_expired() {
342                    self.total_cache_hits += 1;
343                    debug!(domain, "DNS cache hit");
344                    return DnsAction::ReturnCached(entry.addr);
345                }
346            }
347        }
348
349        // Check blocklist
350        if self.is_blocked(domain) {
351            self.total_blocked += 1;
352            warn!(domain, "DNS query blocked");
353            return DnsAction::Block;
354        }
355
356        // Check split DNS
357        if self.is_split_dns(domain) {
358            debug!(domain, "DNS split — allowing direct");
359            return DnsAction::AllowDirect;
360        }
361
362        // Forward through tunnel
363        self.total_forwarded += 1;
364        debug!(domain, query_type = ?query_type, "DNS forwarding through tunnel");
365        DnsAction::ForwardThroughTunnel
366    }
367
368    /// Cache a DNS response for a domain.
369    pub fn cache_response(&mut self, domain: &str, addr: IpAddr) {
370        if !self.config.enable_cache {
371            return;
372        }
373        if self.cache.len() >= self.config.max_cache_size {
374            self.evict_expired();
375            // If still full, remove oldest entry
376            if self.cache.len() >= self.config.max_cache_size {
377                if let Some(key) = self.cache.keys().next().cloned() {
378                    self.cache.remove(&key);
379                }
380            }
381        }
382        self.cache.insert(domain.to_string(), CacheEntry {
383            addr,
384            expires_at: Instant::now() + self.config.cache_ttl,
385        });
386        debug!(domain, addr = %addr, "DNS response cached");
387    }
388
389    /// Returns `true` if the domain is in the blocklist.
390    ///
391    /// Supports wildcard suffix matching: blocking "ads.com" also blocks "sub.ads.com".
392    pub fn is_blocked(&self, domain: &str) -> bool {
393        let domain_lower = domain.to_lowercase();
394        self.config.blocked_domains.iter().any(|blocked| {
395            let b = blocked.to_lowercase();
396            domain_lower == b || domain_lower.ends_with(&format!(".{}", b))
397        })
398    }
399
400    /// Returns `true` if the domain should use split DNS (bypass tunnel).
401    pub fn is_split_dns(&self, domain: &str) -> bool {
402        let domain_lower = domain.to_lowercase();
403        self.config.split_dns_domains.iter().any(|split| {
404            let s = split.to_lowercase();
405            domain_lower == s || domain_lower.ends_with(&format!(".{}", s))
406        })
407    }
408
409    /// Add a domain to the blocklist at runtime.
410    pub fn block_domain(&mut self, domain: &str) {
411        info!(domain, "DNS domain blocked");
412        self.config.blocked_domains.push(domain.to_string());
413    }
414
415    /// Add a split DNS domain at runtime.
416    pub fn add_split_domain(&mut self, domain: &str) {
417        info!(domain, "DNS split domain added");
418        self.config.split_dns_domains.push(domain.to_string());
419    }
420
421    /// Get the first upstream DNS server address.
422    pub fn primary_upstream(&self) -> Option<&str> {
423        self.config.upstream_servers.first().map(|s| s.as_str())
424    }
425
426    /// Remove expired entries from the cache.
427    pub fn evict_expired(&mut self) {
428        let before = self.cache.len();
429        self.cache.retain(|_, v| !v.is_expired());
430        let removed = before - self.cache.len();
431        if removed > 0 {
432            debug!(removed, "DNS cache eviction");
433        }
434    }
435
436    /// Clear the entire DNS cache.
437    pub fn clear_cache(&mut self) {
438        self.cache.clear();
439        debug!("DNS cache cleared");
440    }
441
442    /// Returns the number of entries currently in the cache.
443    pub fn cache_size(&self) -> usize {
444        self.cache.len()
445    }
446
447    /// Returns total queries intercepted.
448    pub fn total_intercepted(&self) -> u64 {
449        self.total_intercepted
450    }
451
452    /// Returns total queries blocked.
453    pub fn total_blocked(&self) -> u64 {
454        self.total_blocked
455    }
456
457    /// Returns total cache hits.
458    pub fn total_cache_hits(&self) -> u64 {
459        self.total_cache_hits
460    }
461
462    /// Returns total queries forwarded through tunnel.
463    pub fn total_forwarded(&self) -> u64 {
464        self.total_forwarded
465    }
466
467    /// Returns a reference to the config.
468    pub fn config(&self) -> &DnsConfig {
469        &self.config
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    fn minimal_dns_query(domain: &str) -> Vec<u8> {
478        // Build a minimal DNS query packet for domain
479        let mut pkt = vec![
480            0x00, 0x01, // ID = 1
481            0x01, 0x00, // flags: standard query
482            0x00, 0x01, // QDCOUNT = 1
483            0x00, 0x00, // ANCOUNT = 0
484            0x00, 0x00, // NSCOUNT = 0
485            0x00, 0x00, // ARCOUNT = 0
486        ];
487        // Encode domain name
488        for label in domain.split('.') {
489            pkt.push(label.len() as u8);
490            pkt.extend_from_slice(label.as_bytes());
491        }
492        pkt.push(0x00); // end of name
493        pkt.extend_from_slice(&[0x00, 0x01]); // QTYPE = A
494        pkt.extend_from_slice(&[0x00, 0x01]); // QCLASS = IN
495        pkt
496    }
497
498    #[test]
499    fn test_is_dns_packet_valid() {
500        let pkt = minimal_dns_query("example.com");
501        assert!(DnsFilter::is_dns_packet(&pkt));
502    }
503
504    #[test]
505    fn test_is_dns_packet_too_short() {
506        assert!(!DnsFilter::is_dns_packet(&[0u8; 5]));
507        assert!(!DnsFilter::is_dns_packet(&[]));
508    }
509
510    #[test]
511    fn test_dns_packet_parse() {
512        let raw = minimal_dns_query("example.com");
513        let pkt = DnsPacket::parse(raw).unwrap();
514        assert_eq!(pkt.id, 1);
515        assert!(pkt.is_query());
516        assert_eq!(pkt.domain, "example.com");
517        assert_eq!(pkt.query_type, DnsQueryType::A);
518    }
519
520    #[test]
521    fn test_dns_packet_parse_too_short() {
522        assert!(DnsPacket::parse(vec![0u8; 5]).is_none());
523    }
524
525    #[test]
526    fn test_dns_config_default() {
527        let c = DnsConfig::default();
528        assert!(c.upstream_servers.contains(&CLOUDFLARE_DNS.to_string()));
529        assert!(c.enable_cache);
530    }
531
532    #[test]
533    fn test_dns_config_google() {
534        let c = DnsConfig::google();
535        assert!(c.upstream_servers.contains(&GOOGLE_DNS.to_string()));
536    }
537
538    #[test]
539    fn test_dns_config_quad9() {
540        let c = DnsConfig::quad9();
541        assert!(c.upstream_servers.contains(&QUAD9_DNS.to_string()));
542    }
543
544    #[test]
545    fn test_dns_config_with_blocked() {
546        let c = DnsConfig::default().with_blocked_domain("ads.com");
547        assert!(c.blocked_domains.contains(&"ads.com".to_string()));
548    }
549
550    #[test]
551    fn test_dns_config_with_split() {
552        let c = DnsConfig::default().with_split_domain("corp.internal");
553        assert!(c.split_dns_domains.contains(&"corp.internal".to_string()));
554    }
555
556    #[test]
557    fn test_filter_forward() {
558        let mut f = DnsFilter::new(DnsConfig::default());
559        let action = f.decide("example.com", &DnsQueryType::A);
560        assert_eq!(action, DnsAction::ForwardThroughTunnel);
561        assert_eq!(f.total_forwarded(), 1);
562    }
563
564    #[test]
565    fn test_filter_block() {
566        let config = DnsConfig::default().with_blocked_domain("ads.com");
567        let mut f = DnsFilter::new(config);
568        let action = f.decide("ads.com", &DnsQueryType::A);
569        assert_eq!(action, DnsAction::Block);
570        assert_eq!(f.total_blocked(), 1);
571    }
572
573    #[test]
574    fn test_filter_block_subdomain() {
575        let config = DnsConfig::default().with_blocked_domain("ads.com");
576        let mut f = DnsFilter::new(config);
577        let action = f.decide("tracker.ads.com", &DnsQueryType::A);
578        assert_eq!(action, DnsAction::Block);
579    }
580
581    #[test]
582    fn test_filter_split_dns() {
583        let config = DnsConfig::default().with_split_domain("corp.internal");
584        let mut f = DnsFilter::new(config);
585        let action = f.decide("server.corp.internal", &DnsQueryType::A);
586        assert_eq!(action, DnsAction::AllowDirect);
587    }
588
589    #[test]
590    fn test_filter_cache_hit() {
591        let mut f = DnsFilter::new(DnsConfig::default());
592        let addr: IpAddr = "1.2.3.4".parse().unwrap();
593        f.cache_response("example.com", addr);
594        let action = f.decide("example.com", &DnsQueryType::A);
595        assert_eq!(action, DnsAction::ReturnCached(addr));
596        assert_eq!(f.total_cache_hits(), 1);
597    }
598
599    #[test]
600    fn test_filter_cache_size() {
601        let mut f = DnsFilter::new(DnsConfig::default());
602        f.cache_response("a.com", "1.1.1.1".parse().unwrap());
603        f.cache_response("b.com", "2.2.2.2".parse().unwrap());
604        assert_eq!(f.cache_size(), 2);
605    }
606
607    #[test]
608    fn test_filter_clear_cache() {
609        let mut f = DnsFilter::new(DnsConfig::default());
610        f.cache_response("a.com", "1.1.1.1".parse().unwrap());
611        f.clear_cache();
612        assert_eq!(f.cache_size(), 0);
613    }
614
615    #[test]
616    fn test_filter_block_runtime() {
617        let mut f = DnsFilter::new(DnsConfig::default());
618        f.block_domain("evil.com");
619        assert_eq!(f.decide("evil.com", &DnsQueryType::A), DnsAction::Block);
620    }
621
622    #[test]
623    fn test_filter_split_runtime() {
624        let mut f = DnsFilter::new(DnsConfig::default());
625        f.add_split_domain("local.net");
626        assert_eq!(f.decide("host.local.net", &DnsQueryType::A), DnsAction::AllowDirect);
627    }
628
629    #[test]
630    fn test_is_blocked_exact() {
631        let config = DnsConfig::default().with_blocked_domain("bad.com");
632        let f = DnsFilter::new(config);
633        assert!(f.is_blocked("bad.com"));
634        assert!(!f.is_blocked("good.com"));
635    }
636
637    #[test]
638    fn test_is_blocked_subdomain() {
639        let config = DnsConfig::default().with_blocked_domain("bad.com");
640        let f = DnsFilter::new(config);
641        assert!(f.is_blocked("sub.bad.com"));
642        assert!(f.is_blocked("deep.sub.bad.com"));
643    }
644
645    #[test]
646    fn test_is_split_dns() {
647        let config = DnsConfig::default().with_split_domain("internal");
648        let f = DnsFilter::new(config);
649        assert!(f.is_split_dns("host.internal"));
650        assert!(!f.is_split_dns("external.com"));
651    }
652
653    #[test]
654    fn test_primary_upstream() {
655        let f = DnsFilter::new(DnsConfig::cloudflare());
656        assert_eq!(f.primary_upstream(), Some(CLOUDFLARE_DNS));
657    }
658
659    #[test]
660    fn test_stats() {
661        let mut f = DnsFilter::new(
662            DnsConfig::default().with_blocked_domain("bad.com")
663        );
664        f.decide("example.com", &DnsQueryType::A);
665        f.decide("bad.com", &DnsQueryType::A);
666        assert_eq!(f.total_intercepted(), 2);
667        assert_eq!(f.total_blocked(), 1);
668        assert_eq!(f.total_forwarded(), 1);
669    }
670
671    #[test]
672    fn test_query_type_from_u16() {
673        assert_eq!(DnsQueryType::from_u16(1),  DnsQueryType::A);
674        assert_eq!(DnsQueryType::from_u16(28), DnsQueryType::AAAA);
675        assert_eq!(DnsQueryType::from_u16(99), DnsQueryType::Other(99));
676    }
677
678    #[test]
679    fn test_query_type_to_u16() {
680        assert_eq!(DnsQueryType::A.to_u16(), 1);
681        assert_eq!(DnsQueryType::AAAA.to_u16(), 28);
682        assert_eq!(DnsQueryType::Other(99).to_u16(), 99);
683    }
684
685    #[test]
686    fn test_evict_expired() {
687        let config = DnsConfig {
688            cache_ttl: Duration::from_millis(1),
689            ..DnsConfig::default()
690        };
691        let mut f = DnsFilter::new(config);
692        f.cache_response("a.com", "1.1.1.1".parse().unwrap());
693        std::thread::sleep(Duration::from_millis(5));
694        f.evict_expired();
695        assert_eq!(f.cache_size(), 0);
696    }
697}