Skip to main content

synapse_pingora/horizon/
blocklist.rs

1//! Blocklist cache for fast IP and fingerprint lookups.
2
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::sync::atomic::{AtomicU64, Ordering};
6
7/// Type of block entry.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
10pub enum BlockType {
11    Ip,
12    Fingerprint,
13}
14
15/// A blocklist entry.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(rename_all = "camelCase")]
18pub struct BlocklistEntry {
19    pub block_type: BlockType,
20    pub indicator: String,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub expires_at: Option<String>,
23    pub source: String,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub reason: Option<String>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub created_at: Option<String>,
28}
29
30/// An incremental blocklist update.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(rename_all = "camelCase")]
33pub struct BlocklistUpdate {
34    pub action: BlocklistAction,
35    pub block_type: BlockType,
36    pub indicator: String,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub source: Option<String>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub reason: Option<String>,
41}
42
43/// Blocklist update action.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "lowercase")]
46pub enum BlocklistAction {
47    Add,
48    Remove,
49}
50
51/// High-performance blocklist cache with O(1) lookups.
52///
53/// Uses DashMap for lock-free concurrent access.
54pub struct BlocklistCache {
55    /// IP blocklist
56    ips: DashMap<String, BlocklistEntry>,
57    /// Fingerprint blocklist
58    fingerprints: DashMap<String, BlocklistEntry>,
59    /// Current sequence ID from hub
60    sequence_id: AtomicU64,
61}
62
63impl Default for BlocklistCache {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl BlocklistCache {
70    /// Create a new empty blocklist cache.
71    pub fn new() -> Self {
72        Self {
73            ips: DashMap::new(),
74            fingerprints: DashMap::new(),
75            sequence_id: AtomicU64::new(0),
76        }
77    }
78
79    /// Check if an IP is blocked.
80    ///
81    /// This is an O(1) lookup.
82    #[inline]
83    pub fn is_ip_blocked(&self, ip: &str) -> bool {
84        self.ips.contains_key(ip)
85    }
86
87    /// Check if a fingerprint is blocked.
88    ///
89    /// This is an O(1) lookup.
90    #[inline]
91    pub fn is_fingerprint_blocked(&self, fingerprint: &str) -> bool {
92        self.fingerprints.contains_key(fingerprint)
93    }
94
95    /// Check if either IP or fingerprint is blocked.
96    #[inline]
97    pub fn is_blocked(&self, ip: Option<&str>, fingerprint: Option<&str>) -> bool {
98        if let Some(ip) = ip {
99            if self.is_ip_blocked(ip) {
100                return true;
101            }
102        }
103        if let Some(fp) = fingerprint {
104            if self.is_fingerprint_blocked(fp) {
105                return true;
106            }
107        }
108        false
109    }
110
111    /// Get an IP block entry.
112    pub fn get_ip(&self, ip: &str) -> Option<BlocklistEntry> {
113        self.ips.get(ip).map(|r| r.value().clone())
114    }
115
116    /// Get a fingerprint block entry.
117    pub fn get_fingerprint(&self, fingerprint: &str) -> Option<BlocklistEntry> {
118        self.fingerprints
119            .get(fingerprint)
120            .map(|r| r.value().clone())
121    }
122
123    /// Add a blocklist entry.
124    pub fn add(&self, entry: BlocklistEntry) {
125        match entry.block_type {
126            BlockType::Ip => {
127                self.ips.insert(entry.indicator.clone(), entry);
128            }
129            BlockType::Fingerprint => {
130                self.fingerprints.insert(entry.indicator.clone(), entry);
131            }
132        }
133    }
134
135    /// Remove a blocklist entry.
136    pub fn remove(&self, block_type: BlockType, indicator: &str) {
137        match block_type {
138            BlockType::Ip => {
139                self.ips.remove(indicator);
140            }
141            BlockType::Fingerprint => {
142                self.fingerprints.remove(indicator);
143            }
144        }
145    }
146
147    /// Load a full blocklist snapshot from the hub.
148    pub fn load_snapshot(&self, entries: Vec<BlocklistEntry>, sequence_id: u64) {
149        // Clear existing entries
150        self.ips.clear();
151        self.fingerprints.clear();
152
153        // Load new entries
154        for entry in entries {
155            self.add(entry);
156        }
157
158        self.sequence_id.store(sequence_id, Ordering::SeqCst);
159    }
160
161    /// Apply incremental updates from the hub.
162    pub fn apply_updates(&self, updates: Vec<BlocklistUpdate>, sequence_id: u64) {
163        for update in updates {
164            match update.action {
165                BlocklistAction::Add => {
166                    self.add(BlocklistEntry {
167                        block_type: update.block_type,
168                        indicator: update.indicator,
169                        expires_at: None,
170                        source: update.source.unwrap_or_else(|| "hub".to_string()),
171                        reason: update.reason,
172                        created_at: None,
173                    });
174                }
175                BlocklistAction::Remove => {
176                    self.remove(update.block_type, &update.indicator);
177                }
178            }
179        }
180
181        self.sequence_id.store(sequence_id, Ordering::SeqCst);
182    }
183
184    /// Get the total blocklist size.
185    pub fn size(&self) -> usize {
186        self.ips.len() + self.fingerprints.len()
187    }
188
189    /// Get the IP blocklist size.
190    pub fn ip_count(&self) -> usize {
191        self.ips.len()
192    }
193
194    /// Get the fingerprint blocklist size.
195    pub fn fingerprint_count(&self) -> usize {
196        self.fingerprints.len()
197    }
198
199    /// Get the current sequence ID.
200    pub fn sequence_id(&self) -> u64 {
201        self.sequence_id.load(Ordering::SeqCst)
202    }
203
204    /// Clear all entries.
205    pub fn clear(&self) {
206        self.ips.clear();
207        self.fingerprints.clear();
208        self.sequence_id.store(0, Ordering::SeqCst);
209    }
210
211    /// Get all IP entries.
212    pub fn all_ips(&self) -> Vec<BlocklistEntry> {
213        self.ips.iter().map(|r| r.value().clone()).collect()
214    }
215
216    /// Get all fingerprint entries.
217    pub fn all_fingerprints(&self) -> Vec<BlocklistEntry> {
218        self.fingerprints
219            .iter()
220            .map(|r| r.value().clone())
221            .collect()
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_ip_blocking() {
231        let cache = BlocklistCache::new();
232
233        cache.add(BlocklistEntry {
234            block_type: BlockType::Ip,
235            indicator: "192.168.1.100".to_string(),
236            expires_at: None,
237            source: "test".to_string(),
238            reason: None,
239            created_at: None,
240        });
241
242        assert!(cache.is_ip_blocked("192.168.1.100"));
243        assert!(!cache.is_ip_blocked("192.168.1.101"));
244    }
245
246    #[test]
247    fn test_fingerprint_blocking() {
248        let cache = BlocklistCache::new();
249
250        cache.add(BlocklistEntry {
251            block_type: BlockType::Fingerprint,
252            indicator: "t13d1516h2_abc123".to_string(),
253            expires_at: None,
254            source: "test".to_string(),
255            reason: None,
256            created_at: None,
257        });
258
259        assert!(cache.is_fingerprint_blocked("t13d1516h2_abc123"));
260        assert!(!cache.is_fingerprint_blocked("t13d1516h2_def456"));
261    }
262
263    #[test]
264    fn test_is_blocked_combined() {
265        let cache = BlocklistCache::new();
266
267        cache.add(BlocklistEntry {
268            block_type: BlockType::Ip,
269            indicator: "10.0.0.1".to_string(),
270            expires_at: None,
271            source: "test".to_string(),
272            reason: None,
273            created_at: None,
274        });
275
276        assert!(cache.is_blocked(Some("10.0.0.1"), None));
277        assert!(cache.is_blocked(Some("10.0.0.1"), Some("fp123")));
278        assert!(!cache.is_blocked(Some("10.0.0.2"), Some("fp123")));
279        assert!(!cache.is_blocked(None, None));
280    }
281
282    #[test]
283    fn test_load_snapshot() {
284        let cache = BlocklistCache::new();
285
286        // Add some initial entries
287        cache.add(BlocklistEntry {
288            block_type: BlockType::Ip,
289            indicator: "old-ip".to_string(),
290            expires_at: None,
291            source: "old".to_string(),
292            reason: None,
293            created_at: None,
294        });
295
296        // Load snapshot (should replace)
297        cache.load_snapshot(
298            vec![BlocklistEntry {
299                block_type: BlockType::Ip,
300                indicator: "new-ip".to_string(),
301                expires_at: None,
302                source: "snapshot".to_string(),
303                reason: None,
304                created_at: None,
305            }],
306            42,
307        );
308
309        assert!(!cache.is_ip_blocked("old-ip"));
310        assert!(cache.is_ip_blocked("new-ip"));
311        assert_eq!(cache.sequence_id(), 42);
312    }
313
314    #[test]
315    fn test_apply_updates() {
316        let cache = BlocklistCache::new();
317
318        cache.apply_updates(
319            vec![
320                BlocklistUpdate {
321                    action: BlocklistAction::Add,
322                    block_type: BlockType::Ip,
323                    indicator: "10.0.0.1".to_string(),
324                    source: Some("hub".to_string()),
325                    reason: None,
326                },
327                BlocklistUpdate {
328                    action: BlocklistAction::Add,
329                    block_type: BlockType::Fingerprint,
330                    indicator: "fp1".to_string(),
331                    source: None,
332                    reason: Some("malicious".to_string()),
333                },
334            ],
335            100,
336        );
337
338        assert!(cache.is_ip_blocked("10.0.0.1"));
339        assert!(cache.is_fingerprint_blocked("fp1"));
340        assert_eq!(cache.size(), 2);
341        assert_eq!(cache.sequence_id(), 100);
342
343        // Remove update
344        cache.apply_updates(
345            vec![BlocklistUpdate {
346                action: BlocklistAction::Remove,
347                block_type: BlockType::Ip,
348                indicator: "10.0.0.1".to_string(),
349                source: None,
350                reason: None,
351            }],
352            101,
353        );
354
355        assert!(!cache.is_ip_blocked("10.0.0.1"));
356        assert_eq!(cache.size(), 1);
357    }
358}