Skip to main content

smg_mesh/
rate_limit_window.rs

1//! Rate limit time window management
2//!
3//! Manages time windows for global rate limiting with periodic counter resets
4
5use std::{sync::Arc, time::Duration};
6
7use tokio::{sync::watch, time::interval};
8use tracing::{debug, info};
9
10use super::sync::MeshSyncManager;
11
12/// Rate limit window manager
13/// Handles periodic reset of rate limit counters for time window management
14pub struct RateLimitWindow {
15    sync_manager: Arc<MeshSyncManager>,
16    window_seconds: u64,
17}
18
19impl RateLimitWindow {
20    pub fn new(sync_manager: Arc<MeshSyncManager>, window_seconds: u64) -> Self {
21        Self {
22            sync_manager,
23            window_seconds,
24        }
25    }
26
27    /// Start the window reset task
28    /// This task periodically resets the global rate limit counter
29    ///
30    /// # Arguments
31    /// * `shutdown_rx` - A watch receiver that signals when to stop the task
32    pub async fn start_reset_task(self, mut shutdown_rx: watch::Receiver<bool>) {
33        let mut interval_timer = interval(Duration::from_secs(self.window_seconds));
34        info!(
35            "Starting rate limit window reset task with {}s interval",
36            self.window_seconds
37        );
38
39        loop {
40            tokio::select! {
41                _ = interval_timer.tick() => {
42                    debug!("Resetting global rate limit counter");
43                    self.sync_manager.reset_global_rate_limit_counter();
44                }
45                _ = shutdown_rx.changed() => {
46                    info!("Rate limit window reset task received shutdown signal");
47                    break;
48                }
49            }
50        }
51
52        info!("Rate limit window reset task stopped");
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use std::{sync::Arc, time::Duration};
59
60    use tokio::time::sleep;
61
62    use super::*;
63    use crate::stores::{
64        RateLimitConfig, StateStores, GLOBAL_RATE_LIMIT_COUNTER_KEY, GLOBAL_RATE_LIMIT_KEY,
65    };
66
67    #[test]
68    fn test_rate_limit_window_new() {
69        let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
70        let sync_manager = Arc::new(MeshSyncManager::new(stores, "node1".to_string()));
71
72        let window = RateLimitWindow::new(sync_manager, 60);
73        // Should create without panicking
74        assert_eq!(window.window_seconds, 60);
75    }
76
77    #[test]
78    fn test_rate_limit_window_different_intervals() {
79        let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
80        let sync_manager = Arc::new(MeshSyncManager::new(stores, "node1".to_string()));
81
82        let window1 = RateLimitWindow::new(sync_manager.clone(), 30);
83        assert_eq!(window1.window_seconds, 30);
84
85        let window2 = RateLimitWindow::new(sync_manager, 120);
86        assert_eq!(window2.window_seconds, 120);
87    }
88
89    #[tokio::test]
90    async fn test_rate_limit_window_reset_task_interval() {
91        let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
92        let sync_manager = Arc::new(MeshSyncManager::new(stores, "node1".to_string()));
93
94        // Set a very short window for testing (1 second)
95        let window = RateLimitWindow::new(sync_manager, 1);
96
97        // Create shutdown channel
98        let (shutdown_tx, shutdown_rx) = watch::channel(false);
99
100        // Spawn the reset task
101        #[expect(
102            clippy::disallowed_methods,
103            reason = "test: handle is awaited with timeout below"
104        )]
105        let task_handle = tokio::spawn(async move {
106            window.start_reset_task(shutdown_rx).await;
107        });
108
109        // Wait a bit to allow the task to run
110        sleep(Duration::from_millis(1500)).await;
111
112        // Send shutdown signal
113        shutdown_tx
114            .send(true)
115            .expect("failed to send shutdown signal");
116
117        // Wait for task to complete gracefully
118        let res = tokio::time::timeout(Duration::from_secs(1), task_handle).await;
119        assert!(res.is_ok(), "reset task did not shut down in time");
120        let join_res = res.unwrap();
121        assert!(join_res.is_ok(), "reset task panicked");
122
123        // The task should have started and stopped gracefully
124    }
125
126    #[tokio::test]
127    async fn test_rate_limit_window_reset_task() {
128        let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
129        let sync_manager = Arc::new(MeshSyncManager::new(stores.clone(), "node1".to_string()));
130
131        // Setup membership
132        stores.rate_limit.update_membership(&["node1".to_string()]);
133
134        // Setup config
135        let key = GLOBAL_RATE_LIMIT_KEY.to_string();
136        let config = RateLimitConfig {
137            limit_per_second: 100,
138        };
139        let serialized = serde_json::to_vec(&config).unwrap();
140        let _ = stores.app.insert(
141            key.clone(),
142            crate::stores::AppState {
143                key: GLOBAL_RATE_LIMIT_KEY.to_string(),
144                value: serialized,
145                version: 1,
146            },
147        );
148
149        // Increment counter
150        if stores.rate_limit.is_owner(GLOBAL_RATE_LIMIT_COUNTER_KEY) {
151            sync_manager.sync_rate_limit_inc(GLOBAL_RATE_LIMIT_COUNTER_KEY.to_string(), 10);
152            let value_before = sync_manager.get_rate_limit_value(GLOBAL_RATE_LIMIT_COUNTER_KEY);
153            assert!(value_before.is_some() && value_before.unwrap() > 0);
154
155            // Create window manager with short interval for testing
156            let window = RateLimitWindow::new(sync_manager.clone(), 1); // 1 second
157
158            // Create shutdown channel
159            let (shutdown_tx, shutdown_rx) = watch::channel(false);
160
161            // Start reset task in background
162            #[expect(
163                clippy::disallowed_methods,
164                reason = "test: handle is awaited with timeout below"
165            )]
166            let reset_handle = tokio::spawn(async move {
167                window.start_reset_task(shutdown_rx).await;
168            });
169
170            // Wait a bit for reset to happen
171            sleep(Duration::from_millis(1500)).await;
172
173            // Check that counter was reset (or at least decremented)
174            let _value_after = sync_manager.get_rate_limit_value(GLOBAL_RATE_LIMIT_COUNTER_KEY);
175            // Counter should be reset or significantly reduced
176            // Note: The exact value depends on timing, but it should be less than initial
177
178            // Send shutdown signal
179            shutdown_tx
180                .send(true)
181                .expect("failed to send shutdown signal");
182
183            // Wait for task to complete gracefully
184            let res = tokio::time::timeout(Duration::from_secs(1), reset_handle).await;
185            assert!(res.is_ok(), "reset task did not shut down in time");
186            let join_res = res.unwrap();
187            assert!(join_res.is_ok(), "reset task panicked");
188        }
189    }
190
191    #[tokio::test]
192    async fn test_rate_limit_window_reset_with_counter() {
193        use crate::stores::MembershipState;
194
195        // Use with_self_name to ensure RateLimitStore uses the same self_name
196        let stores = Arc::new(StateStores::with_self_name("test_node".to_string()));
197        let sync_manager = Arc::new(MeshSyncManager::new(
198            stores.clone(),
199            "test_node".to_string(),
200        ));
201
202        // First, add this node to membership so it can be an owner
203        let membership_key = "test_node".to_string();
204        let membership_state = MembershipState {
205            name: "test_node".to_string(),
206            address: "127.0.0.1:8080".to_string(),
207            status: 1, // NodeStatus::Alive
208            version: 1,
209            metadata: Default::default(),
210        };
211        let _ = stores.membership.insert(membership_key, membership_state);
212
213        // Update rate limit membership so this node becomes an owner
214        sync_manager.update_rate_limit_membership();
215
216        // Check if node is owner before incrementing
217        let key = GLOBAL_RATE_LIMIT_COUNTER_KEY.to_string();
218        let is_owner = stores.rate_limit.is_owner(&key);
219        assert!(is_owner, "Node should be owner of the rate limit key");
220
221        // Set up a rate limit counter via sync_manager
222        // This should increment the counter if the node is an owner
223        sync_manager.sync_rate_limit_inc(key.clone(), 10);
224
225        // Verify counter exists (was created)
226        // Note: The actual value might be 0 due to PNCounter implementation details,
227        // but the counter should exist after inc is called
228        let counter_opt = stores.rate_limit.get_counter(&key);
229        assert!(counter_opt.is_some(), "Counter should exist after inc call");
230
231        // Verify counter was created after inc call
232        // Note: The actual value depends on PNCounter implementation,
233        // but the counter should exist after inc is called
234
235        // Reset the counter
236        sync_manager.reset_global_rate_limit_counter();
237
238        // Verify reset was called (counter should still exist)
239        // The reset implementation decrements by current count,
240        // so the value should be 0 or negative after reset
241        let reset_value = stores.rate_limit.value(&key).unwrap_or(0);
242        // After reset, value should be <= 0 (since we decrement by current count)
243        assert!(
244            reset_value <= 0,
245            "Counter should be reset to 0 or less, got: {reset_value}"
246        );
247    }
248
249    #[test]
250    fn test_rate_limit_window_zero_seconds() {
251        let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
252        let sync_manager = Arc::new(MeshSyncManager::new(stores, "node1".to_string()));
253
254        // Should handle zero seconds (though not recommended in practice)
255        let window = RateLimitWindow::new(sync_manager, 0);
256        assert_eq!(window.window_seconds, 0);
257    }
258
259    #[test]
260    fn test_rate_limit_window_large_interval() {
261        let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
262        let sync_manager = Arc::new(MeshSyncManager::new(stores, "node1".to_string()));
263
264        // Test with a large interval
265        let window = RateLimitWindow::new(sync_manager, 86400); // 24 hours
266        assert_eq!(window.window_seconds, 86400);
267    }
268
269    #[tokio::test]
270    async fn test_reset_global_rate_limit_counter_logic() {
271        let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
272        let sync_manager = Arc::new(MeshSyncManager::new(stores.clone(), "node1".to_string()));
273
274        // Setup membership
275        stores.rate_limit.update_membership(&["node1".to_string()]);
276
277        if stores.rate_limit.is_owner(GLOBAL_RATE_LIMIT_COUNTER_KEY) {
278            // Increment counter
279            sync_manager.sync_rate_limit_inc(GLOBAL_RATE_LIMIT_COUNTER_KEY.to_string(), 20);
280            let value_before = sync_manager.get_rate_limit_value(GLOBAL_RATE_LIMIT_COUNTER_KEY);
281            assert!(value_before.is_some() && value_before.unwrap() > 0);
282
283            // Reset
284            sync_manager.reset_global_rate_limit_counter();
285
286            // Check that counter was reset
287            let value_after = sync_manager.get_rate_limit_value(GLOBAL_RATE_LIMIT_COUNTER_KEY);
288            // Should be 0 or negative after reset
289            assert!(value_after.is_none() || value_after.unwrap() <= 0);
290        }
291    }
292}