tower_ipfilter/
lib.rs

1pub mod types;
2mod compress;
3mod extract;
4mod body;
5pub mod geo_filter;
6pub mod ip_filter;
7pub mod network_filter_service;
8pub mod connection_info_service;
9
10#[cfg(test)]
11mod tests {
12    use dashmap::DashMap;
13    use geo_filter::GeoIpv4Filter;
14    use ipnetwork::Ipv4Network;
15    use types::CountryLocation;
16
17    use super::*;
18    use std::net::Ipv4Addr;
19    use std::str::FromStr;
20
21    fn create_test_geo_ip_service() -> GeoIpv4Filter {
22        let ip_networks = DashMap::new();
23
24        // Add some test data
25        ip_networks.insert(Ipv4Network::from_str("192.168.0.0/16").unwrap(), CountryLocation {
26            geoname_id: 1,
27            locale_code: "EN".to_string(),
28            continent_code: "EU".to_string(),
29            continent_name: "Europe".to_string(),
30            country_iso_code: Some("GB".to_string()),
31            country_name: Some("United Kingdom".to_string()),
32            is_in_european_union: false,
33        });
34        ip_networks.insert(Ipv4Network::from_str("10.0.0.0/8").unwrap(), CountryLocation {
35            geoname_id: 2,
36            locale_code: "EN".to_string(),
37            continent_code: "NA".to_string(),
38            continent_name: "North America".to_string(),
39            country_iso_code: Some("US".to_string()),
40            country_name: Some("United States".to_string()),
41            is_in_european_union: false,
42        });
43        ip_networks.insert(Ipv4Network::from_str("172.16.0.0/12").unwrap(),  CountryLocation {
44            geoname_id: 3,
45            locale_code: "FR".to_string(),
46            continent_code: "EU".to_string(),
47            continent_name: "Europe".to_string(),
48            country_iso_code: Some("FR".to_string()),
49            country_name: Some("France".to_string()),
50            is_in_european_union: true,
51        });
52        //ip_networks.insert(Ipv4Network::from_str("2001:db8::/32").unwrap(), CountryLocation {
53        //    geoname_id: 4,
54        //    locale_code: "JA".to_string(),
55        //    continent_code: "AS".to_string(),
56        //    continent_name: "Asia".to_string(),
57        //    country_iso_code: Some("JP".to_string()),
58        //    country_name: Some("Japan".to_string()),
59        //    is_in_european_union: false,
60        //});
61
62
63        GeoIpv4Filter {
64            networks: ip_networks,
65            addresses: DashMap::new(),
66            countries: DashMap::new(),
67            mode: Default::default(),
68        }
69    }
70
71    #[tokio::test]
72    async fn test_get_country_for_ip() {
73        let service = create_test_geo_ip_service();
74
75        // Test IPv4 addresses
76        assert_eq!(
77            service.get_country_for_ip(&Ipv4Addr::from_str("192.168.1.1").unwrap()).await.unwrap().country_name,
78            Some("United Kingdom".to_string())
79        );
80        assert_eq!(
81            service.get_country_for_ip(&Ipv4Addr::from_str("10.0.0.1").unwrap()).await.unwrap().country_name,
82            Some("United States".to_string())
83        );
84        assert_eq!(
85            service.get_country_for_ip(&Ipv4Addr::from_str("172.16.0.1").unwrap()).await.unwrap().country_name,
86            Some("France".to_string())
87        );
88
89        // Test IPv6 address
90        //assert_eq!(
91        //    service.get_country_for_ip(&Ipv4Addr::from_str("2001:db8::1").unwrap()).await.unwrap().country_name,
92        //    Some("Japan".to_string())
93        //);
94//
95        //// Test IP address not in any network
96        //assert_eq!(
97        //    service.get_country_for_ip(&Ipv4Addr::from_str("8.8.8.8").unwrap()).await,
98        //    None
99        //);
100    }
101
102    #[tokio::test]
103    async fn test_get_country_for_ip_edge_cases() {
104        let service = create_test_geo_ip_service();
105
106        // Test edge of network
107        assert_eq!(
108            service.get_country_for_ip(&Ipv4Addr::from_str("192.168.255.255").unwrap()).await.unwrap().country_name,
109            Some("United Kingdom".to_string())
110        );
111
112        // Test start of network
113        assert_eq!(
114            service.get_country_for_ip(&Ipv4Addr::from_str("10.0.0.0").unwrap()).await.unwrap().country_name,
115            Some("United States".to_string())
116        );
117
118        // Test end of network
119        assert_eq!(
120            service.get_country_for_ip(&Ipv4Addr::from_str("10.255.255.255").unwrap()).await.unwrap().country_name,
121            Some("United States".to_string())
122        );
123    }
124
125    #[tokio::test]
126    
127    async fn test_blocklist() {
128        let service = create_test_geo_ip_service();
129
130        // Set up blocklist
131        service.set_countries(vec!["United States".to_string(), "France".to_string()]);
132
133        // Test blocked countries
134        assert!(service.is_country_blocked("United States").await);
135        assert!(service.is_country_blocked("France").await);
136        assert!(!service.is_country_blocked("United Kingdom").await);
137        assert!(!service.is_country_blocked("Japan").await);
138
139        // Test blocked IPs
140        assert!(service.is_ip_blocked(&Ipv4Addr::from_str("10.0.0.1").unwrap()).await); // US
141        assert!(service.is_ip_blocked(&Ipv4Addr::from_str("172.16.0.1").unwrap()).await); // France
142        assert!(!service.is_ip_blocked(&Ipv4Addr::from_str("192.168.1.1").unwrap()).await); // UK
143        //assert!(!service.is_ip_blocked(&Ipv4Addr::from_str("2001:db8::1").unwrap()).await); // Japan
144
145        // Test IP not in any network
146        assert!(!service.is_ip_blocked(&Ipv4Addr::from_str("8.8.8.8").unwrap()).await);
147
148        // Update blocklist
149        service.set_countries(vec!["Japan".to_string()]);
150
151        // Test updated blocklist
152        assert!(!service.is_ip_blocked(&Ipv4Addr::from_str("10.0.0.1").unwrap()).await); // US
153        //assert!(service.is_ip_blocked(&Ipv4Addr::from_str("2001:db8::1").unwrap()).await); // Japan
154    }
155}