Skip to main content

stratum_apps/
fallback_coordinator.rs

1use std::sync::{
2    atomic::{AtomicBool, AtomicUsize, Ordering},
3    Arc,
4};
5use tokio::sync::Notify;
6use tokio_util::sync::CancellationToken;
7
8/// coordinates fallback operations across multiple components with acknowledgement.
9///
10/// this is meant to be used together with [`crate::task_manager::TaskManager`],
11/// as it allows triggering a fallback event (via [`CancellationToken`]) and waiting
12/// until all registered components have completed their cleanup.
13///
14/// in summary, every time we spawn a fallback-relevant task inside the manager, we MUST:
15/// - call [`FallbackCoordinator::register`] at task bootstrap
16/// - call [`FallbackCoordinator::done`] at task completion
17///
18/// when a fallback trigger arrives to the main status loop, we MUST call
19/// [`FallbackCoordinator::trigger_and_wait`] to wait for all registered components to complete
20/// their cleanup before re-initializing them under the new upstream server.
21///
22/// finally, a new [`FallbackCoordinator`] must be instantiated for the next fallback cycle.
23#[derive(Debug, Clone)]
24pub struct FallbackCoordinator {
25    signal: CancellationToken,
26    pending_tasks: Arc<AtomicUsize>,
27    notify: Arc<Notify>,
28}
29
30impl Default for FallbackCoordinator {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl FallbackCoordinator {
37    pub fn new() -> Self {
38        Self {
39            signal: CancellationToken::new(),
40            pending_tasks: Arc::new(AtomicUsize::new(0)),
41            notify: Arc::new(Notify::new()),
42        }
43    }
44
45    /// register a component that will participate in fallback coordination
46    /// returns a [`FallbackHandler`] that must be called when the component is done
47    #[must_use]
48    pub fn register(&self) -> FallbackHandler {
49        tracing::debug!("FallbackCoordinator: registering component");
50        self.pending_tasks.fetch_add(1, Ordering::Relaxed);
51
52        FallbackHandler {
53            coordinator: self.clone(),
54            done: AtomicBool::new(false),
55        }
56    }
57
58    /// get the cancellation token that signals fallback
59    pub fn token(&self) -> CancellationToken {
60        self.signal.clone()
61    }
62
63    /// trigger fallback and wait for all registered components to acknowledge
64    pub async fn trigger_fallback_and_wait(&self) {
65        tracing::debug!("FallbackCoordinator: triggering fallback");
66        self.signal.cancel();
67
68        if self.pending_tasks.load(Ordering::Acquire) == 0 {
69            return; // all tasks already done
70        }
71
72        // there's still some tasks running,
73        // wait for the last task to notify us
74        self.notify.notified().await;
75        tracing::debug!("FallbackCoordinator: finished waiting for components to complete cleanup");
76    }
77}
78
79pub struct FallbackHandler {
80    coordinator: FallbackCoordinator,
81    done: AtomicBool,
82}
83
84/// Handler for a component that will participate in fallback coordination
85///
86/// ⚠️ Warning: dropping this handler without calling [`FallbackHandler::done`] will result in a
87/// panic.
88impl FallbackHandler {
89    /// Mark this handler as finished
90    /// Takes ownership of `self`, preventing double-calling
91    pub fn done(self) {
92        tracing::debug!("FallbackHandler: done called");
93        self.done.store(true, Ordering::Release);
94
95        let prev = self
96            .coordinator
97            .pending_tasks
98            .fetch_sub(1, Ordering::Release);
99
100        // Notify if fallback has been triggered and this is the last handler
101        if self.coordinator.signal.is_cancelled() && prev == 1 {
102            self.coordinator.notify.notify_one();
103        }
104    }
105}
106
107impl Drop for FallbackHandler {
108    fn drop(&mut self) {
109        if !self.done.load(Ordering::Acquire) {
110            panic!("FallbackHandler dropped without calling done()");
111        }
112    }
113}