1use std::collections::BTreeMap;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{Arc, Mutex, OnceLock};
19use std::time::{Duration, Instant};
20
21use tokio::sync::oneshot;
22
23static INTERCEPT_MODE: AtomicBool = AtomicBool::new(false);
27
28static INTERCEPT_STORE: OnceLock<InterceptStore> = OnceLock::new();
31
32#[must_use]
44pub fn intercept_mode_enabled() -> bool {
45 INTERCEPT_MODE.load(Ordering::Acquire)
46}
47
48static MODE_TRANSITION: std::sync::Mutex<()> = std::sync::Mutex::new(());
54
55pub fn toggle_intercept_mode() -> bool {
59 let _guard = MODE_TRANSITION
64 .lock()
65 .unwrap_or_else(std::sync::PoisonError::into_inner);
66 let prev = INTERCEPT_MODE.fetch_xor(true, Ordering::Release);
70 let now_on = !prev;
71 if !now_on {
72 let _ = global_store().drain_release();
73 }
74 now_on
75}
76
77pub fn set_intercept_mode(on: bool) {
80 let _guard = MODE_TRANSITION
81 .lock()
82 .unwrap_or_else(std::sync::PoisonError::into_inner);
83 let prev = INTERCEPT_MODE.swap(on, Ordering::Release);
86 if prev && !on {
87 let _ = global_store().drain_release();
88 }
89}
90
91pub fn global_store() -> &'static InterceptStore {
94 INTERCEPT_STORE.get_or_init(InterceptStore::new)
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum InterceptDecision {
101 Release,
103 Kill,
105}
106
107#[derive(Debug, Clone)]
109pub struct PendingIntercept {
110 pub id: u64,
111 pub host: String,
112 pub method: String,
113 pub path: String,
114 pub since: Instant,
116}
117
118#[derive(Debug, Default, Clone)]
120pub struct InterceptStore {
121 inner: Arc<Mutex<InterceptInner>>,
122}
123
124#[derive(Debug, Default)]
125struct InterceptInner {
126 senders: BTreeMap<u64, oneshot::Sender<InterceptDecision>>,
129 pending: BTreeMap<u64, PendingIntercept>,
131 next_id: u64,
137}
138
139pub const INTERCEPT_TIMEOUT: Duration = Duration::from_secs(30);
143
144impl InterceptStore {
145 pub fn new() -> Self {
146 Self::default()
147 }
148
149 pub fn register(
158 &self,
159 host: impl Into<String>,
160 method: impl Into<String>,
161 path: impl Into<String>,
162 ) -> (u64, oneshot::Receiver<InterceptDecision>) {
163 let (tx, rx) = oneshot::channel();
164 let mut inner = self
165 .inner
166 .lock()
167 .unwrap_or_else(std::sync::PoisonError::into_inner);
168 let dead: Vec<u64> = inner
170 .senders
171 .iter()
172 .filter(|(_, tx)| tx.is_closed())
173 .map(|(id, _)| *id)
174 .collect();
175 for id in dead {
176 inner.senders.remove(&id);
177 inner.pending.remove(&id);
178 }
179 inner.next_id = inner.next_id.wrapping_add(1);
186 if inner.next_id == 0 {
187 inner.next_id = 1;
188 }
189 let id = inner.next_id;
190 inner.senders.insert(id, tx);
191 inner.pending.insert(
192 id,
193 PendingIntercept {
194 id,
195 host: host.into(),
196 method: method.into(),
197 path: path.into(),
198 since: Instant::now(),
199 },
200 );
201 (id, rx)
202 }
203
204 pub fn gc_dead_senders(&self) -> usize {
208 let mut inner = self
209 .inner
210 .lock()
211 .unwrap_or_else(std::sync::PoisonError::into_inner);
212 let dead: Vec<u64> = inner
213 .senders
214 .iter()
215 .filter(|(_, tx)| tx.is_closed())
216 .map(|(id, _)| *id)
217 .collect();
218 let n = dead.len();
219 for id in dead {
220 inner.senders.remove(&id);
221 inner.pending.remove(&id);
222 }
223 n
224 }
225
226 pub fn resolve(&self, id: u64, decision: InterceptDecision) -> bool {
229 let mut inner = self
230 .inner
231 .lock()
232 .unwrap_or_else(std::sync::PoisonError::into_inner);
233 inner.pending.remove(&id);
234 if let Some(tx) = inner.senders.remove(&id) {
235 let _ = tx.send(decision);
236 true
237 } else {
238 false
239 }
240 }
241
242 pub fn cancel(&self, id: u64) -> bool {
250 let mut inner = self
251 .inner
252 .lock()
253 .unwrap_or_else(std::sync::PoisonError::into_inner);
254 let removed_pending = inner.pending.remove(&id).is_some();
255 let removed_sender = inner.senders.remove(&id).is_some();
256 removed_pending || removed_sender
257 }
258
259 pub fn drain_release(&self) -> usize {
263 let mut inner = self
264 .inner
265 .lock()
266 .unwrap_or_else(std::sync::PoisonError::into_inner);
267 let ids: Vec<u64> = inner.senders.keys().copied().collect();
268 let mut released = 0;
269 for id in ids {
270 if let Some(tx) = inner.senders.remove(&id) {
271 inner.pending.remove(&id);
272 let _ = tx.send(InterceptDecision::Release);
273 released += 1;
274 }
275 }
276 released
277 }
278
279 pub fn snapshot(&self) -> Vec<PendingIntercept> {
281 let inner = self
282 .inner
283 .lock()
284 .unwrap_or_else(std::sync::PoisonError::into_inner);
285 inner.pending.values().cloned().collect()
286 }
287
288 pub fn pending_count(&self) -> usize {
290 let inner = self
291 .inner
292 .lock()
293 .unwrap_or_else(std::sync::PoisonError::into_inner);
294 inner.pending.len()
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 fn store() -> InterceptStore {
303 InterceptStore::new()
304 }
305
306 #[tokio::test]
307 async fn register_then_release_unblocks_with_release() {
308 let s = store();
309 let (id, rx) = s.register("h", "GET", "/");
310 let s2 = s.clone();
311 tokio::spawn(async move {
312 tokio::time::sleep(Duration::from_millis(10)).await;
313 s2.resolve(id, InterceptDecision::Release);
314 });
315 let decision = rx.await.expect("rx");
316 assert_eq!(decision, InterceptDecision::Release);
317 assert_eq!(s.pending_count(), 0, "pending must drain after resolve");
318 }
319
320 #[tokio::test]
321 async fn register_then_kill_unblocks_with_kill() {
322 let s = store();
323 let (id, rx) = s.register("h", "POST", "/admin");
324 let s2 = s.clone();
325 tokio::spawn(async move {
326 s2.resolve(id, InterceptDecision::Kill);
327 });
328 assert_eq!(rx.await.unwrap(), InterceptDecision::Kill);
329 }
330
331 #[tokio::test]
332 async fn snapshot_shows_pending_until_resolved() {
333 let s = store();
334 let (id1, _r1) = s.register("a.com", "GET", "/x");
335 let (id2, _r2) = s.register("b.com", "POST", "/y");
336 let snap = s.snapshot();
337 assert_eq!(snap.len(), 2);
338 assert!(snap.iter().any(|p| p.id == id1 && p.host == "a.com"));
339 assert!(snap.iter().any(|p| p.id == id2 && p.host == "b.com"));
340 }
341
342 #[tokio::test]
343 async fn drain_release_unblocks_every_pending() {
344 let s = store();
345 let (_, rx1) = s.register("a", "GET", "/");
346 let (_, rx2) = s.register("b", "GET", "/");
347 let n = s.drain_release();
348 assert_eq!(n, 2);
349 assert_eq!(rx1.await.unwrap(), InterceptDecision::Release);
350 assert_eq!(rx2.await.unwrap(), InterceptDecision::Release);
351 assert_eq!(s.pending_count(), 0);
352 }
353
354 #[tokio::test]
355 async fn resolve_unknown_id_is_idempotent_no_op() {
356 let s = store();
357 let acted = s.resolve(999, InterceptDecision::Release);
358 assert!(!acted, "resolve of unknown id must report it didn't fire");
359 }
360
361 #[tokio::test]
362 async fn resolve_twice_only_fires_once() {
363 let s = store();
364 let (id, rx) = s.register("h", "GET", "/");
365 assert!(s.resolve(id, InterceptDecision::Release));
366 assert!(
367 !s.resolve(id, InterceptDecision::Kill),
368 "second resolve must no-op"
369 );
370 assert_eq!(rx.await.unwrap(), InterceptDecision::Release);
371 }
372
373 #[tokio::test]
374 async fn timeout_default_release_via_select() {
375 let s = store();
378 let (_id, rx) = s.register("h", "GET", "/");
379 let result = tokio::time::timeout(Duration::from_millis(50), rx).await;
380 assert!(result.is_err(), "rx must NOT complete on its own");
381 }
382
383 #[tokio::test]
384 async fn ids_are_monotonic_per_register() {
385 let s = store();
386 let (id1, _) = s.register("a", "GET", "/");
387 let (id2, _) = s.register("a", "GET", "/");
388 let (id3, _) = s.register("a", "GET", "/");
389 assert_eq!(id2, id1 + 1);
390 assert_eq!(id3, id2 + 1);
391 }
392
393 #[test]
394 fn id_zero_is_reserved_and_resolve_cancel_return_false() {
395 let s = store();
400 assert!(!s.resolve(0, InterceptDecision::Release));
402 assert!(!s.cancel(0));
403 let (id, _rx) = s.register("h", "GET", "/");
405 assert_eq!(id, 1, "first id must be 1 (0 is reserved)");
406 assert!(!s.resolve(0, InterceptDecision::Release));
408 assert!(!s.cancel(0));
409 }
410
411 #[test]
412 fn id_wraparound_skips_zero() {
413 let s = store();
417 {
418 let mut inner = s.inner.lock().unwrap();
419 inner.next_id = u64::MAX - 1;
420 }
421 let (id1, _rx1) = s.register("h", "GET", "/");
423 assert_eq!(id1, u64::MAX, "pre-wraparound id must be u64::MAX");
424 let (id2, _rx2) = s.register("h", "GET", "/");
426 assert_eq!(id2, 1, "post-wraparound id must skip 0 and return 1");
427 assert_ne!(id2, 0, "id=0 must never be issued");
428 }
429
430 #[test]
431 fn cancel_removes_from_both_maps() {
432 let s = store();
433 let (id, _rx) = s.register("h", "GET", "/path");
434 assert_eq!(s.pending_count(), 1);
435 let removed = s.cancel(id);
436 assert!(removed, "cancel must return true for a valid id");
437 assert_eq!(s.pending_count(), 0, "cancel must drain from pending map");
438 assert!(!s.cancel(id), "second cancel returns false (already gone)");
440 }
441
442 #[test]
443 fn gc_dead_senders_removes_disconnected_rx() {
444 let s = store();
445 let (id, rx) = s.register("h", "GET", "/");
446 drop(rx);
448 let removed = s.gc_dead_senders();
450 assert_eq!(removed, 1, "exactly one dead sender must be GCd");
451 assert_eq!(s.pending_count(), 0);
452 assert!(!s.cancel(id));
454 }
455
456 #[test]
457 fn resolve_zero_id_never_matches_real_intercept() {
458 let s = store();
461 {
462 let mut inner = s.inner.lock().unwrap();
463 inner.next_id = u64::MAX;
464 }
465 let (id, _rx) = s.register("h", "GET", "/");
467 assert_eq!(id, 1);
468 assert!(!s.resolve(0, InterceptDecision::Kill));
470 assert_eq!(
471 s.pending_count(),
472 1,
473 "id=1 must still be pending after resolve(0)"
474 );
475 }
476}