sentinel_proxy/
geo_filter.rs

1//! GeoIP filtering for Sentinel proxy
2//!
3//! This module provides geolocation-based request filtering using MaxMind GeoLite2/GeoIP2
4//! and IP2Location databases. Filters can block, allow, or log requests based on country.
5//!
6//! # Features
7//! - Support for MaxMind (.mmdb) and IP2Location (.bin) databases
8//! - Block mode (blocklist) and Allow mode (allowlist)
9//! - Log-only mode for monitoring without blocking
10//! - Per-filter IP→Country caching with configurable TTL
11//! - Configurable fail-open/fail-closed on lookup errors
12//! - X-GeoIP-Country response header injection
13
14use std::collections::{HashMap, HashSet};
15use std::net::IpAddr;
16use std::path::{Path, PathBuf};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use dashmap::DashMap;
21use notify::{Event, EventKind, RecursiveMode, Watcher};
22use parking_lot::RwLock;
23use tokio::sync::mpsc;
24use tracing::{debug, error, info, trace, warn};
25
26use sentinel_config::{GeoDatabaseType, GeoFailureMode, GeoFilter, GeoFilterAction};
27
28// =============================================================================
29// Error Types
30// =============================================================================
31
32/// Errors that can occur during geo lookup
33#[derive(Debug, Clone)]
34pub enum GeoLookupError {
35    /// IP address could not be parsed
36    InvalidIp(String),
37    /// Database error during lookup
38    DatabaseError(String),
39    /// Database file could not be loaded
40    LoadError(String),
41}
42
43impl std::fmt::Display for GeoLookupError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            GeoLookupError::InvalidIp(ip) => write!(f, "invalid IP address: {}", ip),
47            GeoLookupError::DatabaseError(msg) => write!(f, "database error: {}", msg),
48            GeoLookupError::LoadError(msg) => write!(f, "failed to load database: {}", msg),
49        }
50    }
51}
52
53impl std::error::Error for GeoLookupError {}
54
55// =============================================================================
56// GeoDatabase Trait
57// =============================================================================
58
59/// Trait for GeoIP database backends
60pub trait GeoDatabase: Send + Sync {
61    /// Look up the country code for an IP address
62    fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError>;
63
64    /// Get the database type
65    fn database_type(&self) -> GeoDatabaseType;
66}
67
68// =============================================================================
69// MaxMind Database Backend
70// =============================================================================
71
72/// MaxMind GeoLite2/GeoIP2 database backend
73pub struct MaxMindDatabase {
74    reader: maxminddb::Reader<Vec<u8>>,
75}
76
77impl MaxMindDatabase {
78    /// Open a MaxMind database file
79    pub fn open(path: impl AsRef<Path>) -> Result<Self, GeoLookupError> {
80        let path = path.as_ref();
81        let reader = maxminddb::Reader::open_readfile(path).map_err(|e| {
82            GeoLookupError::LoadError(format!("failed to open MaxMind database {:?}: {}", path, e))
83        })?;
84
85        debug!(path = ?path, "Opened MaxMind GeoIP database");
86        Ok(Self { reader })
87    }
88}
89
90impl GeoDatabase for MaxMindDatabase {
91    fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError> {
92        match self.reader.lookup(ip) {
93            Ok(result) => {
94                if !result.has_data() {
95                    trace!(ip = %ip, "IP not found in MaxMind database");
96                    return Ok(None);
97                }
98                match result.decode::<maxminddb::geoip2::Country>() {
99                    Ok(Some(record)) => {
100                        let country_code = record.country.iso_code.map(|s| s.to_string());
101                        trace!(ip = %ip, country = ?country_code, "MaxMind lookup");
102                        Ok(country_code)
103                    }
104                    Ok(None) => {
105                        trace!(ip = %ip, "No country data for IP in MaxMind database");
106                        Ok(None)
107                    }
108                    Err(e) => {
109                        warn!(ip = %ip, error = %e, "MaxMind decode error");
110                        Err(GeoLookupError::DatabaseError(e.to_string()))
111                    }
112                }
113            }
114            Err(e) => {
115                warn!(ip = %ip, error = %e, "MaxMind lookup error");
116                Err(GeoLookupError::DatabaseError(e.to_string()))
117            }
118        }
119    }
120
121    fn database_type(&self) -> GeoDatabaseType {
122        GeoDatabaseType::MaxMind
123    }
124}
125
126// =============================================================================
127// IP2Location Database Backend
128// =============================================================================
129
130/// IP2Location database backend
131pub struct Ip2LocationDatabase {
132    db: ip2location::DB,
133}
134
135impl Ip2LocationDatabase {
136    /// Open an IP2Location database file
137    pub fn open(path: impl AsRef<Path>) -> Result<Self, GeoLookupError> {
138        let path = path.as_ref();
139        let db = ip2location::DB::from_file(path).map_err(|e| {
140            GeoLookupError::LoadError(format!(
141                "failed to open IP2Location database {:?}: {}",
142                path, e
143            ))
144        })?;
145
146        debug!(path = ?path, "Opened IP2Location GeoIP database");
147        Ok(Self { db })
148    }
149}
150
151impl GeoDatabase for Ip2LocationDatabase {
152    fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError> {
153        match self.db.ip_lookup(ip) {
154            Ok(record) => {
155                // Record is an enum - extract country from the LocationDb variant
156                let country_code = match record {
157                    ip2location::Record::LocationDb(loc) => {
158                        loc.country.map(|c| c.short_name.to_string())
159                    }
160                    ip2location::Record::ProxyDb(proxy) => {
161                        proxy.country.map(|c| c.short_name.to_string())
162                    }
163                };
164                trace!(ip = %ip, country = ?country_code, "IP2Location lookup");
165                Ok(country_code)
166            }
167            Err(ip2location::error::Error::RecordNotFound) => {
168                trace!(ip = %ip, "IP not found in IP2Location database");
169                Ok(None)
170            }
171            Err(e) => {
172                warn!(ip = %ip, error = %e, "IP2Location lookup error");
173                Err(GeoLookupError::DatabaseError(e.to_string()))
174            }
175        }
176    }
177
178    fn database_type(&self) -> GeoDatabaseType {
179        GeoDatabaseType::Ip2Location
180    }
181}
182
183// =============================================================================
184// Cached Country Entry
185// =============================================================================
186
187/// Cached country lookup result
188struct CachedCountry {
189    /// The country code (or None if not found)
190    country_code: Option<String>,
191    /// When this entry was cached
192    cached_at: Instant,
193}
194
195// =============================================================================
196// GeoFilterResult
197// =============================================================================
198
199/// Result of a geo filter check
200#[derive(Debug, Clone)]
201pub struct GeoFilterResult {
202    /// Whether the request is allowed
203    pub allowed: bool,
204    /// The country code (if found)
205    pub country_code: Option<String>,
206    /// Whether this was a cache hit
207    pub cache_hit: bool,
208    /// Whether to add the country header
209    pub add_header: bool,
210    /// HTTP status code to return if blocked
211    pub status_code: u16,
212    /// Block message to return if blocked
213    pub block_message: Option<String>,
214}
215
216// =============================================================================
217// GeoFilterPool
218// =============================================================================
219
220/// A single geo filter instance with its database and cache
221pub struct GeoFilterPool {
222    /// The underlying GeoIP database (wrapped in RwLock for hot reload)
223    database: RwLock<Arc<dyn GeoDatabase>>,
224    /// IP → Country cache
225    cache: DashMap<IpAddr, CachedCountry>,
226    /// Filter configuration
227    config: GeoFilter,
228    /// Pre-computed set of countries for fast lookup
229    countries_set: HashSet<String>,
230    /// Cache TTL duration
231    cache_ttl: Duration,
232    /// Database file path for reload
233    database_path: PathBuf,
234    /// Database type
235    database_type: GeoDatabaseType,
236}
237
238impl GeoFilterPool {
239    /// Create a new geo filter pool from configuration
240    pub fn new(config: GeoFilter) -> Result<Self, GeoLookupError> {
241        // Determine database type (auto-detect from extension if not specified)
242        let db_type = config.database_type.clone().unwrap_or_else(|| {
243            if config.database_path.ends_with(".mmdb") {
244                GeoDatabaseType::MaxMind
245            } else {
246                GeoDatabaseType::Ip2Location
247            }
248        });
249
250        let database_path = PathBuf::from(&config.database_path);
251
252        // Open the database
253        let database: Arc<dyn GeoDatabase> = match db_type {
254            GeoDatabaseType::MaxMind => Arc::new(MaxMindDatabase::open(&config.database_path)?),
255            GeoDatabaseType::Ip2Location => {
256                Arc::new(Ip2LocationDatabase::open(&config.database_path)?)
257            }
258        };
259
260        // Build countries set for fast lookup
261        let countries_set: HashSet<String> = config.countries.iter().cloned().collect();
262
263        let cache_ttl = Duration::from_secs(config.cache_ttl_secs);
264
265        debug!(
266            database_path = %config.database_path,
267            database_type = ?db_type,
268            action = ?config.action,
269            countries_count = countries_set.len(),
270            cache_ttl_secs = config.cache_ttl_secs,
271            "Created GeoFilterPool"
272        );
273
274        Ok(Self {
275            database: RwLock::new(database),
276            cache: DashMap::new(),
277            config,
278            countries_set,
279            cache_ttl,
280            database_path,
281            database_type: db_type,
282        })
283    }
284
285    /// Reload the database from disk
286    ///
287    /// This atomically swaps the database and clears the cache.
288    pub fn reload_database(&self) -> Result<(), GeoLookupError> {
289        info!(
290            database_path = %self.database_path.display(),
291            database_type = ?self.database_type,
292            "Reloading geo database"
293        );
294
295        // Open the new database
296        let new_database: Arc<dyn GeoDatabase> = match self.database_type {
297            GeoDatabaseType::MaxMind => Arc::new(MaxMindDatabase::open(&self.database_path)?),
298            GeoDatabaseType::Ip2Location => {
299                Arc::new(Ip2LocationDatabase::open(&self.database_path)?)
300            }
301        };
302
303        // Atomically swap the database
304        {
305            let mut db = self.database.write();
306            *db = new_database;
307        }
308
309        // Clear the cache since country mappings may have changed
310        self.cache.clear();
311
312        info!(
313            database_path = %self.database_path.display(),
314            "Geo database reloaded successfully"
315        );
316
317        Ok(())
318    }
319
320    /// Get the database file path
321    pub fn database_path(&self) -> &Path {
322        &self.database_path
323    }
324
325    /// Check if a client IP should be allowed or blocked
326    pub fn check(&self, client_ip: &str) -> GeoFilterResult {
327        // Parse the IP address
328        let ip: IpAddr = match client_ip.parse() {
329            Ok(ip) => ip,
330            Err(_) => {
331                warn!(client_ip = %client_ip, "Failed to parse client IP for geo filter");
332                return self.handle_failure();
333            }
334        };
335
336        // Check cache first
337        let now = Instant::now();
338        if let Some(entry) = self.cache.get(&ip) {
339            if now.duration_since(entry.cached_at) < self.cache_ttl {
340                trace!(ip = %ip, country = ?entry.country_code, "Geo cache hit");
341                return self.evaluate(entry.country_code.clone(), true);
342            }
343            // Entry expired, will be replaced
344        }
345
346        // Lookup in database
347        let database = self.database.read();
348        match database.lookup(ip) {
349            Ok(country_code) => {
350                // Cache the result
351                self.cache.insert(
352                    ip,
353                    CachedCountry {
354                        country_code: country_code.clone(),
355                        cached_at: now,
356                    },
357                );
358                self.evaluate(country_code, false)
359            }
360            Err(e) => {
361                warn!(ip = %ip, error = %e, "Geo lookup failed");
362                self.handle_failure()
363            }
364        }
365    }
366
367    /// Evaluate the filter action based on country code
368    fn evaluate(&self, country_code: Option<String>, cache_hit: bool) -> GeoFilterResult {
369        let in_list = country_code
370            .as_ref()
371            .map(|c| self.countries_set.contains(c))
372            .unwrap_or(false);
373
374        let allowed = match self.config.action {
375            GeoFilterAction::Block => {
376                // Block mode: block if country is in the list
377                !in_list
378            }
379            GeoFilterAction::Allow => {
380                // Allow mode: allow only if country is in the list
381                // If no country found and list is not empty, block
382                if self.countries_set.is_empty() {
383                    true
384                } else {
385                    in_list
386                }
387            }
388            GeoFilterAction::LogOnly => {
389                // Log-only mode: always allow
390                true
391            }
392        };
393
394        trace!(
395            country = ?country_code,
396            in_list = in_list,
397            action = ?self.config.action,
398            allowed = allowed,
399            "Geo filter evaluation"
400        );
401
402        GeoFilterResult {
403            allowed,
404            country_code,
405            cache_hit,
406            add_header: self.config.add_country_header,
407            status_code: self.config.status_code,
408            block_message: self.config.block_message.clone(),
409        }
410    }
411
412    /// Handle lookup failure based on failure mode
413    fn handle_failure(&self) -> GeoFilterResult {
414        let allowed = match self.config.on_failure {
415            GeoFailureMode::Open => true,
416            GeoFailureMode::Closed => false,
417        };
418
419        GeoFilterResult {
420            allowed,
421            country_code: None,
422            cache_hit: false,
423            add_header: false,
424            status_code: self.config.status_code,
425            block_message: self.config.block_message.clone(),
426        }
427    }
428
429    /// Get cache statistics
430    pub fn cache_stats(&self) -> (usize, usize) {
431        let now = Instant::now();
432        let total = self.cache.len();
433        let valid = self
434            .cache
435            .iter()
436            .filter(|e| now.duration_since(e.cached_at) < self.cache_ttl)
437            .count();
438        (total, valid)
439    }
440
441    /// Clear expired cache entries
442    pub fn clear_expired(&self) {
443        let now = Instant::now();
444        self.cache
445            .retain(|_, v| now.duration_since(v.cached_at) < self.cache_ttl);
446    }
447}
448
449// =============================================================================
450// GeoFilterManager
451// =============================================================================
452
453/// Manages all geo filter instances
454pub struct GeoFilterManager {
455    /// Filter ID → GeoFilterPool mapping
456    filter_pools: DashMap<String, Arc<GeoFilterPool>>,
457}
458
459impl GeoFilterManager {
460    /// Create a new empty geo filter manager
461    pub fn new() -> Self {
462        Self {
463            filter_pools: DashMap::new(),
464        }
465    }
466
467    /// Register a geo filter from configuration
468    pub fn register_filter(
469        &self,
470        filter_id: &str,
471        config: GeoFilter,
472    ) -> Result<(), GeoLookupError> {
473        let pool = GeoFilterPool::new(config)?;
474        self.filter_pools
475            .insert(filter_id.to_string(), Arc::new(pool));
476        debug!(filter_id = %filter_id, "Registered geo filter");
477        Ok(())
478    }
479
480    /// Check a client IP against a specific filter
481    pub fn check(&self, filter_id: &str, client_ip: &str) -> Option<GeoFilterResult> {
482        self.filter_pools
483            .get(filter_id)
484            .map(|pool| pool.check(client_ip))
485    }
486
487    /// Get a reference to a filter pool
488    pub fn get_pool(&self, filter_id: &str) -> Option<Arc<GeoFilterPool>> {
489        self.filter_pools.get(filter_id).map(|r| r.clone())
490    }
491
492    /// Check if a filter exists
493    pub fn has_filter(&self, filter_id: &str) -> bool {
494        self.filter_pools.contains_key(filter_id)
495    }
496
497    /// Get all filter IDs
498    pub fn filter_ids(&self) -> Vec<String> {
499        self.filter_pools.iter().map(|r| r.key().clone()).collect()
500    }
501
502    /// Clear expired cache entries in all pools
503    pub fn clear_expired_caches(&self) {
504        for pool in self.filter_pools.iter() {
505            pool.clear_expired();
506        }
507    }
508
509    /// Reload a filter's database from disk
510    pub fn reload_filter(&self, filter_id: &str) -> Result<(), GeoLookupError> {
511        if let Some(pool) = self.filter_pools.get(filter_id) {
512            pool.reload_database()
513        } else {
514            Err(GeoLookupError::LoadError(format!(
515                "Filter '{}' not found",
516                filter_id
517            )))
518        }
519    }
520
521    /// Reload database for all filters using the given path
522    pub fn reload_by_path(&self, path: &Path) -> Vec<(String, Result<(), GeoLookupError>)> {
523        let mut results = Vec::new();
524        for entry in self.filter_pools.iter() {
525            if entry.value().database_path() == path {
526                let filter_id = entry.key().clone();
527                let result = entry.value().reload_database();
528                results.push((filter_id, result));
529            }
530        }
531        results
532    }
533
534    /// Get all unique database paths being used
535    pub fn database_paths(&self) -> Vec<(String, PathBuf)> {
536        self.filter_pools
537            .iter()
538            .map(|e| (e.key().clone(), e.value().database_path().to_path_buf()))
539            .collect()
540    }
541}
542
543impl Default for GeoFilterManager {
544    fn default() -> Self {
545        Self::new()
546    }
547}
548
549// =============================================================================
550// GeoDatabaseWatcher
551// =============================================================================
552
553/// Watches geo database files for changes and triggers reloads
554pub struct GeoDatabaseWatcher {
555    /// The watcher instance
556    watcher: RwLock<Option<notify::RecommendedWatcher>>,
557    /// Mapping from database path to filter IDs using it
558    path_to_filters: RwLock<HashMap<PathBuf, Vec<String>>>,
559    /// Reference to the geo filter manager
560    manager: Arc<GeoFilterManager>,
561}
562
563impl GeoDatabaseWatcher {
564    /// Create a new database watcher
565    pub fn new(manager: Arc<GeoFilterManager>) -> Self {
566        Self {
567            watcher: RwLock::new(None),
568            path_to_filters: RwLock::new(HashMap::new()),
569            manager,
570        }
571    }
572
573    /// Start watching all registered database files
574    pub fn start_watching(&self) -> Result<mpsc::Receiver<PathBuf>, GeoLookupError> {
575        // Build path → filter ID mapping
576        let db_paths = self.manager.database_paths();
577        let mut path_map: HashMap<PathBuf, Vec<String>> = HashMap::new();
578        for (filter_id, path) in db_paths {
579            path_map
580                .entry(path)
581                .or_default()
582                .push(filter_id);
583        }
584
585        if path_map.is_empty() {
586            debug!("No geo databases to watch");
587            let (_tx, rx) = mpsc::channel(1);
588            return Ok(rx);
589        }
590
591        // Store the mapping
592        *self.path_to_filters.write() = path_map.clone();
593
594        // Create channel for events
595        let (tx, rx) = mpsc::channel::<PathBuf>(10);
596
597        // Create file watcher
598        let paths: Vec<PathBuf> = path_map.keys().cloned().collect();
599        let watcher = notify::recommended_watcher(move |event: Result<Event, notify::Error>| {
600            if let Ok(event) = event {
601                if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) {
602                    for path in &event.paths {
603                        let _ = tx.blocking_send(path.clone());
604                    }
605                }
606            }
607        })
608        .map_err(|e| {
609            GeoLookupError::LoadError(format!("Failed to create file watcher: {}", e))
610        })?;
611
612        // Store watcher
613        *self.watcher.write() = Some(watcher);
614
615        // Add watches for each database path
616        if let Some(ref mut watcher) = *self.watcher.write() {
617            for path in &paths {
618                if let Err(e) = watcher.watch(path, RecursiveMode::NonRecursive) {
619                    warn!(
620                        path = %path.display(),
621                        error = %e,
622                        "Failed to watch geo database file"
623                    );
624                } else {
625                    info!(
626                        path = %path.display(),
627                        "Watching geo database for changes"
628                    );
629                }
630            }
631        }
632
633        Ok(rx)
634    }
635
636    /// Handle a file change event
637    pub fn handle_change(&self, path: &Path) {
638        let path_map = self.path_to_filters.read();
639        if let Some(filter_ids) = path_map.get(path) {
640            info!(
641                path = %path.display(),
642                filters = ?filter_ids,
643                "Geo database file changed, reloading"
644            );
645
646            for filter_id in filter_ids {
647                match self.manager.reload_filter(filter_id) {
648                    Ok(()) => {
649                        info!(
650                            filter_id = %filter_id,
651                            "Geo filter database reloaded successfully"
652                        );
653                    }
654                    Err(e) => {
655                        error!(
656                            filter_id = %filter_id,
657                            error = %e,
658                            "Failed to reload geo filter database"
659                        );
660                    }
661                }
662            }
663        }
664    }
665
666    /// Stop watching
667    pub fn stop(&self) {
668        *self.watcher.write() = None;
669        info!("Stopped watching geo database files");
670    }
671}
672
673// =============================================================================
674// Tests
675// =============================================================================
676
677#[cfg(test)]
678mod tests {
679    use super::*;
680
681    #[test]
682    fn test_geo_lookup_error_display() {
683        let err = GeoLookupError::InvalidIp("not-an-ip".to_string());
684        assert!(err.to_string().contains("invalid IP"));
685
686        let err = GeoLookupError::DatabaseError("db error".to_string());
687        assert!(err.to_string().contains("database error"));
688
689        let err = GeoLookupError::LoadError("load error".to_string());
690        assert!(err.to_string().contains("failed to load"));
691    }
692
693    #[test]
694    fn test_geo_filter_result_default() {
695        let result = GeoFilterResult {
696            allowed: true,
697            country_code: Some("US".to_string()),
698            cache_hit: false,
699            add_header: true,
700            status_code: 403,
701            block_message: None,
702        };
703
704        assert!(result.allowed);
705        assert_eq!(result.country_code, Some("US".to_string()));
706        assert!(!result.cache_hit);
707        assert!(result.add_header);
708    }
709
710    #[test]
711    fn test_geo_filter_manager_new() {
712        let manager = GeoFilterManager::new();
713        assert!(manager.filter_ids().is_empty());
714        assert!(!manager.has_filter("test"));
715    }
716
717    // Integration tests would require actual database files
718    // These are covered in the integration test suite
719}