tower_ipfilter/
geo_filter.rs1use dashmap::DashMap;
2use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network};
3use tracing::info;
4
5use crate::{
6 body::{create_geo_access_denied_response, IpResponseBody}, compress::{load_compressed_data, save_compressed_data}, extract::extract_and_parse_csv, network_filter_service::NetworkFilter, types::{CountryLocation, Mode}
7};
8use std::{
9 error::Error,
10 net::{IpAddr, Ipv4Addr, Ipv6Addr},
11 path::{Path, PathBuf},
12};
13
14pub trait IpAddrExt: Sized + Send {
15 fn to_network(self) -> IpNetwork;
16 fn to_ip_addr(self) -> IpAddr;
17 fn is_ipv4(&self) -> bool;
18}
19
20impl IpAddrExt for Ipv4Addr {
21 fn to_ip_addr(self) -> IpAddr {
22 IpAddr::V4(self)
23 }
24 fn to_network(self) -> IpNetwork {
25 IpNetwork::V4(Ipv4Network::from(self))
26 }
27 fn is_ipv4(&self) -> bool {
28 true
29 }
30}
31
32impl IpAddrExt for Ipv6Addr {
33 fn to_ip_addr(self) -> IpAddr {
34 IpAddr::V6(self)
35 }
36 fn to_network(self) -> IpNetwork {
37 IpNetwork::V6(Ipv6Network::from(self))
38 }
39 fn is_ipv4(&self) -> bool {
40 false
41 }
42}
43
44impl IpAddrExt for IpAddr {
45 fn to_ip_addr(self) -> IpAddr {
46 self
47 }
48 fn to_network(self) -> IpNetwork {
49 match self {
50 IpAddr::V4(ip) => IpNetwork::V4(Ipv4Network::from(ip)),
51 IpAddr::V6(ip) => IpNetwork::V6(Ipv6Network::from(ip)),
52 }
53 }
54 fn is_ipv4(&self) -> bool {
55 match self {
56 IpAddr::V4(_) => true,
57 IpAddr::V6(_) => false,
58 }
59 }
60}
61
62impl IpAddrExt for Ipv4Network {
63 fn to_ip_addr(self) -> IpAddr {
64 IpAddr::V4(self.network())
65 }
66 fn to_network(self) -> IpNetwork {
67 IpNetwork::V4(self)
68 }
69 fn is_ipv4(&self) -> bool {
70 true
71 }
72}
73
74#[derive(Debug, Clone)]
75pub struct GeoIpv4Filter {
76 pub networks: DashMap<Ipv4Network, CountryLocation>,
77 pub addresses: DashMap<Ipv4Addr, CountryLocation>,
78 pub countries: DashMap<String, bool>,
79 pub mode: Mode,
80}
81
82impl GeoIpv4Filter {
83 pub fn new(mode: Mode, path_to_data: impl Into<PathBuf>) -> Result<Self, Box<dyn Error>> {
84 let data_path = Path::new("geo_ip_data.bin.gz");
85
86 let geo_data = if !data_path.exists() {
87 let data = extract_and_parse_csv(&path_to_data.into())?;
88 save_compressed_data(&data, data_path)?;
89 data
90 } else {
91 load_compressed_data(data_path)?
92 };
93
94 info!(
95 "Loaded {} ip blocks and {} country locations",
96 geo_data.ip_blocks.len(),
97 geo_data.country_locations.len()
98 );
99
100 let ip_country_map = DashMap::<Ipv4Network, CountryLocation>::new();
101
102 ip_country_map.insert(
104 Ipv4Network::from(Ipv4Addr::new(127, 0, 0, 1)),
105 CountryLocation {
106 geoname_id: 0,
107 locale_code: "NB".to_string(),
108 continent_code: "NA".to_string(),
109 continent_name: "Europe".to_string(),
110 country_iso_code: Some("NO".to_string()),
111 country_name: Some("Norway".to_string()),
112 is_in_european_union: true,
113 },
114 );
115
116 for block in geo_data.ip_blocks {
117 if let Some(geoname_id) = block.geoname_id {
118 if let Ok(network) = block.network.parse() {
119 if let Some(country) = geo_data.country_locations.get(&geoname_id) {
120 ip_country_map.insert(network, country.clone());
121 } else {
122 println!("No country found for geoname_id: {}", geoname_id);
123 }
124 }
125 }
126 }
127
128 Ok(Self {
129 networks: ip_country_map,
130 addresses: DashMap::new(),
131 countries: DashMap::new(),
132 mode,
133 })
134 }
135
136 pub async fn get_country_for_ip(&self, ip: &Ipv4Addr) -> Option<CountryLocation> {
137 let mut country = None;
138
139 if let Some(location) = self.addresses.get(ip) {
140 return Some(location.clone());
141 }
142
143 for kv in self.networks.iter() {
144 let (network, location) = kv.pair();
145 if network.contains(*ip) {
146 country = Some(location.clone());
147 break;
148 }
149 }
150 country
151 }
152
153 pub async fn add_ip(&self, ip: Ipv4Addr) {
154 if let Some(country) = self.get_country_for_ip(&ip).await {
155 self.addresses.insert(ip, country.clone());
156 }
157 }
158
159 pub fn remove_ip(&self, ip: Ipv4Addr) {
160 self.addresses.remove(&ip);
161 }
162
163 pub async fn add_network(&self, network: Ipv4Network) {
164 if let Some(country) = self.get_country_for_ip(&network.network()).await {
165 self.networks.insert(network, country.clone());
166 tracing::info!("Added network: {} from country: {}", network, country.country_name.unwrap());
167 }
168 }
169
170 pub fn remove_network(&self, network: Ipv4Network) {
171 self.networks.remove(&network);
172 }
173
174 pub fn set_countries(&self, countries: Vec<String>) {
175 self.countries.clear();
176 tracing::info!("Setting countries: {:?}, mode: {}", countries, self.mode);
177 for country in countries {
178 self.countries.insert(country, true);
179
180 }
181 }
182
183 pub async fn is_country_blocked(&self, country: &str) -> bool {
184 match self.mode {
185 Mode::BlackList => self.countries.contains_key(country),
186 Mode::WhiteList => !self.countries.contains_key(country),
187 }
188 }
189
190 pub async fn is_ip_blocked(&self, ip: &Ipv4Addr) -> bool {
191 if let Some(country) = self.get_country_for_ip(ip).await {
192 let name = country.country_name.unwrap();
193 let is_blocked = self.is_country_blocked(&name).await;
194 if is_blocked {
195 tracing::warn!("Blocked ip: {} from country: {}", ip, name);
196 return true;
197 } else {
198 tracing::debug!("Allowed ip: {} from country: {}", ip, name);
199 return false;
200 }
201
202 } else {
203 false
204 }
205 }
206}
207
208impl NetworkFilter for GeoIpv4Filter {
209 fn block(
210 &self,
211 ip: impl IpAddrExt,
212 network: bool,
213 ) -> impl std::future::Future<Output = ()> + Send {
214 async move {
215 if network {
216 match ip.to_network() {
217 IpNetwork::V4(ip) => {
218 self.add_network(ip)
219 .await;
220 }
221 _ => {}
222 }
223 } else {
224 match ip.to_ip_addr() {
225 IpAddr::V4(ip) => {
226 self.add_ip(ip)
227 .await;
228 }
229 _ => {}
230 }
231 }
232 }
233 }
234
235 fn unblock(
236 &self,
237 ip: impl IpAddrExt,
238 network: bool,
239 ) -> impl std::future::Future<Output = ()> + Send {
240 async move {
241 if network {
242 match ip.to_network() {
243 IpNetwork::V4(ip) => {
244 self.remove_network(ip);
245 }
246 _ => {}
247 }
248 } else {
249 match ip.to_ip_addr() {
250 IpAddr::V4(ip) => {
251 self.remove_ip(ip);
252 }
253 _ => {}
254 }
255 }
256 }
257 }
258
259 fn is_blocked(&self, ip: impl IpAddrExt) -> impl std::future::Future<Output = bool> + Send {
260 async move {
261 match ip.to_ip_addr() {
262 IpAddr::V4(ip) => self.is_ip_blocked(&ip).await,
263 _ => false,
264 }
265 }
266 }
267
268 fn to_denied_response<T: http_body::Body>(&self) -> http::Response<IpResponseBody<T>>{
269 create_geo_access_denied_response()
270 }
271}