1use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::sync::atomic::{AtomicU64, Ordering};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
10pub enum BlockType {
11 Ip,
12 Fingerprint,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(rename_all = "camelCase")]
18pub struct BlocklistEntry {
19 pub block_type: BlockType,
20 pub indicator: String,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub expires_at: Option<String>,
23 pub source: String,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub reason: Option<String>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub created_at: Option<String>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(rename_all = "camelCase")]
33pub struct BlocklistUpdate {
34 pub action: BlocklistAction,
35 pub block_type: BlockType,
36 pub indicator: String,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub source: Option<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub reason: Option<String>,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "lowercase")]
46pub enum BlocklistAction {
47 Add,
48 Remove,
49}
50
51pub struct BlocklistCache {
55 ips: DashMap<String, BlocklistEntry>,
57 fingerprints: DashMap<String, BlocklistEntry>,
59 sequence_id: AtomicU64,
61}
62
63impl Default for BlocklistCache {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl BlocklistCache {
70 pub fn new() -> Self {
72 Self {
73 ips: DashMap::new(),
74 fingerprints: DashMap::new(),
75 sequence_id: AtomicU64::new(0),
76 }
77 }
78
79 #[inline]
83 pub fn is_ip_blocked(&self, ip: &str) -> bool {
84 self.ips.contains_key(ip)
85 }
86
87 #[inline]
91 pub fn is_fingerprint_blocked(&self, fingerprint: &str) -> bool {
92 self.fingerprints.contains_key(fingerprint)
93 }
94
95 #[inline]
97 pub fn is_blocked(&self, ip: Option<&str>, fingerprint: Option<&str>) -> bool {
98 if let Some(ip) = ip {
99 if self.is_ip_blocked(ip) {
100 return true;
101 }
102 }
103 if let Some(fp) = fingerprint {
104 if self.is_fingerprint_blocked(fp) {
105 return true;
106 }
107 }
108 false
109 }
110
111 pub fn get_ip(&self, ip: &str) -> Option<BlocklistEntry> {
113 self.ips.get(ip).map(|r| r.value().clone())
114 }
115
116 pub fn get_fingerprint(&self, fingerprint: &str) -> Option<BlocklistEntry> {
118 self.fingerprints
119 .get(fingerprint)
120 .map(|r| r.value().clone())
121 }
122
123 pub fn add(&self, entry: BlocklistEntry) {
125 match entry.block_type {
126 BlockType::Ip => {
127 self.ips.insert(entry.indicator.clone(), entry);
128 }
129 BlockType::Fingerprint => {
130 self.fingerprints.insert(entry.indicator.clone(), entry);
131 }
132 }
133 }
134
135 pub fn remove(&self, block_type: BlockType, indicator: &str) {
137 match block_type {
138 BlockType::Ip => {
139 self.ips.remove(indicator);
140 }
141 BlockType::Fingerprint => {
142 self.fingerprints.remove(indicator);
143 }
144 }
145 }
146
147 pub fn load_snapshot(&self, entries: Vec<BlocklistEntry>, sequence_id: u64) {
149 self.ips.clear();
151 self.fingerprints.clear();
152
153 for entry in entries {
155 self.add(entry);
156 }
157
158 self.sequence_id.store(sequence_id, Ordering::SeqCst);
159 }
160
161 pub fn apply_updates(&self, updates: Vec<BlocklistUpdate>, sequence_id: u64) {
163 for update in updates {
164 match update.action {
165 BlocklistAction::Add => {
166 self.add(BlocklistEntry {
167 block_type: update.block_type,
168 indicator: update.indicator,
169 expires_at: None,
170 source: update.source.unwrap_or_else(|| "hub".to_string()),
171 reason: update.reason,
172 created_at: None,
173 });
174 }
175 BlocklistAction::Remove => {
176 self.remove(update.block_type, &update.indicator);
177 }
178 }
179 }
180
181 self.sequence_id.store(sequence_id, Ordering::SeqCst);
182 }
183
184 pub fn size(&self) -> usize {
186 self.ips.len() + self.fingerprints.len()
187 }
188
189 pub fn ip_count(&self) -> usize {
191 self.ips.len()
192 }
193
194 pub fn fingerprint_count(&self) -> usize {
196 self.fingerprints.len()
197 }
198
199 pub fn sequence_id(&self) -> u64 {
201 self.sequence_id.load(Ordering::SeqCst)
202 }
203
204 pub fn clear(&self) {
206 self.ips.clear();
207 self.fingerprints.clear();
208 self.sequence_id.store(0, Ordering::SeqCst);
209 }
210
211 pub fn all_ips(&self) -> Vec<BlocklistEntry> {
213 self.ips.iter().map(|r| r.value().clone()).collect()
214 }
215
216 pub fn all_fingerprints(&self) -> Vec<BlocklistEntry> {
218 self.fingerprints
219 .iter()
220 .map(|r| r.value().clone())
221 .collect()
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_ip_blocking() {
231 let cache = BlocklistCache::new();
232
233 cache.add(BlocklistEntry {
234 block_type: BlockType::Ip,
235 indicator: "192.168.1.100".to_string(),
236 expires_at: None,
237 source: "test".to_string(),
238 reason: None,
239 created_at: None,
240 });
241
242 assert!(cache.is_ip_blocked("192.168.1.100"));
243 assert!(!cache.is_ip_blocked("192.168.1.101"));
244 }
245
246 #[test]
247 fn test_fingerprint_blocking() {
248 let cache = BlocklistCache::new();
249
250 cache.add(BlocklistEntry {
251 block_type: BlockType::Fingerprint,
252 indicator: "t13d1516h2_abc123".to_string(),
253 expires_at: None,
254 source: "test".to_string(),
255 reason: None,
256 created_at: None,
257 });
258
259 assert!(cache.is_fingerprint_blocked("t13d1516h2_abc123"));
260 assert!(!cache.is_fingerprint_blocked("t13d1516h2_def456"));
261 }
262
263 #[test]
264 fn test_is_blocked_combined() {
265 let cache = BlocklistCache::new();
266
267 cache.add(BlocklistEntry {
268 block_type: BlockType::Ip,
269 indicator: "10.0.0.1".to_string(),
270 expires_at: None,
271 source: "test".to_string(),
272 reason: None,
273 created_at: None,
274 });
275
276 assert!(cache.is_blocked(Some("10.0.0.1"), None));
277 assert!(cache.is_blocked(Some("10.0.0.1"), Some("fp123")));
278 assert!(!cache.is_blocked(Some("10.0.0.2"), Some("fp123")));
279 assert!(!cache.is_blocked(None, None));
280 }
281
282 #[test]
283 fn test_load_snapshot() {
284 let cache = BlocklistCache::new();
285
286 cache.add(BlocklistEntry {
288 block_type: BlockType::Ip,
289 indicator: "old-ip".to_string(),
290 expires_at: None,
291 source: "old".to_string(),
292 reason: None,
293 created_at: None,
294 });
295
296 cache.load_snapshot(
298 vec![BlocklistEntry {
299 block_type: BlockType::Ip,
300 indicator: "new-ip".to_string(),
301 expires_at: None,
302 source: "snapshot".to_string(),
303 reason: None,
304 created_at: None,
305 }],
306 42,
307 );
308
309 assert!(!cache.is_ip_blocked("old-ip"));
310 assert!(cache.is_ip_blocked("new-ip"));
311 assert_eq!(cache.sequence_id(), 42);
312 }
313
314 #[test]
315 fn test_apply_updates() {
316 let cache = BlocklistCache::new();
317
318 cache.apply_updates(
319 vec![
320 BlocklistUpdate {
321 action: BlocklistAction::Add,
322 block_type: BlockType::Ip,
323 indicator: "10.0.0.1".to_string(),
324 source: Some("hub".to_string()),
325 reason: None,
326 },
327 BlocklistUpdate {
328 action: BlocklistAction::Add,
329 block_type: BlockType::Fingerprint,
330 indicator: "fp1".to_string(),
331 source: None,
332 reason: Some("malicious".to_string()),
333 },
334 ],
335 100,
336 );
337
338 assert!(cache.is_ip_blocked("10.0.0.1"));
339 assert!(cache.is_fingerprint_blocked("fp1"));
340 assert_eq!(cache.size(), 2);
341 assert_eq!(cache.sequence_id(), 100);
342
343 cache.apply_updates(
345 vec![BlocklistUpdate {
346 action: BlocklistAction::Remove,
347 block_type: BlockType::Ip,
348 indicator: "10.0.0.1".to_string(),
349 source: None,
350 reason: None,
351 }],
352 101,
353 );
354
355 assert!(!cache.is_ip_blocked("10.0.0.1"));
356 assert_eq!(cache.size(), 1);
357 }
358}