1use std::collections::{HashMap, HashSet};
15use std::net::IpAddr;
16use std::path::{Path, PathBuf};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use dashmap::DashMap;
21use notify::{Event, EventKind, RecursiveMode, Watcher};
22use parking_lot::RwLock;
23use tokio::sync::mpsc;
24use tracing::{debug, error, info, trace, warn};
25
26use sentinel_config::{GeoDatabaseType, GeoFailureMode, GeoFilter, GeoFilterAction};
27
28#[derive(Debug, Clone)]
34pub enum GeoLookupError {
35 InvalidIp(String),
37 DatabaseError(String),
39 LoadError(String),
41}
42
43impl std::fmt::Display for GeoLookupError {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 GeoLookupError::InvalidIp(ip) => write!(f, "invalid IP address: {}", ip),
47 GeoLookupError::DatabaseError(msg) => write!(f, "database error: {}", msg),
48 GeoLookupError::LoadError(msg) => write!(f, "failed to load database: {}", msg),
49 }
50 }
51}
52
53impl std::error::Error for GeoLookupError {}
54
55pub trait GeoDatabase: Send + Sync {
61 fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError>;
63
64 fn database_type(&self) -> GeoDatabaseType;
66}
67
68pub struct MaxMindDatabase {
74 reader: maxminddb::Reader<Vec<u8>>,
75}
76
77impl MaxMindDatabase {
78 pub fn open(path: impl AsRef<Path>) -> Result<Self, GeoLookupError> {
80 let path = path.as_ref();
81 let reader = maxminddb::Reader::open_readfile(path).map_err(|e| {
82 GeoLookupError::LoadError(format!("failed to open MaxMind database {:?}: {}", path, e))
83 })?;
84
85 debug!(path = ?path, "Opened MaxMind GeoIP database");
86 Ok(Self { reader })
87 }
88}
89
90impl GeoDatabase for MaxMindDatabase {
91 fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError> {
92 match self.reader.lookup(ip) {
93 Ok(result) => {
94 if !result.has_data() {
95 trace!(ip = %ip, "IP not found in MaxMind database");
96 return Ok(None);
97 }
98 match result.decode::<maxminddb::geoip2::Country>() {
99 Ok(Some(record)) => {
100 let country_code = record.country.iso_code.map(|s| s.to_string());
101 trace!(ip = %ip, country = ?country_code, "MaxMind lookup");
102 Ok(country_code)
103 }
104 Ok(None) => {
105 trace!(ip = %ip, "No country data for IP in MaxMind database");
106 Ok(None)
107 }
108 Err(e) => {
109 warn!(ip = %ip, error = %e, "MaxMind decode error");
110 Err(GeoLookupError::DatabaseError(e.to_string()))
111 }
112 }
113 }
114 Err(e) => {
115 warn!(ip = %ip, error = %e, "MaxMind lookup error");
116 Err(GeoLookupError::DatabaseError(e.to_string()))
117 }
118 }
119 }
120
121 fn database_type(&self) -> GeoDatabaseType {
122 GeoDatabaseType::MaxMind
123 }
124}
125
126pub struct Ip2LocationDatabase {
132 db: ip2location::DB,
133}
134
135impl Ip2LocationDatabase {
136 pub fn open(path: impl AsRef<Path>) -> Result<Self, GeoLookupError> {
138 let path = path.as_ref();
139 let db = ip2location::DB::from_file(path).map_err(|e| {
140 GeoLookupError::LoadError(format!(
141 "failed to open IP2Location database {:?}: {}",
142 path, e
143 ))
144 })?;
145
146 debug!(path = ?path, "Opened IP2Location GeoIP database");
147 Ok(Self { db })
148 }
149}
150
151impl GeoDatabase for Ip2LocationDatabase {
152 fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError> {
153 match self.db.ip_lookup(ip) {
154 Ok(record) => {
155 let country_code = match record {
157 ip2location::Record::LocationDb(loc) => {
158 loc.country.map(|c| c.short_name.to_string())
159 }
160 ip2location::Record::ProxyDb(proxy) => {
161 proxy.country.map(|c| c.short_name.to_string())
162 }
163 };
164 trace!(ip = %ip, country = ?country_code, "IP2Location lookup");
165 Ok(country_code)
166 }
167 Err(ip2location::error::Error::RecordNotFound) => {
168 trace!(ip = %ip, "IP not found in IP2Location database");
169 Ok(None)
170 }
171 Err(e) => {
172 warn!(ip = %ip, error = %e, "IP2Location lookup error");
173 Err(GeoLookupError::DatabaseError(e.to_string()))
174 }
175 }
176 }
177
178 fn database_type(&self) -> GeoDatabaseType {
179 GeoDatabaseType::Ip2Location
180 }
181}
182
183struct CachedCountry {
189 country_code: Option<String>,
191 cached_at: Instant,
193}
194
195#[derive(Debug, Clone)]
201pub struct GeoFilterResult {
202 pub allowed: bool,
204 pub country_code: Option<String>,
206 pub cache_hit: bool,
208 pub add_header: bool,
210 pub status_code: u16,
212 pub block_message: Option<String>,
214}
215
216pub struct GeoFilterPool {
222 database: RwLock<Arc<dyn GeoDatabase>>,
224 cache: DashMap<IpAddr, CachedCountry>,
226 config: GeoFilter,
228 countries_set: HashSet<String>,
230 cache_ttl: Duration,
232 database_path: PathBuf,
234 database_type: GeoDatabaseType,
236}
237
238impl GeoFilterPool {
239 pub fn new(config: GeoFilter) -> Result<Self, GeoLookupError> {
241 let db_type = config.database_type.clone().unwrap_or_else(|| {
243 if config.database_path.ends_with(".mmdb") {
244 GeoDatabaseType::MaxMind
245 } else {
246 GeoDatabaseType::Ip2Location
247 }
248 });
249
250 let database_path = PathBuf::from(&config.database_path);
251
252 let database: Arc<dyn GeoDatabase> = match db_type {
254 GeoDatabaseType::MaxMind => Arc::new(MaxMindDatabase::open(&config.database_path)?),
255 GeoDatabaseType::Ip2Location => {
256 Arc::new(Ip2LocationDatabase::open(&config.database_path)?)
257 }
258 };
259
260 let countries_set: HashSet<String> = config.countries.iter().cloned().collect();
262
263 let cache_ttl = Duration::from_secs(config.cache_ttl_secs);
264
265 debug!(
266 database_path = %config.database_path,
267 database_type = ?db_type,
268 action = ?config.action,
269 countries_count = countries_set.len(),
270 cache_ttl_secs = config.cache_ttl_secs,
271 "Created GeoFilterPool"
272 );
273
274 Ok(Self {
275 database: RwLock::new(database),
276 cache: DashMap::new(),
277 config,
278 countries_set,
279 cache_ttl,
280 database_path,
281 database_type: db_type,
282 })
283 }
284
285 pub fn reload_database(&self) -> Result<(), GeoLookupError> {
289 info!(
290 database_path = %self.database_path.display(),
291 database_type = ?self.database_type,
292 "Reloading geo database"
293 );
294
295 let new_database: Arc<dyn GeoDatabase> = match self.database_type {
297 GeoDatabaseType::MaxMind => Arc::new(MaxMindDatabase::open(&self.database_path)?),
298 GeoDatabaseType::Ip2Location => {
299 Arc::new(Ip2LocationDatabase::open(&self.database_path)?)
300 }
301 };
302
303 {
305 let mut db = self.database.write();
306 *db = new_database;
307 }
308
309 self.cache.clear();
311
312 info!(
313 database_path = %self.database_path.display(),
314 "Geo database reloaded successfully"
315 );
316
317 Ok(())
318 }
319
320 pub fn database_path(&self) -> &Path {
322 &self.database_path
323 }
324
325 pub fn check(&self, client_ip: &str) -> GeoFilterResult {
327 let ip: IpAddr = match client_ip.parse() {
329 Ok(ip) => ip,
330 Err(_) => {
331 warn!(client_ip = %client_ip, "Failed to parse client IP for geo filter");
332 return self.handle_failure();
333 }
334 };
335
336 let now = Instant::now();
338 if let Some(entry) = self.cache.get(&ip) {
339 if now.duration_since(entry.cached_at) < self.cache_ttl {
340 trace!(ip = %ip, country = ?entry.country_code, "Geo cache hit");
341 return self.evaluate(entry.country_code.clone(), true);
342 }
343 }
345
346 let database = self.database.read();
348 match database.lookup(ip) {
349 Ok(country_code) => {
350 self.cache.insert(
352 ip,
353 CachedCountry {
354 country_code: country_code.clone(),
355 cached_at: now,
356 },
357 );
358 self.evaluate(country_code, false)
359 }
360 Err(e) => {
361 warn!(ip = %ip, error = %e, "Geo lookup failed");
362 self.handle_failure()
363 }
364 }
365 }
366
367 fn evaluate(&self, country_code: Option<String>, cache_hit: bool) -> GeoFilterResult {
369 let in_list = country_code
370 .as_ref()
371 .map(|c| self.countries_set.contains(c))
372 .unwrap_or(false);
373
374 let allowed = match self.config.action {
375 GeoFilterAction::Block => {
376 !in_list
378 }
379 GeoFilterAction::Allow => {
380 if self.countries_set.is_empty() {
383 true
384 } else {
385 in_list
386 }
387 }
388 GeoFilterAction::LogOnly => {
389 true
391 }
392 };
393
394 trace!(
395 country = ?country_code,
396 in_list = in_list,
397 action = ?self.config.action,
398 allowed = allowed,
399 "Geo filter evaluation"
400 );
401
402 GeoFilterResult {
403 allowed,
404 country_code,
405 cache_hit,
406 add_header: self.config.add_country_header,
407 status_code: self.config.status_code,
408 block_message: self.config.block_message.clone(),
409 }
410 }
411
412 fn handle_failure(&self) -> GeoFilterResult {
414 let allowed = match self.config.on_failure {
415 GeoFailureMode::Open => true,
416 GeoFailureMode::Closed => false,
417 };
418
419 GeoFilterResult {
420 allowed,
421 country_code: None,
422 cache_hit: false,
423 add_header: false,
424 status_code: self.config.status_code,
425 block_message: self.config.block_message.clone(),
426 }
427 }
428
429 pub fn cache_stats(&self) -> (usize, usize) {
431 let now = Instant::now();
432 let total = self.cache.len();
433 let valid = self
434 .cache
435 .iter()
436 .filter(|e| now.duration_since(e.cached_at) < self.cache_ttl)
437 .count();
438 (total, valid)
439 }
440
441 pub fn clear_expired(&self) {
443 let now = Instant::now();
444 self.cache
445 .retain(|_, v| now.duration_since(v.cached_at) < self.cache_ttl);
446 }
447}
448
449pub struct GeoFilterManager {
455 filter_pools: DashMap<String, Arc<GeoFilterPool>>,
457}
458
459impl GeoFilterManager {
460 pub fn new() -> Self {
462 Self {
463 filter_pools: DashMap::new(),
464 }
465 }
466
467 pub fn register_filter(
469 &self,
470 filter_id: &str,
471 config: GeoFilter,
472 ) -> Result<(), GeoLookupError> {
473 let pool = GeoFilterPool::new(config)?;
474 self.filter_pools
475 .insert(filter_id.to_string(), Arc::new(pool));
476 debug!(filter_id = %filter_id, "Registered geo filter");
477 Ok(())
478 }
479
480 pub fn check(&self, filter_id: &str, client_ip: &str) -> Option<GeoFilterResult> {
482 self.filter_pools
483 .get(filter_id)
484 .map(|pool| pool.check(client_ip))
485 }
486
487 pub fn get_pool(&self, filter_id: &str) -> Option<Arc<GeoFilterPool>> {
489 self.filter_pools.get(filter_id).map(|r| r.clone())
490 }
491
492 pub fn has_filter(&self, filter_id: &str) -> bool {
494 self.filter_pools.contains_key(filter_id)
495 }
496
497 pub fn filter_ids(&self) -> Vec<String> {
499 self.filter_pools.iter().map(|r| r.key().clone()).collect()
500 }
501
502 pub fn clear_expired_caches(&self) {
504 for pool in self.filter_pools.iter() {
505 pool.clear_expired();
506 }
507 }
508
509 pub fn reload_filter(&self, filter_id: &str) -> Result<(), GeoLookupError> {
511 if let Some(pool) = self.filter_pools.get(filter_id) {
512 pool.reload_database()
513 } else {
514 Err(GeoLookupError::LoadError(format!(
515 "Filter '{}' not found",
516 filter_id
517 )))
518 }
519 }
520
521 pub fn reload_by_path(&self, path: &Path) -> Vec<(String, Result<(), GeoLookupError>)> {
523 let mut results = Vec::new();
524 for entry in self.filter_pools.iter() {
525 if entry.value().database_path() == path {
526 let filter_id = entry.key().clone();
527 let result = entry.value().reload_database();
528 results.push((filter_id, result));
529 }
530 }
531 results
532 }
533
534 pub fn database_paths(&self) -> Vec<(String, PathBuf)> {
536 self.filter_pools
537 .iter()
538 .map(|e| (e.key().clone(), e.value().database_path().to_path_buf()))
539 .collect()
540 }
541}
542
543impl Default for GeoFilterManager {
544 fn default() -> Self {
545 Self::new()
546 }
547}
548
549pub struct GeoDatabaseWatcher {
555 watcher: RwLock<Option<notify::RecommendedWatcher>>,
557 path_to_filters: RwLock<HashMap<PathBuf, Vec<String>>>,
559 manager: Arc<GeoFilterManager>,
561}
562
563impl GeoDatabaseWatcher {
564 pub fn new(manager: Arc<GeoFilterManager>) -> Self {
566 Self {
567 watcher: RwLock::new(None),
568 path_to_filters: RwLock::new(HashMap::new()),
569 manager,
570 }
571 }
572
573 pub fn start_watching(&self) -> Result<mpsc::Receiver<PathBuf>, GeoLookupError> {
575 let db_paths = self.manager.database_paths();
577 let mut path_map: HashMap<PathBuf, Vec<String>> = HashMap::new();
578 for (filter_id, path) in db_paths {
579 path_map
580 .entry(path)
581 .or_default()
582 .push(filter_id);
583 }
584
585 if path_map.is_empty() {
586 debug!("No geo databases to watch");
587 let (_tx, rx) = mpsc::channel(1);
588 return Ok(rx);
589 }
590
591 *self.path_to_filters.write() = path_map.clone();
593
594 let (tx, rx) = mpsc::channel::<PathBuf>(10);
596
597 let paths: Vec<PathBuf> = path_map.keys().cloned().collect();
599 let watcher = notify::recommended_watcher(move |event: Result<Event, notify::Error>| {
600 if let Ok(event) = event {
601 if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) {
602 for path in &event.paths {
603 let _ = tx.blocking_send(path.clone());
604 }
605 }
606 }
607 })
608 .map_err(|e| {
609 GeoLookupError::LoadError(format!("Failed to create file watcher: {}", e))
610 })?;
611
612 *self.watcher.write() = Some(watcher);
614
615 if let Some(ref mut watcher) = *self.watcher.write() {
617 for path in &paths {
618 if let Err(e) = watcher.watch(path, RecursiveMode::NonRecursive) {
619 warn!(
620 path = %path.display(),
621 error = %e,
622 "Failed to watch geo database file"
623 );
624 } else {
625 info!(
626 path = %path.display(),
627 "Watching geo database for changes"
628 );
629 }
630 }
631 }
632
633 Ok(rx)
634 }
635
636 pub fn handle_change(&self, path: &Path) {
638 let path_map = self.path_to_filters.read();
639 if let Some(filter_ids) = path_map.get(path) {
640 info!(
641 path = %path.display(),
642 filters = ?filter_ids,
643 "Geo database file changed, reloading"
644 );
645
646 for filter_id in filter_ids {
647 match self.manager.reload_filter(filter_id) {
648 Ok(()) => {
649 info!(
650 filter_id = %filter_id,
651 "Geo filter database reloaded successfully"
652 );
653 }
654 Err(e) => {
655 error!(
656 filter_id = %filter_id,
657 error = %e,
658 "Failed to reload geo filter database"
659 );
660 }
661 }
662 }
663 }
664 }
665
666 pub fn stop(&self) {
668 *self.watcher.write() = None;
669 info!("Stopped watching geo database files");
670 }
671}
672
673#[cfg(test)]
678mod tests {
679 use super::*;
680
681 #[test]
682 fn test_geo_lookup_error_display() {
683 let err = GeoLookupError::InvalidIp("not-an-ip".to_string());
684 assert!(err.to_string().contains("invalid IP"));
685
686 let err = GeoLookupError::DatabaseError("db error".to_string());
687 assert!(err.to_string().contains("database error"));
688
689 let err = GeoLookupError::LoadError("load error".to_string());
690 assert!(err.to_string().contains("failed to load"));
691 }
692
693 #[test]
694 fn test_geo_filter_result_default() {
695 let result = GeoFilterResult {
696 allowed: true,
697 country_code: Some("US".to_string()),
698 cache_hit: false,
699 add_header: true,
700 status_code: 403,
701 block_message: None,
702 };
703
704 assert!(result.allowed);
705 assert_eq!(result.country_code, Some("US".to_string()));
706 assert!(!result.cache_hit);
707 assert!(result.add_header);
708 }
709
710 #[test]
711 fn test_geo_filter_manager_new() {
712 let manager = GeoFilterManager::new();
713 assert!(manager.filter_ids().is_empty());
714 assert!(!manager.has_filter("test"));
715 }
716
717 }