sentinel_proxy/
geo_filter.rs

1//! GeoIP filtering for Sentinel proxy
2//!
3//! This module provides geolocation-based request filtering using MaxMind GeoLite2/GeoIP2
4//! and IP2Location databases. Filters can block, allow, or log requests based on country.
5//!
6//! # Features
7//! - Support for MaxMind (.mmdb) and IP2Location (.bin) databases
8//! - Block mode (blocklist) and Allow mode (allowlist)
9//! - Log-only mode for monitoring without blocking
10//! - Per-filter IP→Country caching with configurable TTL
11//! - Configurable fail-open/fail-closed on lookup errors
12//! - X-GeoIP-Country response header injection
13
14use std::collections::{HashMap, HashSet};
15use std::net::IpAddr;
16use std::path::{Path, PathBuf};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use dashmap::DashMap;
21use notify::{Event, EventKind, RecursiveMode, Watcher};
22use parking_lot::RwLock;
23use tokio::sync::mpsc;
24use tracing::{debug, error, info, trace, warn};
25
26use sentinel_config::{GeoDatabaseType, GeoFailureMode, GeoFilter, GeoFilterAction};
27
28// =============================================================================
29// Error Types
30// =============================================================================
31
32/// Errors that can occur during geo lookup
33#[derive(Debug, Clone)]
34pub enum GeoLookupError {
35    /// IP address could not be parsed
36    InvalidIp(String),
37    /// Database error during lookup
38    DatabaseError(String),
39    /// Database file could not be loaded
40    LoadError(String),
41}
42
43impl std::fmt::Display for GeoLookupError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            GeoLookupError::InvalidIp(ip) => write!(f, "invalid IP address: {}", ip),
47            GeoLookupError::DatabaseError(msg) => write!(f, "database error: {}", msg),
48            GeoLookupError::LoadError(msg) => write!(f, "failed to load database: {}", msg),
49        }
50    }
51}
52
53impl std::error::Error for GeoLookupError {}
54
55// =============================================================================
56// GeoDatabase Trait
57// =============================================================================
58
59/// Trait for GeoIP database backends
60pub trait GeoDatabase: Send + Sync {
61    /// Look up the country code for an IP address
62    fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError>;
63
64    /// Get the database type
65    fn database_type(&self) -> GeoDatabaseType;
66}
67
68// =============================================================================
69// MaxMind Database Backend
70// =============================================================================
71
72/// MaxMind GeoLite2/GeoIP2 database backend
73pub struct MaxMindDatabase {
74    reader: maxminddb::Reader<Vec<u8>>,
75}
76
77impl MaxMindDatabase {
78    /// Open a MaxMind database file
79    pub fn open(path: impl AsRef<Path>) -> Result<Self, GeoLookupError> {
80        let path = path.as_ref();
81        let reader = maxminddb::Reader::open_readfile(path).map_err(|e| {
82            GeoLookupError::LoadError(format!("failed to open MaxMind database {:?}: {}", path, e))
83        })?;
84
85        debug!(path = ?path, "Opened MaxMind GeoIP database");
86        Ok(Self { reader })
87    }
88}
89
90impl GeoDatabase for MaxMindDatabase {
91    fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError> {
92        match self.reader.lookup::<maxminddb::geoip2::Country>(ip) {
93            Ok(record) => {
94                let country_code = record
95                    .country
96                    .and_then(|c| c.iso_code)
97                    .map(|s| s.to_string());
98                trace!(ip = %ip, country = ?country_code, "MaxMind lookup");
99                Ok(country_code)
100            }
101            Err(maxminddb::MaxMindDBError::AddressNotFoundError(_)) => {
102                trace!(ip = %ip, "IP not found in MaxMind database");
103                Ok(None)
104            }
105            Err(e) => {
106                warn!(ip = %ip, error = %e, "MaxMind lookup error");
107                Err(GeoLookupError::DatabaseError(e.to_string()))
108            }
109        }
110    }
111
112    fn database_type(&self) -> GeoDatabaseType {
113        GeoDatabaseType::MaxMind
114    }
115}
116
117// =============================================================================
118// IP2Location Database Backend
119// =============================================================================
120
121/// IP2Location database backend
122pub struct Ip2LocationDatabase {
123    db: ip2location::DB,
124}
125
126impl Ip2LocationDatabase {
127    /// Open an IP2Location database file
128    pub fn open(path: impl AsRef<Path>) -> Result<Self, GeoLookupError> {
129        let path = path.as_ref();
130        let db = ip2location::DB::from_file(path).map_err(|e| {
131            GeoLookupError::LoadError(format!(
132                "failed to open IP2Location database {:?}: {}",
133                path, e
134            ))
135        })?;
136
137        debug!(path = ?path, "Opened IP2Location GeoIP database");
138        Ok(Self { db })
139    }
140}
141
142impl GeoDatabase for Ip2LocationDatabase {
143    fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError> {
144        match self.db.ip_lookup(ip) {
145            Ok(record) => {
146                // Record is an enum - extract country from the LocationDb variant
147                let country_code = match record {
148                    ip2location::Record::LocationDb(loc) => {
149                        loc.country.map(|c| c.short_name.to_string())
150                    }
151                    ip2location::Record::ProxyDb(proxy) => {
152                        proxy.country.map(|c| c.short_name.to_string())
153                    }
154                };
155                trace!(ip = %ip, country = ?country_code, "IP2Location lookup");
156                Ok(country_code)
157            }
158            Err(ip2location::error::Error::RecordNotFound) => {
159                trace!(ip = %ip, "IP not found in IP2Location database");
160                Ok(None)
161            }
162            Err(e) => {
163                warn!(ip = %ip, error = %e, "IP2Location lookup error");
164                Err(GeoLookupError::DatabaseError(e.to_string()))
165            }
166        }
167    }
168
169    fn database_type(&self) -> GeoDatabaseType {
170        GeoDatabaseType::Ip2Location
171    }
172}
173
174// =============================================================================
175// Cached Country Entry
176// =============================================================================
177
178/// Cached country lookup result
179struct CachedCountry {
180    /// The country code (or None if not found)
181    country_code: Option<String>,
182    /// When this entry was cached
183    cached_at: Instant,
184}
185
186// =============================================================================
187// GeoFilterResult
188// =============================================================================
189
190/// Result of a geo filter check
191#[derive(Debug, Clone)]
192pub struct GeoFilterResult {
193    /// Whether the request is allowed
194    pub allowed: bool,
195    /// The country code (if found)
196    pub country_code: Option<String>,
197    /// Whether this was a cache hit
198    pub cache_hit: bool,
199    /// Whether to add the country header
200    pub add_header: bool,
201    /// HTTP status code to return if blocked
202    pub status_code: u16,
203    /// Block message to return if blocked
204    pub block_message: Option<String>,
205}
206
207// =============================================================================
208// GeoFilterPool
209// =============================================================================
210
211/// A single geo filter instance with its database and cache
212pub struct GeoFilterPool {
213    /// The underlying GeoIP database (wrapped in RwLock for hot reload)
214    database: RwLock<Arc<dyn GeoDatabase>>,
215    /// IP → Country cache
216    cache: DashMap<IpAddr, CachedCountry>,
217    /// Filter configuration
218    config: GeoFilter,
219    /// Pre-computed set of countries for fast lookup
220    countries_set: HashSet<String>,
221    /// Cache TTL duration
222    cache_ttl: Duration,
223    /// Database file path for reload
224    database_path: PathBuf,
225    /// Database type
226    database_type: GeoDatabaseType,
227}
228
229impl GeoFilterPool {
230    /// Create a new geo filter pool from configuration
231    pub fn new(config: GeoFilter) -> Result<Self, GeoLookupError> {
232        // Determine database type (auto-detect from extension if not specified)
233        let db_type = config.database_type.clone().unwrap_or_else(|| {
234            if config.database_path.ends_with(".mmdb") {
235                GeoDatabaseType::MaxMind
236            } else {
237                GeoDatabaseType::Ip2Location
238            }
239        });
240
241        let database_path = PathBuf::from(&config.database_path);
242
243        // Open the database
244        let database: Arc<dyn GeoDatabase> = match db_type {
245            GeoDatabaseType::MaxMind => Arc::new(MaxMindDatabase::open(&config.database_path)?),
246            GeoDatabaseType::Ip2Location => {
247                Arc::new(Ip2LocationDatabase::open(&config.database_path)?)
248            }
249        };
250
251        // Build countries set for fast lookup
252        let countries_set: HashSet<String> = config.countries.iter().cloned().collect();
253
254        let cache_ttl = Duration::from_secs(config.cache_ttl_secs);
255
256        debug!(
257            database_path = %config.database_path,
258            database_type = ?db_type,
259            action = ?config.action,
260            countries_count = countries_set.len(),
261            cache_ttl_secs = config.cache_ttl_secs,
262            "Created GeoFilterPool"
263        );
264
265        Ok(Self {
266            database: RwLock::new(database),
267            cache: DashMap::new(),
268            config,
269            countries_set,
270            cache_ttl,
271            database_path,
272            database_type: db_type,
273        })
274    }
275
276    /// Reload the database from disk
277    ///
278    /// This atomically swaps the database and clears the cache.
279    pub fn reload_database(&self) -> Result<(), GeoLookupError> {
280        info!(
281            database_path = %self.database_path.display(),
282            database_type = ?self.database_type,
283            "Reloading geo database"
284        );
285
286        // Open the new database
287        let new_database: Arc<dyn GeoDatabase> = match self.database_type {
288            GeoDatabaseType::MaxMind => Arc::new(MaxMindDatabase::open(&self.database_path)?),
289            GeoDatabaseType::Ip2Location => {
290                Arc::new(Ip2LocationDatabase::open(&self.database_path)?)
291            }
292        };
293
294        // Atomically swap the database
295        {
296            let mut db = self.database.write();
297            *db = new_database;
298        }
299
300        // Clear the cache since country mappings may have changed
301        self.cache.clear();
302
303        info!(
304            database_path = %self.database_path.display(),
305            "Geo database reloaded successfully"
306        );
307
308        Ok(())
309    }
310
311    /// Get the database file path
312    pub fn database_path(&self) -> &Path {
313        &self.database_path
314    }
315
316    /// Check if a client IP should be allowed or blocked
317    pub fn check(&self, client_ip: &str) -> GeoFilterResult {
318        // Parse the IP address
319        let ip: IpAddr = match client_ip.parse() {
320            Ok(ip) => ip,
321            Err(_) => {
322                warn!(client_ip = %client_ip, "Failed to parse client IP for geo filter");
323                return self.handle_failure();
324            }
325        };
326
327        // Check cache first
328        let now = Instant::now();
329        if let Some(entry) = self.cache.get(&ip) {
330            if now.duration_since(entry.cached_at) < self.cache_ttl {
331                trace!(ip = %ip, country = ?entry.country_code, "Geo cache hit");
332                return self.evaluate(entry.country_code.clone(), true);
333            }
334            // Entry expired, will be replaced
335        }
336
337        // Lookup in database
338        let database = self.database.read();
339        match database.lookup(ip) {
340            Ok(country_code) => {
341                // Cache the result
342                self.cache.insert(
343                    ip,
344                    CachedCountry {
345                        country_code: country_code.clone(),
346                        cached_at: now,
347                    },
348                );
349                self.evaluate(country_code, false)
350            }
351            Err(e) => {
352                warn!(ip = %ip, error = %e, "Geo lookup failed");
353                self.handle_failure()
354            }
355        }
356    }
357
358    /// Evaluate the filter action based on country code
359    fn evaluate(&self, country_code: Option<String>, cache_hit: bool) -> GeoFilterResult {
360        let in_list = country_code
361            .as_ref()
362            .map(|c| self.countries_set.contains(c))
363            .unwrap_or(false);
364
365        let allowed = match self.config.action {
366            GeoFilterAction::Block => {
367                // Block mode: block if country is in the list
368                !in_list
369            }
370            GeoFilterAction::Allow => {
371                // Allow mode: allow only if country is in the list
372                // If no country found and list is not empty, block
373                if self.countries_set.is_empty() {
374                    true
375                } else {
376                    in_list
377                }
378            }
379            GeoFilterAction::LogOnly => {
380                // Log-only mode: always allow
381                true
382            }
383        };
384
385        trace!(
386            country = ?country_code,
387            in_list = in_list,
388            action = ?self.config.action,
389            allowed = allowed,
390            "Geo filter evaluation"
391        );
392
393        GeoFilterResult {
394            allowed,
395            country_code,
396            cache_hit,
397            add_header: self.config.add_country_header,
398            status_code: self.config.status_code,
399            block_message: self.config.block_message.clone(),
400        }
401    }
402
403    /// Handle lookup failure based on failure mode
404    fn handle_failure(&self) -> GeoFilterResult {
405        let allowed = match self.config.on_failure {
406            GeoFailureMode::Open => true,
407            GeoFailureMode::Closed => false,
408        };
409
410        GeoFilterResult {
411            allowed,
412            country_code: None,
413            cache_hit: false,
414            add_header: false,
415            status_code: self.config.status_code,
416            block_message: self.config.block_message.clone(),
417        }
418    }
419
420    /// Get cache statistics
421    pub fn cache_stats(&self) -> (usize, usize) {
422        let now = Instant::now();
423        let total = self.cache.len();
424        let valid = self
425            .cache
426            .iter()
427            .filter(|e| now.duration_since(e.cached_at) < self.cache_ttl)
428            .count();
429        (total, valid)
430    }
431
432    /// Clear expired cache entries
433    pub fn clear_expired(&self) {
434        let now = Instant::now();
435        self.cache
436            .retain(|_, v| now.duration_since(v.cached_at) < self.cache_ttl);
437    }
438}
439
440// =============================================================================
441// GeoFilterManager
442// =============================================================================
443
444/// Manages all geo filter instances
445pub struct GeoFilterManager {
446    /// Filter ID → GeoFilterPool mapping
447    filter_pools: DashMap<String, Arc<GeoFilterPool>>,
448}
449
450impl GeoFilterManager {
451    /// Create a new empty geo filter manager
452    pub fn new() -> Self {
453        Self {
454            filter_pools: DashMap::new(),
455        }
456    }
457
458    /// Register a geo filter from configuration
459    pub fn register_filter(
460        &self,
461        filter_id: &str,
462        config: GeoFilter,
463    ) -> Result<(), GeoLookupError> {
464        let pool = GeoFilterPool::new(config)?;
465        self.filter_pools
466            .insert(filter_id.to_string(), Arc::new(pool));
467        debug!(filter_id = %filter_id, "Registered geo filter");
468        Ok(())
469    }
470
471    /// Check a client IP against a specific filter
472    pub fn check(&self, filter_id: &str, client_ip: &str) -> Option<GeoFilterResult> {
473        self.filter_pools
474            .get(filter_id)
475            .map(|pool| pool.check(client_ip))
476    }
477
478    /// Get a reference to a filter pool
479    pub fn get_pool(&self, filter_id: &str) -> Option<Arc<GeoFilterPool>> {
480        self.filter_pools.get(filter_id).map(|r| r.clone())
481    }
482
483    /// Check if a filter exists
484    pub fn has_filter(&self, filter_id: &str) -> bool {
485        self.filter_pools.contains_key(filter_id)
486    }
487
488    /// Get all filter IDs
489    pub fn filter_ids(&self) -> Vec<String> {
490        self.filter_pools.iter().map(|r| r.key().clone()).collect()
491    }
492
493    /// Clear expired cache entries in all pools
494    pub fn clear_expired_caches(&self) {
495        for pool in self.filter_pools.iter() {
496            pool.clear_expired();
497        }
498    }
499
500    /// Reload a filter's database from disk
501    pub fn reload_filter(&self, filter_id: &str) -> Result<(), GeoLookupError> {
502        if let Some(pool) = self.filter_pools.get(filter_id) {
503            pool.reload_database()
504        } else {
505            Err(GeoLookupError::LoadError(format!(
506                "Filter '{}' not found",
507                filter_id
508            )))
509        }
510    }
511
512    /// Reload database for all filters using the given path
513    pub fn reload_by_path(&self, path: &Path) -> Vec<(String, Result<(), GeoLookupError>)> {
514        let mut results = Vec::new();
515        for entry in self.filter_pools.iter() {
516            if entry.value().database_path() == path {
517                let filter_id = entry.key().clone();
518                let result = entry.value().reload_database();
519                results.push((filter_id, result));
520            }
521        }
522        results
523    }
524
525    /// Get all unique database paths being used
526    pub fn database_paths(&self) -> Vec<(String, PathBuf)> {
527        self.filter_pools
528            .iter()
529            .map(|e| (e.key().clone(), e.value().database_path().to_path_buf()))
530            .collect()
531    }
532}
533
534impl Default for GeoFilterManager {
535    fn default() -> Self {
536        Self::new()
537    }
538}
539
540// =============================================================================
541// GeoDatabaseWatcher
542// =============================================================================
543
544/// Watches geo database files for changes and triggers reloads
545pub struct GeoDatabaseWatcher {
546    /// The watcher instance
547    watcher: RwLock<Option<notify::RecommendedWatcher>>,
548    /// Mapping from database path to filter IDs using it
549    path_to_filters: RwLock<HashMap<PathBuf, Vec<String>>>,
550    /// Reference to the geo filter manager
551    manager: Arc<GeoFilterManager>,
552}
553
554impl GeoDatabaseWatcher {
555    /// Create a new database watcher
556    pub fn new(manager: Arc<GeoFilterManager>) -> Self {
557        Self {
558            watcher: RwLock::new(None),
559            path_to_filters: RwLock::new(HashMap::new()),
560            manager,
561        }
562    }
563
564    /// Start watching all registered database files
565    pub fn start_watching(&self) -> Result<mpsc::Receiver<PathBuf>, GeoLookupError> {
566        // Build path → filter ID mapping
567        let db_paths = self.manager.database_paths();
568        let mut path_map: HashMap<PathBuf, Vec<String>> = HashMap::new();
569        for (filter_id, path) in db_paths {
570            path_map
571                .entry(path)
572                .or_insert_with(Vec::new)
573                .push(filter_id);
574        }
575
576        if path_map.is_empty() {
577            debug!("No geo databases to watch");
578            let (_tx, rx) = mpsc::channel(1);
579            return Ok(rx);
580        }
581
582        // Store the mapping
583        *self.path_to_filters.write() = path_map.clone();
584
585        // Create channel for events
586        let (tx, rx) = mpsc::channel::<PathBuf>(10);
587
588        // Create file watcher
589        let paths: Vec<PathBuf> = path_map.keys().cloned().collect();
590        let watcher = notify::recommended_watcher(move |event: Result<Event, notify::Error>| {
591            if let Ok(event) = event {
592                if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) {
593                    for path in &event.paths {
594                        let _ = tx.blocking_send(path.clone());
595                    }
596                }
597            }
598        })
599        .map_err(|e| {
600            GeoLookupError::LoadError(format!("Failed to create file watcher: {}", e))
601        })?;
602
603        // Store watcher
604        *self.watcher.write() = Some(watcher);
605
606        // Add watches for each database path
607        if let Some(ref mut watcher) = *self.watcher.write() {
608            for path in &paths {
609                if let Err(e) = watcher.watch(path, RecursiveMode::NonRecursive) {
610                    warn!(
611                        path = %path.display(),
612                        error = %e,
613                        "Failed to watch geo database file"
614                    );
615                } else {
616                    info!(
617                        path = %path.display(),
618                        "Watching geo database for changes"
619                    );
620                }
621            }
622        }
623
624        Ok(rx)
625    }
626
627    /// Handle a file change event
628    pub fn handle_change(&self, path: &Path) {
629        let path_map = self.path_to_filters.read();
630        if let Some(filter_ids) = path_map.get(path) {
631            info!(
632                path = %path.display(),
633                filters = ?filter_ids,
634                "Geo database file changed, reloading"
635            );
636
637            for filter_id in filter_ids {
638                match self.manager.reload_filter(filter_id) {
639                    Ok(()) => {
640                        info!(
641                            filter_id = %filter_id,
642                            "Geo filter database reloaded successfully"
643                        );
644                    }
645                    Err(e) => {
646                        error!(
647                            filter_id = %filter_id,
648                            error = %e,
649                            "Failed to reload geo filter database"
650                        );
651                    }
652                }
653            }
654        }
655    }
656
657    /// Stop watching
658    pub fn stop(&self) {
659        *self.watcher.write() = None;
660        info!("Stopped watching geo database files");
661    }
662}
663
664// =============================================================================
665// Tests
666// =============================================================================
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671
672    #[test]
673    fn test_geo_lookup_error_display() {
674        let err = GeoLookupError::InvalidIp("not-an-ip".to_string());
675        assert!(err.to_string().contains("invalid IP"));
676
677        let err = GeoLookupError::DatabaseError("db error".to_string());
678        assert!(err.to_string().contains("database error"));
679
680        let err = GeoLookupError::LoadError("load error".to_string());
681        assert!(err.to_string().contains("failed to load"));
682    }
683
684    #[test]
685    fn test_geo_filter_result_default() {
686        let result = GeoFilterResult {
687            allowed: true,
688            country_code: Some("US".to_string()),
689            cache_hit: false,
690            add_header: true,
691            status_code: 403,
692            block_message: None,
693        };
694
695        assert!(result.allowed);
696        assert_eq!(result.country_code, Some("US".to_string()));
697        assert!(!result.cache_hit);
698        assert!(result.add_header);
699    }
700
701    #[test]
702    fn test_geo_filter_manager_new() {
703        let manager = GeoFilterManager::new();
704        assert!(manager.filter_ids().is_empty());
705        assert!(!manager.has_filter("test"));
706    }
707
708    // Integration tests would require actual database files
709    // These are covered in the integration test suite
710}