Skip to main content

synapse_pingora/waf/
state.rs

1//! Stateful tracking for IP-based rate limiting and unique counting.
2
3use std::collections::HashMap;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6/// State store for per-IP tracking.
7#[derive(Default)]
8pub struct StateStore {
9    /// Unique values per (IP, key) pair.
10    unique_values: HashMap<(String, String), HashMap<String, u64>>,
11    /// Event counts per (IP, key) pair.
12    event_counts: HashMap<(String, String), Vec<u64>>,
13}
14
15impl StateStore {
16    /// Record unique values for an IP and return the current count.
17    pub fn record_unique_values(
18        &mut self,
19        ip: &str,
20        key: &str,
21        values: &[String],
22        timeframe_sec: u64,
23    ) -> usize {
24        let now = now_ms();
25        let window_ms = timeframe_sec.saturating_mul(1000).max(1);
26        let map_key = (ip.to_string(), key.to_string());
27        let entry = self.unique_values.entry(map_key).or_default();
28
29        for value in values {
30            let normalized = if value.len() > 256 {
31                value[..256].to_string()
32            } else {
33                value.clone()
34            };
35            entry.insert(normalized, now);
36        }
37
38        // Cleanup expired entries
39        entry.retain(|_, ts| now.saturating_sub(*ts) <= window_ms);
40        entry.len()
41    }
42
43    /// Get the current unique count for an IP.
44    pub fn get_unique_count(&mut self, ip: &str, key: &str, timeframe_sec: u64) -> usize {
45        let now = now_ms();
46        let window_ms = timeframe_sec.saturating_mul(1000).max(1);
47        let map_key = (ip.to_string(), key.to_string());
48        let Some(entry) = self.unique_values.get_mut(&map_key) else {
49            return 0;
50        };
51        entry.retain(|_, ts| now.saturating_sub(*ts) <= window_ms);
52        entry.len()
53    }
54
55    /// Record an event and return the current count.
56    pub fn record_event(&mut self, ip: &str, key: &str, timeframe_sec: u64) -> usize {
57        let now = now_ms();
58        let window_ms = timeframe_sec.saturating_mul(1000).max(1);
59        let map_key = (ip.to_string(), key.to_string());
60        let list = self.event_counts.entry(map_key).or_default();
61        list.push(now);
62
63        // Remove expired events
64        while let Some(first) = list.first().copied() {
65            if now.saturating_sub(first) > window_ms {
66                list.remove(0);
67            } else {
68                break;
69            }
70        }
71        list.len()
72    }
73
74    /// Clear all state (for testing).
75    #[cfg(test)]
76    #[allow(dead_code)]
77    pub fn clear(&mut self) {
78        self.unique_values.clear();
79        self.event_counts.clear();
80    }
81}
82
83/// Get current time in milliseconds.
84pub fn now_ms() -> u64 {
85    SystemTime::now()
86        .duration_since(UNIX_EPOCH)
87        .unwrap_or_default()
88        .as_millis() as u64
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    #[test]
96    fn test_unique_count() {
97        let mut store = StateStore::default();
98
99        let count = store.record_unique_values(
100            "192.168.1.1",
101            "test",
102            &["value1".to_string(), "value2".to_string()],
103            60,
104        );
105        assert_eq!(count, 2);
106
107        // Recording same values shouldn't increase count
108        let count = store.record_unique_values("192.168.1.1", "test", &["value1".to_string()], 60);
109        assert_eq!(count, 2);
110
111        // New value should increase count
112        let count = store.record_unique_values("192.168.1.1", "test", &["value3".to_string()], 60);
113        assert_eq!(count, 3);
114    }
115
116    #[test]
117    fn test_event_count() {
118        let mut store = StateStore::default();
119
120        let count = store.record_event("192.168.1.1", "test", 60);
121        assert_eq!(count, 1);
122
123        let count = store.record_event("192.168.1.1", "test", 60);
124        assert_eq!(count, 2);
125
126        // Different IP should have separate count
127        let count = store.record_event("192.168.1.2", "test", 60);
128        assert_eq!(count, 1);
129    }
130}