this_state/
lib.rs

1//! A library for managing state changes.
2//!
3//! # Examples
4//! The examples below use the following state:
5//! ```
6//! #[derive(Clone, Debug, PartialEq)]
7//! enum MyState {
8//!     A,
9//!     B,
10//!     C
11//! }
12//! ```
13//!
14//! ## Waiting for a state change
15//!
16//! ```
17//! # use this_state::State;
18//! # use tokio::runtime;
19//! #
20//! # #[derive(Clone, Debug, PartialEq)]
21//! # enum MyState {
22//! #     A,
23//! #     B,
24//! #     C
25//! # }
26//! #
27//! # let mut rt = runtime::Builder::new_current_thread()
28//! #     .enable_all()
29//! #     .build()
30//! #     .unwrap();
31//! #
32//! # rt.block_on(async {
33//! let state = State::new(MyState::A);
34//!
35//! let state_clone = state.clone();
36//! tokio::spawn(async move {
37//!     // do some work
38//!     # tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
39//!     state_clone.set(MyState::B);
40//!     // do some more work
41//!     # tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
42//!     state_clone.set(MyState::C);
43//! });
44//!
45//! state.wait_for_state(MyState::C).await;
46//!
47//! assert_eq!(state.get(), MyState::C);
48//! # })
49//! ```
50
51use std::cell::UnsafeCell;
52use std::error::Error;
53use std::fmt;
54use std::future::Future;
55use std::marker::PhantomPinned;
56use std::ops::Deref;
57use std::pin::Pin;
58use std::ptr::{addr_of_mut, NonNull};
59use std::sync::Arc;
60use std::task::{Context, Poll, Waker};
61
62use parking_lot::{RwLock, RwLockReadGuard};
63
64use crate::util::linked_list;
65use crate::util::linked_list::LinkedList;
66
67mod util;
68
69/// A thread-safe state, that can be used to share application state globally.
70///
71/// It is similar to a `RWLock<S>`, but it also allows asynchronous waiting for state changes.
72/// This can be useful to coordinate between different parts of an application.
73#[derive(Clone)]
74pub struct State<S> {
75    /// The state wraps an `Arc` to allow easy cloning.
76    inner: Arc<StateInner<S>>,
77}
78
79/// The inner state of a `State`, this contains the actual state and the wait queue.
80struct StateInner<S> {
81    /// The actual state.
82    state: RwLock<S>,
83    /// The wait queue, containing all tasks waiting for a state change.
84    waiters: RwLock<LinkedList<Waiter, <Waiter as linked_list::Link>::Target>>,
85    /// Callback that is called when the state changes.
86    on_change: Box<dyn Fn(&S, &S) + 'static>,
87}
88
89/// Error returned by `State::set_on_change` if there are multiple references to the state.
90pub struct UpdateOnChangeError {
91    /// Number of references to the state.
92    ///
93    /// # Notes
94    /// This number is only accurate if the state has not been cloned in another thread.
95    pub state_references: usize,
96    /// Private field to prevent construction outside of this crate.
97    _p: (),
98}
99
100/// An entry in the wait queue.
101struct Waiter {
102    /// Indicates whether the task is queued on the wait queue.
103    queued: bool,
104
105    /// Task waiting on a state change.
106    waker: Option<Waker>,
107
108    /// Intrusive linked-list pointers.
109    pointers: linked_list::Pointers<Waiter>,
110
111    /// Should not be `Unpin`.
112    _p: PhantomPinned,
113}
114
115/// A future that completes when the state matches the given predicate.
116/// This is returned by `State::wait_for`.
117///
118/// # Notes
119/// Unlike most futures, this future can be polled multiple times, even after it has completed.
120#[must_use = "futures do nothing unless you `.await` or poll them"]
121pub struct StateFuture<S, C> {
122    state: State<S>,
123    waiter: UnsafeCell<Waiter>,
124    wait_for: C,
125}
126
127/// A reference to the current state, returned by `State::get_ref`.
128/// It wraps a `RwLockReadGuard` and can be used to avoid cloning the state.
129#[must_use]
130pub struct StateRef<'a, S>(RwLockReadGuard<'a, S>);
131
132unsafe impl<S> Send for State<S> {}
133unsafe impl<S> Sync for State<S> {}
134
135unsafe impl<S, C> Send for StateFuture<S, C> {}
136unsafe impl<S, C> Sync for StateFuture<S, C> {}
137
138impl<S> State<S> {
139    /// Creates a new state.
140    pub fn new(state: S) -> Self {
141        Self {
142            inner: Arc::new(StateInner {
143                state: RwLock::new(state),
144                waiters: RwLock::new(LinkedList::new()),
145                on_change: Box::new(|_, _| {}),
146            }),
147        }
148    }
149
150    /// Creates a new state with the given `on_change` callback.
151    ///
152    /// # Notes
153    /// The callback is not called when the state is set for the first time, as well as on
154    /// the `State::update` method. You must call the callback manually in these cases.
155    pub fn new_with_on_change(state: S, on_change: impl Fn(&S, &S) + 'static) -> Self {
156        Self {
157            inner: Arc::new(StateInner {
158                state: RwLock::new(state),
159                waiters: RwLock::new(LinkedList::new()),
160                on_change: Box::new(on_change),
161            }),
162        }
163    }
164
165    /// Tries to set the `on_change` callback, to the new callback.
166    ///
167    /// # Notes
168    /// The callback is not called when the state is set for the first time, as well as on
169    /// the `State::update` method. You must call the callback manually in these cases.
170    pub fn set_on_change(&mut self, on_change: impl Fn(&S, &S) + 'static) -> Result<(), UpdateOnChangeError> {
171        if let Some(inner) = Arc::get_mut(&mut self.inner) {
172            inner.on_change = Box::new(on_change);
173            Ok(())
174        } else {
175            Err(UpdateOnChangeError::new(self.ref_count()))
176        }
177    }
178
179    /// Returns the number of references to the state.
180    /// This can be used to check if there are any other references to the state.
181    pub fn ref_count(&self) -> usize {
182        Arc::strong_count(&self.inner)
183    }
184
185    /// Returns a reference to the current state.
186    /// This can be used if the state does not implement `Clone` or if you want to avoid cloning.
187    pub fn get_ref(&self) -> StateRef<S> {
188        StateRef(self.inner.state.read())
189    }
190
191    /// Returns a future that completes when the state matches the given predicate.
192    pub fn wait_for<C>(&self, wait_for: C) -> StateFuture<S, C>
193    where
194        C: Fn(&S) -> bool,
195    {
196        StateFuture::new(
197            State {
198                inner: self.inner.clone(),
199            },
200            wait_for,
201        )
202    }
203
204    /// Sets the state to the given value.
205    pub fn set(&self, state: S) {
206        let mut write = self.inner.state.write();
207        (self.inner.on_change)(&*write, &state);
208        *write = state;
209        drop(write);
210        self.wake_waiters();
211    }
212
213    /// Updates the state using the given function.
214    /// This avoids having to create a new state value, which can be useful for large state values.
215    ///
216    /// # Notes
217    /// This *DOES NOT* call the `on_change` callback, as it is not possible to get the old state.
218    pub fn update(&self, f: impl FnOnce(&mut S)) {
219        let mut write = self.inner.state.write();
220        f(&mut write);
221        drop(write);
222        self.wake_waiters();
223    }
224
225    /// Wakes all waiters.
226    fn wake_waiters(&self) {
227        let mut waiters = self.inner.waiters.write();
228
229        for mut waiter in waiters.iter() {
230            // Safety: list lock is still held.
231            let waiter = unsafe { waiter.as_mut() };
232
233            assert!(waiter.queued);
234
235            if let Some(waker) = waiter.waker.take() {
236                waker.wake();
237            }
238        }
239    }
240}
241
242impl<S> State<S>
243where
244    S: Clone,
245{
246    /// Returns a clone of the current state.
247    /// This is particularly useful for `State`s that implement `Copy`.
248    pub fn get(&self) -> S {
249        self.get_ref().clone()
250    }
251}
252
253impl<S> State<S>
254where
255    S: PartialEq<S>,
256{
257    /// Returns a future that resolves when the state is equal to the given value.
258    pub fn wait_for_state(&self, wait_for: S) -> StateFuture<S, impl Fn(&S) -> bool> {
259        self.wait_for(move |s| wait_for.eq(s))
260    }
261}
262
263impl<S, O> PartialEq<O> for State<S>
264where
265    S: PartialEq<O>,
266{
267    fn eq(&self, other: &O) -> bool {
268        self.get_ref().eq(other)
269    }
270}
271
272impl<S: fmt::Debug> fmt::Debug for State<S> {
273    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274        f.debug_tuple("State").field(&self.get_ref()).finish()
275    }
276}
277
278impl<S: Default> Default for State<S> {
279    fn default() -> Self {
280        Self::new(Default::default())
281    }
282}
283
284impl UpdateOnChangeError {
285    /// Private constructor.
286    fn new(state_references: usize) -> Self {
287        Self {
288            state_references,
289            _p: (),
290        }
291    }
292}
293
294impl fmt::Debug for UpdateOnChangeError {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        f.debug_struct("UpdateOnChangeError")
297            .field("state_references", &self.state_references)
298            .finish()
299    }
300}
301
302impl fmt::Display for UpdateOnChangeError {
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        write!(f, "Cannot update state, as there are {} other references to the state.", self.state_references)
305    }
306}
307
308impl Error for UpdateOnChangeError {}
309
310impl Waiter {
311    fn new() -> Waiter {
312        Waiter {
313            queued: false,
314            waker: None,
315            pointers: linked_list::Pointers::new(),
316            _p: PhantomPinned,
317        }
318    }
319}
320
321/// # Safety
322///
323/// `Waiter` is forced to be !Unpin.
324unsafe impl linked_list::Link for Waiter {
325    type Handle = NonNull<Waiter>;
326    type Target = Waiter;
327
328    fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> {
329        *handle
330    }
331
332    unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
333        ptr
334    }
335
336    unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
337        let me = target.as_ptr();
338        let field = addr_of_mut!((*me).pointers);
339        NonNull::new_unchecked(field)
340    }
341}
342
343impl<S, C> StateFuture<S, C> {
344    /// Returns a reference to the current state.
345    ///
346    /// This may be useful to create other futures or simply getting the current state.
347    pub fn state(&self) -> &State<S> {
348        &self.state
349    }
350
351    fn queue_waker(self: Pin<&mut Self>, waker: &Waker) {
352        // Acquire a read lock so we guarantee the list is not used while we're modifying the waiter.
353        let lock = self.state.inner.waiters.read();
354        // Safety: We have a read lock, so the list is not being modified, and only one thread can
355        // poll the future at a time.
356        let waiter = unsafe { &mut *self.waiter.get() };
357
358        if !waiter.queued {
359            drop(lock);
360            // Acquire a write lock to add ourselves to the list.
361            let mut lock = self.state.inner.waiters.write();
362
363            // Note: We dont need to check if we got queued in the meantime the lock was acquired,
364            // since only the future itself adds the waiter to the list.
365            waiter.queued = true;
366            waiter.waker = Some(waker.clone());
367
368            lock.push_front(unsafe { NonNull::new_unchecked(waiter) });
369            return;
370        }
371
372        // Safety: list lock is held.
373        match waiter.waker {
374            Some(ref w) if w.will_wake(waker) => {}
375            _ => {
376                waiter.waker = Some(waker.clone());
377            }
378        }
379    }
380
381    fn remove_waiter(&self) {
382        let waiters = self.state.inner.waiters.read();
383
384        let waiter = unsafe { &mut *self.waiter.get() };
385        if !waiter.queued {
386            // Return since the waiter is not queued.
387            return;
388        }
389
390        drop(waiters);
391        let mut waiters = self.state.inner.waiters.write();
392
393        // We don't have to check if the waiter was dropped in the meantime, since only the future
394        // itself removes the waiter from the list.
395
396        unsafe {
397            // Safety: waiter is not null and !Unpin.
398            let nonnull = NonNull::new_unchecked(self.waiter.get());
399            // Safety: we have checked that the waiter is queued and therefore in the list.
400            waiters.remove(nonnull);
401        }
402
403        drop(waiters);
404    }
405}
406
407impl<S, C> StateFuture<S, C>
408where
409    C: Fn(&S) -> bool,
410{
411    fn new(state: State<S>, wait_for: C) -> Self {
412        Self {
413            state,
414            waiter: UnsafeCell::new(Waiter::new()),
415            wait_for,
416        }
417    }
418}
419
420impl<S, C> Future for StateFuture<S, C>
421where
422    C: Fn(&S) -> bool,
423{
424    type Output = ();
425
426    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
427        let state = self.state.inner.state.read();
428        if (self.wait_for)(&*state) {
429            drop(state);
430            // remove the waiter from the list, since we're done waiting.
431            self.remove_waiter();
432            return Poll::Ready(());
433        }
434        drop(state);
435
436        self.queue_waker(cx.waker());
437        Poll::Pending
438    }
439}
440
441impl<S, C> Drop for StateFuture<S, C> {
442    fn drop(&mut self) {
443        // remove the waiter from the list, since we're done waiting.
444        self.remove_waiter();
445    }
446}
447
448impl<'a, S> Deref for StateRef<'a, S> {
449    type Target = S;
450
451    fn deref(&self) -> &Self::Target {
452        &self.0
453    }
454}
455
456impl<'a, S: fmt::Debug> fmt::Debug for StateRef<'a, S> {
457    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458        (**self).fmt(f)
459    }
460}
461
462impl<'a, S: fmt::Display> fmt::Display for StateRef<'a, S> {
463    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
464        (**self).fmt(f)
465    }
466}
467
468#[cfg(test)]
469mod test {
470    use super::*;
471    use tokio::time;
472
473    #[derive(Clone, Copy, Debug, PartialEq)]
474    enum StateEnum {
475        A,
476        B,
477        C,
478    }
479
480    #[test]
481    fn test_state() {
482        let state = State::new(StateEnum::A);
483
484        assert_eq!(state.get(), StateEnum::A);
485
486        state.set(StateEnum::B);
487
488        assert_eq!(state.get(), StateEnum::B);
489    }
490
491    #[tokio::test]
492    async fn test_future1() {
493        let state = State::new(StateEnum::A);
494
495        let state_clone = state.clone();
496        let fut = tokio::spawn(async move { state_clone.wait_for_state(StateEnum::B).await });
497
498        assert_eq!(state.get(), StateEnum::A);
499
500        state.set(StateEnum::B);
501
502        assert_eq!(state.get(), StateEnum::B);
503        // Wait for the future to finish.
504        time::sleep(time::Duration::from_millis(100)).await;
505        assert!(fut.is_finished());
506    }
507
508    #[tokio::test]
509    async fn test_future2() {
510        let state = State::new(StateEnum::A);
511
512        let state_clone = state.clone();
513        let fut = tokio::spawn(async move { state_clone.wait_for_state(StateEnum::B).await });
514
515        assert_eq!(state.get(), StateEnum::A);
516
517        state.set(StateEnum::C);
518
519        assert_eq!(state.get(), StateEnum::C);
520        // Wait for the future to potentially finish.
521        time::sleep(time::Duration::from_millis(100)).await;
522        assert!(!fut.is_finished());
523
524        state.set(StateEnum::B);
525
526        assert_eq!(state.get(), StateEnum::B);
527        // Wait for the future to finish.
528        time::sleep(time::Duration::from_millis(100)).await;
529        assert!(fut.is_finished());
530    }
531
532    #[tokio::test]
533    async fn multiple_waiters() {
534        const NUM_WAITERS: usize = 100;
535
536        let state = State::new(StateEnum::A);
537
538        let mut handles = Vec::with_capacity(NUM_WAITERS);
539        for _ in 0..NUM_WAITERS {
540            let state_clone = state.clone();
541            let handle =
542                tokio::spawn(async move { state_clone.wait_for_state(StateEnum::B).await });
543            handles.push(handle);
544        }
545
546        assert_eq!(state.get(), StateEnum::A);
547
548        state.set(StateEnum::C);
549
550        assert_eq!(state.get(), StateEnum::C);
551        // Wait for the future to potentially finish.
552        time::sleep(time::Duration::from_millis(100)).await;
553        assert!(!handles.iter().any(|h| h.is_finished()));
554
555        state.set(StateEnum::B);
556
557        assert_eq!(state.get(), StateEnum::B);
558        // Wait for the future to finish.
559        time::sleep(time::Duration::from_millis(100)).await;
560        assert!(handles.iter().all(|h| h.is_finished()));
561    }
562}