Skip to main content

soft_cycle/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4pub mod payload;
5
6use payload::AtomicPrimitive;
7#[doc(inline)]
8pub use payload::Payload;
9
10use std::{
11    future::Future,
12    pin::Pin,
13    sync::{
14        Arc,
15        atomic::{AtomicU8, AtomicU32, Ordering},
16    },
17    task::{Context, Poll},
18};
19
20use tokio::sync::{Notify, futures::OwnedNotified};
21
22/// Future that completes when a notification from the associated controller is observed.
23///
24/// Returned by [`SoftCycleController::listener`]. Resolves to `Ok(payload)` when a notification
25/// is observed. If the controller is already notified when the listener is created or first
26/// polled, it may complete immediately with that payload. If multiple notify/clear cycles
27/// occur after the listener is created, it returns one of those payloads (no guarantee of
28/// returning the earliest or latest). See the [crate-level documentation](crate) for full
29/// completion guarantees.
30pub struct SoftCycleListener<'a, T: Payload> {
31    notify: OwnedNotified,
32    controller: &'a SoftCycleController<T>,
33}
34
35impl<T: Payload> Future for SoftCycleListener<'_, T> {
36    type Output = Result<T, ()>;
37
38    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
39        let notify_pin = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.notify) };
40
41        match notify_pin.poll(cx) {
42            Poll::Pending => {}
43            // CORRECTNESS: If the `OwnedNotified` is ready, then a call to
44            // `try_notify` must have written the payload and the written
45            // value must be visible here.
46            Poll::Ready(()) => return Poll::Ready(Ok(self.controller.read_payload())),
47        }
48
49        // If `OwnedNotified` was not ready in the last statement, then either:
50        //
51        // - `listener` was called after `try_notify`, in which case we
52        //   should return the payload immediately, or
53        // - it was just not notified yet.
54        //
55        // We check the status anyway here. In the first case, it ensures the
56        // correctness of the implementation. In the second case, it's a
57        // speculative attempt to return early (if we're lucky enough that
58        // `try_notify` was called after we polled the `OwnedNotified` and
59        // before we checked the status). In both cases, it's correct.
60        if self.controller.is_notified() {
61            return Poll::Ready(Ok(self.controller.read_payload()));
62        }
63
64        Poll::Pending
65    }
66}
67
68/// Controller status: 0 = not notified.
69const STATUS_NOT_NOTIFIED: u8 = 0;
70/// Controller status: 1 = storing payload.
71const STATUS_STORING_PAYLOAD: u8 = 1;
72/// Controller status: 2 = notified.
73const STATUS_NOTIFIED: u8 = 2;
74/// Controller status: 3 = clearing.
75const STATUS_CLEARING: u8 = 3;
76
77/// Coordination controller for soft restarts and graceful shutdowns.
78///
79/// `SoftCycleController` exposes a tiny async-friendly coordination protocol:
80///
81/// - producers call [`try_notify`](Self::try_notify) to publish a notification with a payload and
82///   notify waiters;
83/// - producers call [`try_clear`](Self::try_clear) to move back to the non-notified state;
84/// - consumers create a [`SoftCycleListener`] via [`listener`](Self::listener) to wait for
85///   notifications with payloads.
86///
87/// See the [crate-level documentation](crate) for detailed documentation and usage examples.
88pub struct SoftCycleController<T: Payload = ()> {
89    /// Notify used to signal listeners.
90    notify: Arc<Notify>,
91
92    /// Next notify sequence number.
93    next_notify_sequence: AtomicU32,
94
95    /// Controller status.
96    status: AtomicU8,
97
98    /// Atomic payload slot; always holds a valid value. Readers may see a newer
99    /// payload than the notification they observed, but never invalid data.
100    payload: <T as Payload>::UnderlyingAtomic,
101}
102
103impl<T: Payload> SoftCycleController<T> {
104    /// Creates a new [`SoftCycleController`].
105    #[allow(clippy::new_without_default)]
106    pub fn new() -> Self {
107        Self {
108            notify: Arc::new(Notify::new()),
109            next_notify_sequence: AtomicU32::new(0),
110            status: AtomicU8::new(STATUS_NOT_NOTIFIED),
111            payload: <T as Payload>::UnderlyingAtomic::new_default(),
112        }
113    }
114
115    /// Attempts to notify. On success returns `Ok(sequence_number)` where the sequence number
116    /// starts from 0 and increases. On failure (already notified) returns
117    /// `Err(payload)` with the payload unchanged. Never blocks.
118    #[must_use = "Caller must check if the operation was successful"]
119    pub fn try_notify(&self, payload: T) -> Result<u32, T> {
120        match self.status.compare_exchange(
121            STATUS_NOT_NOTIFIED,
122            STATUS_STORING_PAYLOAD,
123            Ordering::AcqRel,
124            Ordering::Relaxed,
125        ) {
126            Ok(_) => {
127                let sequence_number = self.next_notify_sequence.fetch_add(1, Ordering::AcqRel);
128                self.payload.store(payload.into());
129                self.status.store(STATUS_NOTIFIED, Ordering::Release);
130                self.notify.notify_waiters();
131                Ok(sequence_number)
132            }
133            Err(_) => Err(payload),
134        }
135    }
136
137    /// Clears the notified state. On success returns `Ok(sequence_number)` for the notification
138    /// sequence number that was cleared. On failure (not currently notified) returns `Err(())`.
139    /// Never blocks.
140    ///
141    /// Clearing after a notify does not prevent listeners already waiting from receiving the notification.
142    #[allow(clippy::result_unit_err)]
143    pub fn try_clear(&self) -> Result<u32, ()> {
144        match self.status.compare_exchange(
145            STATUS_NOTIFIED,
146            STATUS_CLEARING,
147            Ordering::AcqRel,
148            Ordering::Relaxed,
149        ) {
150            Ok(_) => {
151                let sequence_number = self
152                    .next_notify_sequence
153                    .load(Ordering::Acquire)
154                    .saturating_sub(1);
155                self.status.store(STATUS_NOT_NOTIFIED, Ordering::Release);
156                Ok(sequence_number)
157            }
158            Err(_) => Err(()),
159        }
160    }
161
162    /// Returns a listener that resolves when a notification is observed, with `Ok(payload)`.
163    /// If already notified, it may complete in a finite number of polls; if not yet
164    /// notified, it completes in a finite number of polls after the next [`try_notify`](Self::try_notify).
165    /// After multiple notify/clear cycles, the listener returns one of the payloads (no
166    /// guarantee of earliest or latest). See the [crate-level documentation](crate).
167    #[must_use = "Caller must await the listener to receive the signal"]
168    pub fn listener<'a>(&'a self) -> SoftCycleListener<'a, T> {
169        SoftCycleListener {
170            notify: self.notify.clone().notified_owned(),
171            controller: self,
172        }
173    }
174
175    /// Returns true if the controller is in the notified state.
176    fn is_notified(&self) -> bool {
177        self.status.load(Ordering::Acquire) == STATUS_NOTIFIED
178    }
179
180    /// Reads payload atomically. It always returns a valid value because even
181    /// if [`try_notify`] was never called, the payload is initialized to the
182    /// default value.
183    fn read_payload(&self) -> T {
184        let inner = self.payload.load();
185        T::from(inner)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use std::time::Duration;
192    use std::time::Instant;
193
194    use super::*;
195
196    // --- A. Global order / return value semantics ---
197
198    /// Guarantee: try_notify success returns increasing sequence numbers from 0 upward.
199    #[tokio::test]
200    async fn guarantee_a_try_notify_sn_from_zero() {
201        let ctrl = SoftCycleController::<u32>::new();
202        assert_eq!(ctrl.try_notify(10), Ok(0));
203        assert_eq!(ctrl.try_clear(), Ok(0));
204        assert_eq!(ctrl.try_notify(20), Ok(1));
205        assert_eq!(ctrl.try_clear(), Ok(1));
206        assert_eq!(ctrl.try_notify(30), Ok(2));
207    }
208
209    /// Guarantee: try_clear when not notified returns Err(()).
210    #[tokio::test]
211    async fn guarantee_a_try_clear_fails_when_not_notified() {
212        let ctrl = SoftCycleController::<u32>::new();
213        assert_eq!(ctrl.try_clear(), Err(()));
214        assert_eq!(ctrl.try_notify(1), Ok(0));
215        assert_eq!(ctrl.try_clear(), Ok(0));
216        assert_eq!(ctrl.try_clear(), Err(()));
217    }
218
219    /// Guarantee: try_clear success returns the sequence number that was cleared; sequence numbers
220    /// are consistent across notify/clear sequence.
221    #[tokio::test]
222    async fn guarantee_a_sn_sequence_notify_clear_interleaved() {
223        let ctrl = SoftCycleController::<u32>::new();
224        assert_eq!(ctrl.try_notify(100), Ok(0));
225        assert_eq!(ctrl.try_clear(), Ok(0));
226        assert_eq!(ctrl.try_notify(200), Ok(1));
227        assert_eq!(ctrl.try_clear(), Ok(1));
228        assert_eq!(ctrl.try_notify(300), Ok(2));
229    }
230
231    // --- B. Non-blocking (try_notify and try_clear are synchronous) ---
232
233    /// Guarantee: try_clear completes immediately even with many listeners (no reader-count barrier).
234    #[tokio::test]
235    async fn guarantee_b_try_clear_nonblocking_many_listeners() {
236        let ctrl = Arc::new(SoftCycleController::<u32>::new());
237        let mut listener_handles = Vec::new();
238        for _ in 0..100 {
239            let c = ctrl.clone();
240            listener_handles.push(tokio::spawn(async move { c.listener().await }));
241        }
242        assert_eq!(ctrl.try_notify(1), Ok(0));
243        let deadline = Duration::from_millis(100);
244        let clear_done = tokio::time::timeout(deadline, async {
245            let _ = ctrl.try_clear();
246        });
247        clear_done.await.expect("try_clear must not block");
248        assert_eq!(ctrl.try_clear(), Err(()));
249        ctrl.try_notify(2).ok();
250        let mut got = 0;
251        for h in listener_handles {
252            if let Ok(Ok(v)) = tokio::time::timeout(Duration::from_secs(2), h).await {
253                assert!(v == Ok(1) || v == Ok(2));
254                got += 1;
255            }
256        }
257        assert!(got > 0, "at least one listener should get a value");
258    }
259
260    /// Guarantee: try_notify completes synchronously even with many listeners already waiting.
261    #[tokio::test]
262    async fn guarantee_b_try_notify_nonblocking_many_listeners() {
263        let ctrl = Arc::new(SoftCycleController::<u32>::new());
264        for _ in 0..50 {
265            let c = ctrl.clone();
266            tokio::spawn(async move {
267                let _ = c.listener().await;
268            });
269        }
270        tokio::time::sleep(Duration::from_millis(20)).await;
271        let start = Instant::now();
272        let res = ctrl.try_notify(1);
273        assert!(res.is_ok(), "try_notify must succeed");
274        assert!(
275            start.elapsed() < Duration::from_millis(50),
276            "try_notify must not block"
277        );
278    }
279
280    // --- C. Listener completion ---
281
282    /// Guarantee: listener created after notify, before clear, completes in finite polls (here, immediately).
283    #[tokio::test]
284    async fn guarantee_c_listener_created_while_notified_completes() {
285        let ctrl = Arc::new(SoftCycleController::<u32>::new());
286        assert_eq!(ctrl.try_notify(42), Ok(0));
287        let v = ctrl.listener().await;
288        assert_eq!(v, Ok(42));
289    }
290
291    /// Guarantee: listener created after clear, before next notify, completes in finite polls after next try_notify.
292    #[tokio::test]
293    async fn guarantee_c_listener_created_before_notify_completes_after_notify() {
294        let ctrl = Arc::new(SoftCycleController::<u32>::new());
295        let c = ctrl.clone();
296        let listener_task = tokio::spawn(async move { c.listener().await });
297        tokio::time::sleep(Duration::from_millis(50)).await;
298        assert_eq!(ctrl.try_notify(7), Ok(0));
299        let r = tokio::time::timeout(Duration::from_secs(1), listener_task)
300            .await
301            .expect("listener must complete within timeout")
302            .expect("task must not panic");
303        assert_eq!(r, Ok(7));
304    }
305
306    // --- D. Multi-round: listener returns one of the payloads (no guarantee of earliest/latest) ---
307
308    /// Guarantee: after multiple notify/clear cycles, listener returns one of the payloads.
309    #[tokio::test]
310    async fn guarantee_d_listener_returns_one_payload_after_multi_round() {
311        let ctrl = Arc::new(SoftCycleController::<u32>::new());
312        let c = ctrl.clone();
313        let listener_task = tokio::spawn(async move { c.listener().await });
314        assert!(ctrl.try_notify(1).is_ok());
315        let _ = ctrl.try_clear();
316        assert!(ctrl.try_notify(2).is_ok());
317        let r = listener_task.await.unwrap();
318        let allowed = [Ok(1), Ok(2)];
319        assert!(
320            allowed.contains(&r),
321            "listener must return one of the payloads, got {:?}",
322            r
323        );
324    }
325
326    #[tokio::test]
327    async fn concurrent_multi_round_collects_subset_of_payloads() {
328        let ctrl = Arc::new(SoftCycleController::<u32>::new());
329        let mut seen = Vec::new();
330        let reader = ctrl.clone();
331        let reader_handle = tokio::spawn(async move {
332            for _ in 0..20 {
333                let v = reader.listener().await;
334                if let Ok(x) = v {
335                    seen.push(x);
336                }
337            }
338            seen
339        });
340        for i in 0..10u32 {
341            assert!(ctrl.try_notify(i).is_ok());
342            tokio::time::sleep(Duration::from_millis(2)).await;
343            let _ = ctrl.try_clear();
344            tokio::time::sleep(Duration::from_millis(2)).await;
345        }
346        ctrl.try_notify(99).ok();
347        let collected = reader_handle.await.unwrap();
348        assert!(!collected.is_empty());
349        assert!(collected.iter().all(|&x| (0..10).contains(&x) || x == 99));
350    }
351
352    #[tokio::test]
353    async fn stress_many_cycles_and_listeners() {
354        let ctrl = Arc::new(SoftCycleController::<u32>::new());
355        let writer = ctrl.clone();
356        let writer_handle = tokio::spawn(async move {
357            for i in 1..=400u32 {
358                let _ = writer.try_clear();
359                if writer.try_notify(i).is_ok() {
360                    tokio::time::sleep(Duration::from_millis(30)).await;
361                } else {
362                    panic!("notify failed");
363                }
364            }
365        });
366        let reader = ctrl.clone();
367        let reader_handle = tokio::spawn(async move {
368            for _ in 0..3000 {
369                if let Ok(v) = reader.listener().await {
370                    assert!(0 < v && v <= 400);
371                    tokio::time::sleep(Duration::from_millis(3)).await;
372                }
373            }
374        });
375        let _ = tokio::join!(writer_handle, reader_handle);
376    }
377
378    // --- Regression: concurrent try_notify / try_clear sequence number semantics ---
379
380    /// Guarantee: under concurrent try_notify/try_clear, every successful try_notify returns a
381    /// unique sequence number (0, 1, 2, ...), and every successful try_clear returns a sequence
382    /// number that was previously returned by try_notify.
383    #[tokio::test]
384    async fn regression_concurrent_try_notify_try_clear_sequence_numbers_unique_and_consistent() {
385        let ctrl = Arc::new(SoftCycleController::<u32>::new());
386        let mut notify_seqs: Vec<u32> = Vec::new();
387        let mut clear_seqs: Vec<u32> = Vec::new();
388        let mut handles = Vec::new();
389        for _ in 0..8 {
390            let c = ctrl.clone();
391            let h = tokio::spawn(async move {
392                let mut my_notify = Vec::new();
393                let mut my_clear = Vec::new();
394                for i in 0..20u32 {
395                    if let Ok(seq) = c.try_notify(i) {
396                        my_notify.push(seq);
397                    }
398                    if c.try_clear().map(|seq| my_clear.push(seq)).is_err() {}
399                }
400                (my_notify, my_clear)
401            });
402            handles.push(h);
403        }
404        for h in handles {
405            let (n, cl) = h.await.unwrap();
406            notify_seqs.extend(n);
407            clear_seqs.extend(cl);
408        }
409        notify_seqs.sort_unstable();
410        clear_seqs.sort_unstable();
411        // try_notify sequence numbers must be unique and contiguous from 0.
412        let n = notify_seqs.len();
413        let unique: std::collections::HashSet<u32> = notify_seqs.iter().copied().collect();
414        assert_eq!(unique.len(), n, "every try_notify Ok(seq) must be unique");
415        for seq in 0..n as u32 {
416            assert!(
417                notify_seqs.contains(&seq),
418                "sequence numbers must be contiguous from 0, missing {}",
419                seq
420            );
421        }
422        // Every cleared sequence must have been notified.
423        for &cleared in &clear_seqs {
424            assert!(
425                notify_seqs.contains(&cleared),
426                "try_clear returned seq {} which was not returned by try_notify",
427                cleared
428            );
429        }
430    }
431
432    /// Guarantee: under concurrent contention, try_clear success returns the sequence number
433    /// of the notification that was cleared (documented semantics).
434    #[tokio::test]
435    async fn regression_concurrent_try_clear_returns_notification_sequence() {
436        let ctrl = Arc::new(SoftCycleController::<u32>::new());
437        let ctrl2 = ctrl.clone();
438        let notifier = tokio::spawn(async move {
439            for i in 0u32..50 {
440                if ctrl2.try_notify(100 + i).is_ok() {
441                    tokio::time::sleep(Duration::from_millis(1)).await;
442                }
443            }
444        });
445        let clearer = tokio::spawn(async move {
446            let mut cleared = Vec::new();
447            for _ in 0..60 {
448                if let Ok(seq) = ctrl.try_clear() {
449                    cleared.push(seq);
450                }
451                tokio::time::sleep(Duration::from_millis(1)).await;
452            }
453            cleared
454        });
455        let _ = notifier.await;
456        let cleared = clearer.await.unwrap();
457        // Every try_clear Ok(seq) must be a sequence number that was returned by try_notify.
458        for &s in &cleared {
459            assert!(s < 50, "cleared seq {} must be from a prior notify", s);
460        }
461    }
462}
463
464#[cfg(feature = "global_instance")]
465#[cfg_attr(docsrs, doc(cfg(feature = "global_instance")))]
466mod global;
467
468#[cfg(feature = "global_instance")]
469#[cfg_attr(docsrs, doc(cfg(feature = "global_instance")))]
470pub use global::*;