1use std::collections::{HashMap, HashSet};
15use std::net::IpAddr;
16use std::path::{Path, PathBuf};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use dashmap::DashMap;
21use notify::{Event, EventKind, RecursiveMode, Watcher};
22use parking_lot::RwLock;
23use tokio::sync::mpsc;
24use tracing::{debug, error, info, trace, warn};
25
26use sentinel_config::{GeoDatabaseType, GeoFailureMode, GeoFilter, GeoFilterAction};
27
28#[derive(Debug, Clone)]
34pub enum GeoLookupError {
35 InvalidIp(String),
37 DatabaseError(String),
39 LoadError(String),
41}
42
43impl std::fmt::Display for GeoLookupError {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 GeoLookupError::InvalidIp(ip) => write!(f, "invalid IP address: {}", ip),
47 GeoLookupError::DatabaseError(msg) => write!(f, "database error: {}", msg),
48 GeoLookupError::LoadError(msg) => write!(f, "failed to load database: {}", msg),
49 }
50 }
51}
52
53impl std::error::Error for GeoLookupError {}
54
55pub trait GeoDatabase: Send + Sync {
61 fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError>;
63
64 fn database_type(&self) -> GeoDatabaseType;
66}
67
68pub struct MaxMindDatabase {
74 reader: maxminddb::Reader<Vec<u8>>,
75}
76
77impl MaxMindDatabase {
78 pub fn open(path: impl AsRef<Path>) -> Result<Self, GeoLookupError> {
80 let path = path.as_ref();
81 let reader = maxminddb::Reader::open_readfile(path).map_err(|e| {
82 GeoLookupError::LoadError(format!("failed to open MaxMind database {:?}: {}", path, e))
83 })?;
84
85 debug!(path = ?path, "Opened MaxMind GeoIP database");
86 Ok(Self { reader })
87 }
88}
89
90impl GeoDatabase for MaxMindDatabase {
91 fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError> {
92 match self.reader.lookup::<maxminddb::geoip2::Country>(ip) {
93 Ok(record) => {
94 let country_code = record
95 .country
96 .and_then(|c| c.iso_code)
97 .map(|s| s.to_string());
98 trace!(ip = %ip, country = ?country_code, "MaxMind lookup");
99 Ok(country_code)
100 }
101 Err(maxminddb::MaxMindDBError::AddressNotFoundError(_)) => {
102 trace!(ip = %ip, "IP not found in MaxMind database");
103 Ok(None)
104 }
105 Err(e) => {
106 warn!(ip = %ip, error = %e, "MaxMind lookup error");
107 Err(GeoLookupError::DatabaseError(e.to_string()))
108 }
109 }
110 }
111
112 fn database_type(&self) -> GeoDatabaseType {
113 GeoDatabaseType::MaxMind
114 }
115}
116
117pub struct Ip2LocationDatabase {
123 db: ip2location::DB,
124}
125
126impl Ip2LocationDatabase {
127 pub fn open(path: impl AsRef<Path>) -> Result<Self, GeoLookupError> {
129 let path = path.as_ref();
130 let db = ip2location::DB::from_file(path).map_err(|e| {
131 GeoLookupError::LoadError(format!(
132 "failed to open IP2Location database {:?}: {}",
133 path, e
134 ))
135 })?;
136
137 debug!(path = ?path, "Opened IP2Location GeoIP database");
138 Ok(Self { db })
139 }
140}
141
142impl GeoDatabase for Ip2LocationDatabase {
143 fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError> {
144 match self.db.ip_lookup(ip) {
145 Ok(record) => {
146 let country_code = match record {
148 ip2location::Record::LocationDb(loc) => {
149 loc.country.map(|c| c.short_name.to_string())
150 }
151 ip2location::Record::ProxyDb(proxy) => {
152 proxy.country.map(|c| c.short_name.to_string())
153 }
154 };
155 trace!(ip = %ip, country = ?country_code, "IP2Location lookup");
156 Ok(country_code)
157 }
158 Err(ip2location::error::Error::RecordNotFound) => {
159 trace!(ip = %ip, "IP not found in IP2Location database");
160 Ok(None)
161 }
162 Err(e) => {
163 warn!(ip = %ip, error = %e, "IP2Location lookup error");
164 Err(GeoLookupError::DatabaseError(e.to_string()))
165 }
166 }
167 }
168
169 fn database_type(&self) -> GeoDatabaseType {
170 GeoDatabaseType::Ip2Location
171 }
172}
173
174struct CachedCountry {
180 country_code: Option<String>,
182 cached_at: Instant,
184}
185
186#[derive(Debug, Clone)]
192pub struct GeoFilterResult {
193 pub allowed: bool,
195 pub country_code: Option<String>,
197 pub cache_hit: bool,
199 pub add_header: bool,
201 pub status_code: u16,
203 pub block_message: Option<String>,
205}
206
207pub struct GeoFilterPool {
213 database: RwLock<Arc<dyn GeoDatabase>>,
215 cache: DashMap<IpAddr, CachedCountry>,
217 config: GeoFilter,
219 countries_set: HashSet<String>,
221 cache_ttl: Duration,
223 database_path: PathBuf,
225 database_type: GeoDatabaseType,
227}
228
229impl GeoFilterPool {
230 pub fn new(config: GeoFilter) -> Result<Self, GeoLookupError> {
232 let db_type = config.database_type.clone().unwrap_or_else(|| {
234 if config.database_path.ends_with(".mmdb") {
235 GeoDatabaseType::MaxMind
236 } else {
237 GeoDatabaseType::Ip2Location
238 }
239 });
240
241 let database_path = PathBuf::from(&config.database_path);
242
243 let database: Arc<dyn GeoDatabase> = match db_type {
245 GeoDatabaseType::MaxMind => Arc::new(MaxMindDatabase::open(&config.database_path)?),
246 GeoDatabaseType::Ip2Location => {
247 Arc::new(Ip2LocationDatabase::open(&config.database_path)?)
248 }
249 };
250
251 let countries_set: HashSet<String> = config.countries.iter().cloned().collect();
253
254 let cache_ttl = Duration::from_secs(config.cache_ttl_secs);
255
256 debug!(
257 database_path = %config.database_path,
258 database_type = ?db_type,
259 action = ?config.action,
260 countries_count = countries_set.len(),
261 cache_ttl_secs = config.cache_ttl_secs,
262 "Created GeoFilterPool"
263 );
264
265 Ok(Self {
266 database: RwLock::new(database),
267 cache: DashMap::new(),
268 config,
269 countries_set,
270 cache_ttl,
271 database_path,
272 database_type: db_type,
273 })
274 }
275
276 pub fn reload_database(&self) -> Result<(), GeoLookupError> {
280 info!(
281 database_path = %self.database_path.display(),
282 database_type = ?self.database_type,
283 "Reloading geo database"
284 );
285
286 let new_database: Arc<dyn GeoDatabase> = match self.database_type {
288 GeoDatabaseType::MaxMind => Arc::new(MaxMindDatabase::open(&self.database_path)?),
289 GeoDatabaseType::Ip2Location => {
290 Arc::new(Ip2LocationDatabase::open(&self.database_path)?)
291 }
292 };
293
294 {
296 let mut db = self.database.write();
297 *db = new_database;
298 }
299
300 self.cache.clear();
302
303 info!(
304 database_path = %self.database_path.display(),
305 "Geo database reloaded successfully"
306 );
307
308 Ok(())
309 }
310
311 pub fn database_path(&self) -> &Path {
313 &self.database_path
314 }
315
316 pub fn check(&self, client_ip: &str) -> GeoFilterResult {
318 let ip: IpAddr = match client_ip.parse() {
320 Ok(ip) => ip,
321 Err(_) => {
322 warn!(client_ip = %client_ip, "Failed to parse client IP for geo filter");
323 return self.handle_failure();
324 }
325 };
326
327 let now = Instant::now();
329 if let Some(entry) = self.cache.get(&ip) {
330 if now.duration_since(entry.cached_at) < self.cache_ttl {
331 trace!(ip = %ip, country = ?entry.country_code, "Geo cache hit");
332 return self.evaluate(entry.country_code.clone(), true);
333 }
334 }
336
337 let database = self.database.read();
339 match database.lookup(ip) {
340 Ok(country_code) => {
341 self.cache.insert(
343 ip,
344 CachedCountry {
345 country_code: country_code.clone(),
346 cached_at: now,
347 },
348 );
349 self.evaluate(country_code, false)
350 }
351 Err(e) => {
352 warn!(ip = %ip, error = %e, "Geo lookup failed");
353 self.handle_failure()
354 }
355 }
356 }
357
358 fn evaluate(&self, country_code: Option<String>, cache_hit: bool) -> GeoFilterResult {
360 let in_list = country_code
361 .as_ref()
362 .map(|c| self.countries_set.contains(c))
363 .unwrap_or(false);
364
365 let allowed = match self.config.action {
366 GeoFilterAction::Block => {
367 !in_list
369 }
370 GeoFilterAction::Allow => {
371 if self.countries_set.is_empty() {
374 true
375 } else {
376 in_list
377 }
378 }
379 GeoFilterAction::LogOnly => {
380 true
382 }
383 };
384
385 trace!(
386 country = ?country_code,
387 in_list = in_list,
388 action = ?self.config.action,
389 allowed = allowed,
390 "Geo filter evaluation"
391 );
392
393 GeoFilterResult {
394 allowed,
395 country_code,
396 cache_hit,
397 add_header: self.config.add_country_header,
398 status_code: self.config.status_code,
399 block_message: self.config.block_message.clone(),
400 }
401 }
402
403 fn handle_failure(&self) -> GeoFilterResult {
405 let allowed = match self.config.on_failure {
406 GeoFailureMode::Open => true,
407 GeoFailureMode::Closed => false,
408 };
409
410 GeoFilterResult {
411 allowed,
412 country_code: None,
413 cache_hit: false,
414 add_header: false,
415 status_code: self.config.status_code,
416 block_message: self.config.block_message.clone(),
417 }
418 }
419
420 pub fn cache_stats(&self) -> (usize, usize) {
422 let now = Instant::now();
423 let total = self.cache.len();
424 let valid = self
425 .cache
426 .iter()
427 .filter(|e| now.duration_since(e.cached_at) < self.cache_ttl)
428 .count();
429 (total, valid)
430 }
431
432 pub fn clear_expired(&self) {
434 let now = Instant::now();
435 self.cache
436 .retain(|_, v| now.duration_since(v.cached_at) < self.cache_ttl);
437 }
438}
439
440pub struct GeoFilterManager {
446 filter_pools: DashMap<String, Arc<GeoFilterPool>>,
448}
449
450impl GeoFilterManager {
451 pub fn new() -> Self {
453 Self {
454 filter_pools: DashMap::new(),
455 }
456 }
457
458 pub fn register_filter(
460 &self,
461 filter_id: &str,
462 config: GeoFilter,
463 ) -> Result<(), GeoLookupError> {
464 let pool = GeoFilterPool::new(config)?;
465 self.filter_pools
466 .insert(filter_id.to_string(), Arc::new(pool));
467 debug!(filter_id = %filter_id, "Registered geo filter");
468 Ok(())
469 }
470
471 pub fn check(&self, filter_id: &str, client_ip: &str) -> Option<GeoFilterResult> {
473 self.filter_pools
474 .get(filter_id)
475 .map(|pool| pool.check(client_ip))
476 }
477
478 pub fn get_pool(&self, filter_id: &str) -> Option<Arc<GeoFilterPool>> {
480 self.filter_pools.get(filter_id).map(|r| r.clone())
481 }
482
483 pub fn has_filter(&self, filter_id: &str) -> bool {
485 self.filter_pools.contains_key(filter_id)
486 }
487
488 pub fn filter_ids(&self) -> Vec<String> {
490 self.filter_pools.iter().map(|r| r.key().clone()).collect()
491 }
492
493 pub fn clear_expired_caches(&self) {
495 for pool in self.filter_pools.iter() {
496 pool.clear_expired();
497 }
498 }
499
500 pub fn reload_filter(&self, filter_id: &str) -> Result<(), GeoLookupError> {
502 if let Some(pool) = self.filter_pools.get(filter_id) {
503 pool.reload_database()
504 } else {
505 Err(GeoLookupError::LoadError(format!(
506 "Filter '{}' not found",
507 filter_id
508 )))
509 }
510 }
511
512 pub fn reload_by_path(&self, path: &Path) -> Vec<(String, Result<(), GeoLookupError>)> {
514 let mut results = Vec::new();
515 for entry in self.filter_pools.iter() {
516 if entry.value().database_path() == path {
517 let filter_id = entry.key().clone();
518 let result = entry.value().reload_database();
519 results.push((filter_id, result));
520 }
521 }
522 results
523 }
524
525 pub fn database_paths(&self) -> Vec<(String, PathBuf)> {
527 self.filter_pools
528 .iter()
529 .map(|e| (e.key().clone(), e.value().database_path().to_path_buf()))
530 .collect()
531 }
532}
533
534impl Default for GeoFilterManager {
535 fn default() -> Self {
536 Self::new()
537 }
538}
539
540pub struct GeoDatabaseWatcher {
546 watcher: RwLock<Option<notify::RecommendedWatcher>>,
548 path_to_filters: RwLock<HashMap<PathBuf, Vec<String>>>,
550 manager: Arc<GeoFilterManager>,
552}
553
554impl GeoDatabaseWatcher {
555 pub fn new(manager: Arc<GeoFilterManager>) -> Self {
557 Self {
558 watcher: RwLock::new(None),
559 path_to_filters: RwLock::new(HashMap::new()),
560 manager,
561 }
562 }
563
564 pub fn start_watching(&self) -> Result<mpsc::Receiver<PathBuf>, GeoLookupError> {
566 let db_paths = self.manager.database_paths();
568 let mut path_map: HashMap<PathBuf, Vec<String>> = HashMap::new();
569 for (filter_id, path) in db_paths {
570 path_map
571 .entry(path)
572 .or_insert_with(Vec::new)
573 .push(filter_id);
574 }
575
576 if path_map.is_empty() {
577 debug!("No geo databases to watch");
578 let (_tx, rx) = mpsc::channel(1);
579 return Ok(rx);
580 }
581
582 *self.path_to_filters.write() = path_map.clone();
584
585 let (tx, rx) = mpsc::channel::<PathBuf>(10);
587
588 let paths: Vec<PathBuf> = path_map.keys().cloned().collect();
590 let watcher = notify::recommended_watcher(move |event: Result<Event, notify::Error>| {
591 if let Ok(event) = event {
592 if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) {
593 for path in &event.paths {
594 let _ = tx.blocking_send(path.clone());
595 }
596 }
597 }
598 })
599 .map_err(|e| {
600 GeoLookupError::LoadError(format!("Failed to create file watcher: {}", e))
601 })?;
602
603 *self.watcher.write() = Some(watcher);
605
606 if let Some(ref mut watcher) = *self.watcher.write() {
608 for path in &paths {
609 if let Err(e) = watcher.watch(path, RecursiveMode::NonRecursive) {
610 warn!(
611 path = %path.display(),
612 error = %e,
613 "Failed to watch geo database file"
614 );
615 } else {
616 info!(
617 path = %path.display(),
618 "Watching geo database for changes"
619 );
620 }
621 }
622 }
623
624 Ok(rx)
625 }
626
627 pub fn handle_change(&self, path: &Path) {
629 let path_map = self.path_to_filters.read();
630 if let Some(filter_ids) = path_map.get(path) {
631 info!(
632 path = %path.display(),
633 filters = ?filter_ids,
634 "Geo database file changed, reloading"
635 );
636
637 for filter_id in filter_ids {
638 match self.manager.reload_filter(filter_id) {
639 Ok(()) => {
640 info!(
641 filter_id = %filter_id,
642 "Geo filter database reloaded successfully"
643 );
644 }
645 Err(e) => {
646 error!(
647 filter_id = %filter_id,
648 error = %e,
649 "Failed to reload geo filter database"
650 );
651 }
652 }
653 }
654 }
655 }
656
657 pub fn stop(&self) {
659 *self.watcher.write() = None;
660 info!("Stopped watching geo database files");
661 }
662}
663
664#[cfg(test)]
669mod tests {
670 use super::*;
671
672 #[test]
673 fn test_geo_lookup_error_display() {
674 let err = GeoLookupError::InvalidIp("not-an-ip".to_string());
675 assert!(err.to_string().contains("invalid IP"));
676
677 let err = GeoLookupError::DatabaseError("db error".to_string());
678 assert!(err.to_string().contains("database error"));
679
680 let err = GeoLookupError::LoadError("load error".to_string());
681 assert!(err.to_string().contains("failed to load"));
682 }
683
684 #[test]
685 fn test_geo_filter_result_default() {
686 let result = GeoFilterResult {
687 allowed: true,
688 country_code: Some("US".to_string()),
689 cache_hit: false,
690 add_header: true,
691 status_code: 403,
692 block_message: None,
693 };
694
695 assert!(result.allowed);
696 assert_eq!(result.country_code, Some("US".to_string()));
697 assert!(!result.cache_hit);
698 assert!(result.add_header);
699 }
700
701 #[test]
702 fn test_geo_filter_manager_new() {
703 let manager = GeoFilterManager::new();
704 assert!(manager.filter_ids().is_empty());
705 assert!(!manager.has_filter("test"));
706 }
707
708 }