tower_ipfilter/
geo_filter.rs

1use dashmap::DashMap;
2use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network};
3use tracing::info;
4
5use crate::{
6    body::{create_geo_access_denied_response, IpResponseBody}, compress::{load_compressed_data, save_compressed_data}, extract::extract_and_parse_csv, network_filter_service::NetworkFilter, types::{CountryLocation, Mode}
7};
8use std::{
9    error::Error,
10    net::{IpAddr, Ipv4Addr, Ipv6Addr},
11    path::{Path, PathBuf},
12};
13
14pub trait IpAddrExt: Sized + Send {
15    fn to_network(self) -> IpNetwork;
16    fn to_ip_addr(self) -> IpAddr;
17    fn is_ipv4(&self) -> bool;
18}
19
20impl IpAddrExt for Ipv4Addr {
21    fn to_ip_addr(self) -> IpAddr {
22        IpAddr::V4(self)
23    }
24    fn to_network(self) -> IpNetwork {
25        IpNetwork::V4(Ipv4Network::from(self))
26    }
27    fn is_ipv4(&self) -> bool {
28        true
29    }
30}
31
32impl IpAddrExt for Ipv6Addr {
33    fn to_ip_addr(self) -> IpAddr {
34        IpAddr::V6(self)
35    }
36    fn to_network(self) -> IpNetwork {
37        IpNetwork::V6(Ipv6Network::from(self))
38    }
39    fn is_ipv4(&self) -> bool {
40        false
41    }
42}
43
44impl IpAddrExt for IpAddr {
45    fn to_ip_addr(self) -> IpAddr {
46        self
47    }
48    fn to_network(self) -> IpNetwork {
49        match self {
50            IpAddr::V4(ip) => IpNetwork::V4(Ipv4Network::from(ip)),
51            IpAddr::V6(ip) => IpNetwork::V6(Ipv6Network::from(ip)),
52        }
53    }
54    fn is_ipv4(&self) -> bool {
55        match self {
56            IpAddr::V4(_) => true,
57            IpAddr::V6(_) => false,
58        }
59    }
60}
61
62impl IpAddrExt for Ipv4Network {
63    fn to_ip_addr(self) -> IpAddr {
64        IpAddr::V4(self.network())
65    }
66    fn to_network(self) -> IpNetwork {
67        IpNetwork::V4(self)
68    }
69    fn is_ipv4(&self) -> bool {
70        true
71    }
72}
73
74#[derive(Debug, Clone)]
75pub struct GeoIpv4Filter {
76    pub networks: DashMap<Ipv4Network, CountryLocation>,
77    pub addresses: DashMap<Ipv4Addr, CountryLocation>,
78    pub countries: DashMap<String, bool>,
79    pub mode: Mode,
80}
81
82impl GeoIpv4Filter {
83    pub fn new(mode: Mode, path_to_data: impl Into<PathBuf>) -> Result<Self, Box<dyn Error>> {
84        let data_path = Path::new("geo_ip_data.bin.gz");
85
86        let geo_data = if !data_path.exists() {
87            let data = extract_and_parse_csv(&path_to_data.into())?;
88            save_compressed_data(&data, data_path)?;
89            data
90        } else {
91            load_compressed_data(data_path)?
92        };
93
94        info!(
95            "Loaded {} ip blocks and {} country locations",
96            geo_data.ip_blocks.len(),
97            geo_data.country_locations.len()
98        );
99
100        let ip_country_map = DashMap::<Ipv4Network, CountryLocation>::new();
101
102        // add localhost
103        ip_country_map.insert(
104            Ipv4Network::from(Ipv4Addr::new(127, 0, 0, 1)),
105            CountryLocation {
106                geoname_id: 0,
107                locale_code: "NB".to_string(),
108                continent_code: "NA".to_string(),
109                continent_name: "Europe".to_string(),
110                country_iso_code: Some("NO".to_string()),
111                country_name: Some("Norway".to_string()),
112                is_in_european_union: true,
113            },
114        );
115
116        for block in geo_data.ip_blocks {
117            if let Some(geoname_id) = block.geoname_id {
118                if let Ok(network) = block.network.parse() {
119                    if let Some(country) = geo_data.country_locations.get(&geoname_id) {
120                        ip_country_map.insert(network, country.clone());
121                    } else {
122                        println!("No country found for geoname_id: {}", geoname_id);
123                    }
124                }
125            }
126        }
127
128        Ok(Self {
129            networks: ip_country_map,
130            addresses: DashMap::new(),
131            countries: DashMap::new(),
132            mode,
133        })
134    }
135
136    pub async fn get_country_for_ip(&self, ip: &Ipv4Addr) -> Option<CountryLocation> {
137        let mut country = None;
138
139        if let Some(location) = self.addresses.get(ip) {
140            return Some(location.clone());
141        }
142
143        for kv in self.networks.iter() {
144            let (network, location) = kv.pair();
145            if network.contains(*ip) {
146                country = Some(location.clone());
147                break;
148            }
149        }
150        country
151    }
152
153    pub async fn add_ip(&self, ip: Ipv4Addr) {
154        if let Some(country) = self.get_country_for_ip(&ip).await {
155            self.addresses.insert(ip, country.clone());
156        }
157    }
158
159    pub fn remove_ip(&self, ip: Ipv4Addr) {
160        self.addresses.remove(&ip);
161    }
162
163    pub async fn add_network(&self, network: Ipv4Network) {
164        if let Some(country) = self.get_country_for_ip(&network.network()).await {
165            self.networks.insert(network, country.clone());
166            tracing::info!("Added network: {} from country: {}", network, country.country_name.unwrap());
167        }
168    }
169
170    pub fn remove_network(&self, network: Ipv4Network) {
171        self.networks.remove(&network);
172    }
173
174    pub fn set_countries(&self, countries: Vec<String>) {
175        self.countries.clear();
176        tracing::info!("Setting countries: {:?}, mode: {}", countries, self.mode);
177        for country in countries {
178            self.countries.insert(country, true);
179            
180        }
181    }
182
183    pub async fn is_country_blocked(&self, country: &str) -> bool {
184        match self.mode {
185            Mode::BlackList => self.countries.contains_key(country),
186            Mode::WhiteList => !self.countries.contains_key(country),
187        }
188    }
189
190    pub async fn is_ip_blocked(&self, ip: &Ipv4Addr) -> bool {
191        if let Some(country) = self.get_country_for_ip(ip).await {
192            let name = country.country_name.unwrap();
193            let is_blocked = self.is_country_blocked(&name).await;
194            if is_blocked {
195                tracing::warn!("Blocked ip: {} from country: {}", ip, name);
196                return true;
197            } else {
198                tracing::debug!("Allowed ip: {} from country: {}", ip, name);
199                return false;
200            }
201            
202        } else {
203            false
204        }
205    }
206}
207
208impl NetworkFilter for GeoIpv4Filter {
209    fn block(
210        &self,
211        ip: impl IpAddrExt,
212        network: bool,
213    ) -> impl std::future::Future<Output = ()> + Send {
214        async move {
215            if network {
216                match ip.to_network() {
217                    IpNetwork::V4(ip) => {
218                        self.add_network(ip)
219                            .await;
220                    }
221                    _ => {}
222                }
223            } else {
224                match ip.to_ip_addr() {
225                    IpAddr::V4(ip) => {
226                        self.add_ip(ip)
227                            .await;
228                    }
229                    _ => {}
230                }
231            }
232        }
233    }
234
235    fn unblock(
236        &self,
237        ip: impl IpAddrExt,
238        network: bool,
239    ) -> impl std::future::Future<Output = ()> + Send {
240        async move {
241            if network {
242                match ip.to_network() {
243                    IpNetwork::V4(ip) => {
244                        self.remove_network(ip);
245                    }
246                    _ => {}
247                }
248            } else {
249                match ip.to_ip_addr() {
250                    IpAddr::V4(ip) => {
251                        self.remove_ip(ip);
252                    }
253                    _ => {}
254                }
255            }
256        }
257    }
258
259    fn is_blocked(&self, ip: impl IpAddrExt) -> impl std::future::Future<Output = bool> + Send {
260        async move {
261            match ip.to_ip_addr() {
262                IpAddr::V4(ip) => self.is_ip_blocked(&ip).await,
263                _ => false,
264            }
265        }
266    }
267
268    fn to_denied_response<T: http_body::Body>(&self) -> http::Response<IpResponseBody<T>>{
269        create_geo_access_denied_response()
270    }
271}