tower_ipfilter/
ip_filter.rs

1use std::{marker::PhantomData, net::IpAddr};
2
3use dashmap::DashMap;
4use ipnetwork::IpNetwork;
5
6use crate::{
7    body::{create_ip_address_denied_response, IpResponseBody},
8    geo_filter::IpAddrExt,
9    network_filter_service::NetworkFilter,
10    types::Mode,
11};
12
13#[derive(Debug, Clone)]
14pub struct IpMetaData {
15    pub reason: String,
16    pub date: String,
17}
18
19#[derive(Debug, Clone)]
20pub enum V4 {}
21
22#[derive(Debug, Clone)]
23pub enum V6 {}
24
25pub trait IpType {}
26
27impl IpType for V4 {}
28impl IpType for V6 {}
29
30#[derive(Debug, Clone)]
31pub struct IpFilter<S: IpType> {
32    pub addresses: DashMap<IpAddr, IpMetaData>,
33    pub networks: DashMap<IpNetwork, IpMetaData>,
34    pub mode: Mode,
35    marker: PhantomData<S>,
36}
37
38impl<S: IpType> IpFilter<S> {
39    pub fn new(mode: Mode) -> Self {
40        Self {
41            networks: DashMap::new(),
42            addresses: DashMap::new(),
43            mode,
44            marker: PhantomData,
45        }
46    }
47    pub async fn add_ip(&self, ip: IpAddr, reason: String, date: String) {
48        self.addresses.insert(ip, IpMetaData { reason, date });
49    }
50    pub async fn add_network(&self, network: IpNetwork, reason: String, date: String) {
51        self.networks.insert(network, IpMetaData { reason, date });
52    }
53
54    async fn is_ip_blocked(&self, ip: &IpAddr) -> bool {
55        if self.addresses.contains_key(ip) {
56            match self.mode {
57                Mode::BlackList => return true,
58                Mode::WhiteList => return false,
59            }
60        } else {
61            for kv in self.networks.iter() {
62                let (network, _) = kv.pair();
63                if network.contains(*ip) {
64                    match self.mode {
65                        Mode::BlackList => return true,
66                        Mode::WhiteList => return false,
67                    }
68                }
69            }
70
71            match self.mode {
72                Mode::BlackList => return false,
73                Mode::WhiteList => return true,
74            }
75        }
76    }
77
78    async fn block_ip(&self, ip: impl IpAddrExt, network: bool) {
79        if network {
80            match ip.to_network() {
81                IpNetwork::V4(ip) => {
82                    self.add_network(
83                        IpNetwork::V4(ip),
84                        "Blocked".to_string(),
85                        "2021-09-01".to_string(),
86                    )
87                    .await;
88                }
89                IpNetwork::V6(ip) => {
90                    self.add_network(
91                        IpNetwork::V6(ip),
92                        "Blocked".to_string(),
93                        "2021-09-01".to_string(),
94                    )
95                    .await;
96                }
97            }
98        } else {
99            match ip.to_ip_addr() {
100                IpAddr::V4(ip) => {
101                    self.add_ip(
102                        IpAddr::V4(ip),
103                        "Blocked".to_string(),
104                        "2021-09-01".to_string(),
105                    )
106                    .await;
107                }
108                IpAddr::V6(ip) => {
109                    self.add_ip(
110                        IpAddr::V6(ip),
111                        "Blocked".to_string(),
112                        "2021-09-01".to_string(),
113                    )
114                    .await;
115                }
116            }
117        }
118    }
119
120    async fn unblock_ip(&self, ip: impl IpAddrExt, network: bool) {
121        if network {
122            match ip.to_network() {
123                IpNetwork::V4(ip) => {
124                    self.networks.remove(&IpNetwork::V4(ip));
125                }
126                IpNetwork::V6(ip) => {
127                    self.networks.remove(&IpNetwork::V6(ip));
128                }
129            }
130        } else {
131            match ip.to_ip_addr() {
132                IpAddr::V4(ip) => {
133                    self.addresses.remove(&IpAddr::V4(ip));
134                }
135                IpAddr::V6(ip) => {
136                    self.addresses.remove(&IpAddr::V6(ip));
137                }
138            }
139        }
140    }
141}
142
143impl NetworkFilter for IpFilter<V4> {
144    fn block(
145        &self,
146        ip: impl IpAddrExt,
147        network: bool,
148    ) -> impl std::future::Future<Output = ()> + Send {
149        async move {
150            if ip.is_ipv4() {
151                self.block_ip(ip, network).await;
152            } else {
153                panic!("Invalid IP address");
154            }
155        }
156    }
157
158    fn unblock(
159        &self,
160        ip: impl IpAddrExt,
161        network: bool,
162    ) -> impl std::future::Future<Output = ()> + Send {
163        async move {
164            if ip.is_ipv4() {
165                self.unblock_ip(ip, network).await;
166            } else {
167                panic!("Invalid IP address");
168            }
169        }
170    }
171
172    fn is_blocked(&self, ip: impl IpAddrExt) -> impl std::future::Future<Output = bool> + Send {
173        async move {
174            if ip.is_ipv4() {
175                self.is_ip_blocked(&ip.to_ip_addr()).await
176            } else {
177                panic!("Invalid IP address");
178            }
179        }
180    }
181
182    fn to_denied_response<T: http_body::Body>(&self) -> http::Response<IpResponseBody<T>> {
183        create_ip_address_denied_response()
184    }
185}
186
187impl NetworkFilter for IpFilter<V6> {
188    fn block(
189        &self,
190        ip: impl IpAddrExt,
191        network: bool,
192    ) -> impl std::future::Future<Output = ()> + Send {
193        async move {
194            if !ip.is_ipv4() {
195                self.block_ip(ip, network).await;
196            } else {
197                panic!("Invalid IP address");
198            }
199        }
200    }
201
202    fn unblock(
203        &self,
204        ip: impl IpAddrExt,
205        network: bool,
206    ) -> impl std::future::Future<Output = ()> + Send {
207        async move {
208            if !ip.is_ipv4() {
209                self.unblock_ip(ip, network).await;
210            } else {
211                panic!("Invalid IP address");
212            }
213        }
214    }
215
216    fn is_blocked(&self, ip: impl IpAddrExt) -> impl std::future::Future<Output = bool> + Send {
217        async move {
218            if !ip.is_ipv4() {
219                self.is_ip_blocked(&ip.to_ip_addr()).await
220            } else {
221                panic!("Invalid IP address");
222            }
223        }
224    }
225
226    fn to_denied_response<T: http_body::Body>(&self) -> http::Response<IpResponseBody<T>> {
227        create_ip_address_denied_response()
228    }
229}