smg_mesh/
rate_limit_window.rs1use std::{sync::Arc, time::Duration};
6
7use tokio::{sync::watch, time::interval};
8use tracing::{debug, info};
9
10use super::sync::MeshSyncManager;
11
12pub struct RateLimitWindow {
15 sync_manager: Arc<MeshSyncManager>,
16 window_seconds: u64,
17}
18
19impl RateLimitWindow {
20 pub fn new(sync_manager: Arc<MeshSyncManager>, window_seconds: u64) -> Self {
21 Self {
22 sync_manager,
23 window_seconds,
24 }
25 }
26
27 pub async fn start_reset_task(self, mut shutdown_rx: watch::Receiver<bool>) {
33 let mut interval_timer = interval(Duration::from_secs(self.window_seconds));
34 info!(
35 "Starting rate limit window reset task with {}s interval",
36 self.window_seconds
37 );
38
39 loop {
40 tokio::select! {
41 _ = interval_timer.tick() => {
42 debug!("Resetting global rate limit counter");
43 self.sync_manager.reset_global_rate_limit_counter();
44 }
45 _ = shutdown_rx.changed() => {
46 info!("Rate limit window reset task received shutdown signal");
47 break;
48 }
49 }
50 }
51
52 info!("Rate limit window reset task stopped");
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use std::{sync::Arc, time::Duration};
59
60 use tokio::time::sleep;
61
62 use super::*;
63 use crate::stores::{
64 RateLimitConfig, StateStores, GLOBAL_RATE_LIMIT_COUNTER_KEY, GLOBAL_RATE_LIMIT_KEY,
65 };
66
67 #[test]
68 fn test_rate_limit_window_new() {
69 let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
70 let sync_manager = Arc::new(MeshSyncManager::new(stores, "node1".to_string()));
71
72 let window = RateLimitWindow::new(sync_manager, 60);
73 assert_eq!(window.window_seconds, 60);
75 }
76
77 #[test]
78 fn test_rate_limit_window_different_intervals() {
79 let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
80 let sync_manager = Arc::new(MeshSyncManager::new(stores, "node1".to_string()));
81
82 let window1 = RateLimitWindow::new(sync_manager.clone(), 30);
83 assert_eq!(window1.window_seconds, 30);
84
85 let window2 = RateLimitWindow::new(sync_manager, 120);
86 assert_eq!(window2.window_seconds, 120);
87 }
88
89 #[tokio::test]
90 async fn test_rate_limit_window_reset_task_interval() {
91 let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
92 let sync_manager = Arc::new(MeshSyncManager::new(stores, "node1".to_string()));
93
94 let window = RateLimitWindow::new(sync_manager, 1);
96
97 let (shutdown_tx, shutdown_rx) = watch::channel(false);
99
100 #[expect(
102 clippy::disallowed_methods,
103 reason = "test: handle is awaited with timeout below"
104 )]
105 let task_handle = tokio::spawn(async move {
106 window.start_reset_task(shutdown_rx).await;
107 });
108
109 sleep(Duration::from_millis(1500)).await;
111
112 shutdown_tx
114 .send(true)
115 .expect("failed to send shutdown signal");
116
117 let res = tokio::time::timeout(Duration::from_secs(1), task_handle).await;
119 assert!(res.is_ok(), "reset task did not shut down in time");
120 let join_res = res.unwrap();
121 assert!(join_res.is_ok(), "reset task panicked");
122
123 }
125
126 #[tokio::test]
127 async fn test_rate_limit_window_reset_task() {
128 let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
129 let sync_manager = Arc::new(MeshSyncManager::new(stores.clone(), "node1".to_string()));
130
131 stores.rate_limit.update_membership(&["node1".to_string()]);
133
134 let key = GLOBAL_RATE_LIMIT_KEY.to_string();
136 let config = RateLimitConfig {
137 limit_per_second: 100,
138 };
139 let serialized = serde_json::to_vec(&config).unwrap();
140 let _ = stores.app.insert(
141 key.clone(),
142 crate::stores::AppState {
143 key: GLOBAL_RATE_LIMIT_KEY.to_string(),
144 value: serialized,
145 version: 1,
146 },
147 );
148
149 if stores.rate_limit.is_owner(GLOBAL_RATE_LIMIT_COUNTER_KEY) {
151 sync_manager.sync_rate_limit_inc(GLOBAL_RATE_LIMIT_COUNTER_KEY.to_string(), 10);
152 let value_before = sync_manager.get_rate_limit_value(GLOBAL_RATE_LIMIT_COUNTER_KEY);
153 assert!(value_before.is_some() && value_before.unwrap() > 0);
154
155 let window = RateLimitWindow::new(sync_manager.clone(), 1); let (shutdown_tx, shutdown_rx) = watch::channel(false);
160
161 #[expect(
163 clippy::disallowed_methods,
164 reason = "test: handle is awaited with timeout below"
165 )]
166 let reset_handle = tokio::spawn(async move {
167 window.start_reset_task(shutdown_rx).await;
168 });
169
170 sleep(Duration::from_millis(1500)).await;
172
173 let _value_after = sync_manager.get_rate_limit_value(GLOBAL_RATE_LIMIT_COUNTER_KEY);
175 shutdown_tx
180 .send(true)
181 .expect("failed to send shutdown signal");
182
183 let res = tokio::time::timeout(Duration::from_secs(1), reset_handle).await;
185 assert!(res.is_ok(), "reset task did not shut down in time");
186 let join_res = res.unwrap();
187 assert!(join_res.is_ok(), "reset task panicked");
188 }
189 }
190
191 #[tokio::test]
192 async fn test_rate_limit_window_reset_with_counter() {
193 use crate::stores::MembershipState;
194
195 let stores = Arc::new(StateStores::with_self_name("test_node".to_string()));
197 let sync_manager = Arc::new(MeshSyncManager::new(
198 stores.clone(),
199 "test_node".to_string(),
200 ));
201
202 let membership_key = "test_node".to_string();
204 let membership_state = MembershipState {
205 name: "test_node".to_string(),
206 address: "127.0.0.1:8080".to_string(),
207 status: 1, version: 1,
209 metadata: Default::default(),
210 };
211 let _ = stores.membership.insert(membership_key, membership_state);
212
213 sync_manager.update_rate_limit_membership();
215
216 let key = GLOBAL_RATE_LIMIT_COUNTER_KEY.to_string();
218 let is_owner = stores.rate_limit.is_owner(&key);
219 assert!(is_owner, "Node should be owner of the rate limit key");
220
221 sync_manager.sync_rate_limit_inc(key.clone(), 10);
224
225 let counter_opt = stores.rate_limit.get_counter(&key);
229 assert!(counter_opt.is_some(), "Counter should exist after inc call");
230
231 sync_manager.reset_global_rate_limit_counter();
237
238 let reset_value = stores.rate_limit.value(&key).unwrap_or(0);
242 assert!(
244 reset_value <= 0,
245 "Counter should be reset to 0 or less, got: {reset_value}"
246 );
247 }
248
249 #[test]
250 fn test_rate_limit_window_zero_seconds() {
251 let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
252 let sync_manager = Arc::new(MeshSyncManager::new(stores, "node1".to_string()));
253
254 let window = RateLimitWindow::new(sync_manager, 0);
256 assert_eq!(window.window_seconds, 0);
257 }
258
259 #[test]
260 fn test_rate_limit_window_large_interval() {
261 let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
262 let sync_manager = Arc::new(MeshSyncManager::new(stores, "node1".to_string()));
263
264 let window = RateLimitWindow::new(sync_manager, 86400); assert_eq!(window.window_seconds, 86400);
267 }
268
269 #[tokio::test]
270 async fn test_reset_global_rate_limit_counter_logic() {
271 let stores = Arc::new(StateStores::with_self_name("node1".to_string()));
272 let sync_manager = Arc::new(MeshSyncManager::new(stores.clone(), "node1".to_string()));
273
274 stores.rate_limit.update_membership(&["node1".to_string()]);
276
277 if stores.rate_limit.is_owner(GLOBAL_RATE_LIMIT_COUNTER_KEY) {
278 sync_manager.sync_rate_limit_inc(GLOBAL_RATE_LIMIT_COUNTER_KEY.to_string(), 20);
280 let value_before = sync_manager.get_rate_limit_value(GLOBAL_RATE_LIMIT_COUNTER_KEY);
281 assert!(value_before.is_some() && value_before.unwrap() > 0);
282
283 sync_manager.reset_global_rate_limit_counter();
285
286 let value_after = sync_manager.get_rate_limit_value(GLOBAL_RATE_LIMIT_COUNTER_KEY);
288 assert!(value_after.is_none() || value_after.unwrap() <= 0);
290 }
291 }
292}