1use std::cell::UnsafeCell;
52use std::error::Error;
53use std::fmt;
54use std::future::Future;
55use std::marker::PhantomPinned;
56use std::ops::Deref;
57use std::pin::Pin;
58use std::ptr::{addr_of_mut, NonNull};
59use std::sync::Arc;
60use std::task::{Context, Poll, Waker};
61
62use parking_lot::{RwLock, RwLockReadGuard};
63
64use crate::util::linked_list;
65use crate::util::linked_list::LinkedList;
66
67mod util;
68
69#[derive(Clone)]
74pub struct State<S> {
75 inner: Arc<StateInner<S>>,
77}
78
79struct StateInner<S> {
81 state: RwLock<S>,
83 waiters: RwLock<LinkedList<Waiter, <Waiter as linked_list::Link>::Target>>,
85 on_change: Box<dyn Fn(&S, &S) + 'static>,
87}
88
89pub struct UpdateOnChangeError {
91 pub state_references: usize,
96 _p: (),
98}
99
100struct Waiter {
102 queued: bool,
104
105 waker: Option<Waker>,
107
108 pointers: linked_list::Pointers<Waiter>,
110
111 _p: PhantomPinned,
113}
114
115#[must_use = "futures do nothing unless you `.await` or poll them"]
121pub struct StateFuture<S, C> {
122 state: State<S>,
123 waiter: UnsafeCell<Waiter>,
124 wait_for: C,
125}
126
127#[must_use]
130pub struct StateRef<'a, S>(RwLockReadGuard<'a, S>);
131
132unsafe impl<S> Send for State<S> {}
133unsafe impl<S> Sync for State<S> {}
134
135unsafe impl<S, C> Send for StateFuture<S, C> {}
136unsafe impl<S, C> Sync for StateFuture<S, C> {}
137
138impl<S> State<S> {
139 pub fn new(state: S) -> Self {
141 Self {
142 inner: Arc::new(StateInner {
143 state: RwLock::new(state),
144 waiters: RwLock::new(LinkedList::new()),
145 on_change: Box::new(|_, _| {}),
146 }),
147 }
148 }
149
150 pub fn new_with_on_change(state: S, on_change: impl Fn(&S, &S) + 'static) -> Self {
156 Self {
157 inner: Arc::new(StateInner {
158 state: RwLock::new(state),
159 waiters: RwLock::new(LinkedList::new()),
160 on_change: Box::new(on_change),
161 }),
162 }
163 }
164
165 pub fn set_on_change(&mut self, on_change: impl Fn(&S, &S) + 'static) -> Result<(), UpdateOnChangeError> {
171 if let Some(inner) = Arc::get_mut(&mut self.inner) {
172 inner.on_change = Box::new(on_change);
173 Ok(())
174 } else {
175 Err(UpdateOnChangeError::new(self.ref_count()))
176 }
177 }
178
179 pub fn ref_count(&self) -> usize {
182 Arc::strong_count(&self.inner)
183 }
184
185 pub fn get_ref(&self) -> StateRef<S> {
188 StateRef(self.inner.state.read())
189 }
190
191 pub fn wait_for<C>(&self, wait_for: C) -> StateFuture<S, C>
193 where
194 C: Fn(&S) -> bool,
195 {
196 StateFuture::new(
197 State {
198 inner: self.inner.clone(),
199 },
200 wait_for,
201 )
202 }
203
204 pub fn set(&self, state: S) {
206 let mut write = self.inner.state.write();
207 (self.inner.on_change)(&*write, &state);
208 *write = state;
209 drop(write);
210 self.wake_waiters();
211 }
212
213 pub fn update(&self, f: impl FnOnce(&mut S)) {
219 let mut write = self.inner.state.write();
220 f(&mut write);
221 drop(write);
222 self.wake_waiters();
223 }
224
225 fn wake_waiters(&self) {
227 let mut waiters = self.inner.waiters.write();
228
229 for mut waiter in waiters.iter() {
230 let waiter = unsafe { waiter.as_mut() };
232
233 assert!(waiter.queued);
234
235 if let Some(waker) = waiter.waker.take() {
236 waker.wake();
237 }
238 }
239 }
240}
241
242impl<S> State<S>
243where
244 S: Clone,
245{
246 pub fn get(&self) -> S {
249 self.get_ref().clone()
250 }
251}
252
253impl<S> State<S>
254where
255 S: PartialEq<S>,
256{
257 pub fn wait_for_state(&self, wait_for: S) -> StateFuture<S, impl Fn(&S) -> bool> {
259 self.wait_for(move |s| wait_for.eq(s))
260 }
261}
262
263impl<S, O> PartialEq<O> for State<S>
264where
265 S: PartialEq<O>,
266{
267 fn eq(&self, other: &O) -> bool {
268 self.get_ref().eq(other)
269 }
270}
271
272impl<S: fmt::Debug> fmt::Debug for State<S> {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 f.debug_tuple("State").field(&self.get_ref()).finish()
275 }
276}
277
278impl<S: Default> Default for State<S> {
279 fn default() -> Self {
280 Self::new(Default::default())
281 }
282}
283
284impl UpdateOnChangeError {
285 fn new(state_references: usize) -> Self {
287 Self {
288 state_references,
289 _p: (),
290 }
291 }
292}
293
294impl fmt::Debug for UpdateOnChangeError {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 f.debug_struct("UpdateOnChangeError")
297 .field("state_references", &self.state_references)
298 .finish()
299 }
300}
301
302impl fmt::Display for UpdateOnChangeError {
303 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304 write!(f, "Cannot update state, as there are {} other references to the state.", self.state_references)
305 }
306}
307
308impl Error for UpdateOnChangeError {}
309
310impl Waiter {
311 fn new() -> Waiter {
312 Waiter {
313 queued: false,
314 waker: None,
315 pointers: linked_list::Pointers::new(),
316 _p: PhantomPinned,
317 }
318 }
319}
320
321unsafe impl linked_list::Link for Waiter {
325 type Handle = NonNull<Waiter>;
326 type Target = Waiter;
327
328 fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> {
329 *handle
330 }
331
332 unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
333 ptr
334 }
335
336 unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
337 let me = target.as_ptr();
338 let field = addr_of_mut!((*me).pointers);
339 NonNull::new_unchecked(field)
340 }
341}
342
343impl<S, C> StateFuture<S, C> {
344 pub fn state(&self) -> &State<S> {
348 &self.state
349 }
350
351 fn queue_waker(self: Pin<&mut Self>, waker: &Waker) {
352 let lock = self.state.inner.waiters.read();
354 let waiter = unsafe { &mut *self.waiter.get() };
357
358 if !waiter.queued {
359 drop(lock);
360 let mut lock = self.state.inner.waiters.write();
362
363 waiter.queued = true;
366 waiter.waker = Some(waker.clone());
367
368 lock.push_front(unsafe { NonNull::new_unchecked(waiter) });
369 return;
370 }
371
372 match waiter.waker {
374 Some(ref w) if w.will_wake(waker) => {}
375 _ => {
376 waiter.waker = Some(waker.clone());
377 }
378 }
379 }
380
381 fn remove_waiter(&self) {
382 let waiters = self.state.inner.waiters.read();
383
384 let waiter = unsafe { &mut *self.waiter.get() };
385 if !waiter.queued {
386 return;
388 }
389
390 drop(waiters);
391 let mut waiters = self.state.inner.waiters.write();
392
393 unsafe {
397 let nonnull = NonNull::new_unchecked(self.waiter.get());
399 waiters.remove(nonnull);
401 }
402
403 drop(waiters);
404 }
405}
406
407impl<S, C> StateFuture<S, C>
408where
409 C: Fn(&S) -> bool,
410{
411 fn new(state: State<S>, wait_for: C) -> Self {
412 Self {
413 state,
414 waiter: UnsafeCell::new(Waiter::new()),
415 wait_for,
416 }
417 }
418}
419
420impl<S, C> Future for StateFuture<S, C>
421where
422 C: Fn(&S) -> bool,
423{
424 type Output = ();
425
426 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
427 let state = self.state.inner.state.read();
428 if (self.wait_for)(&*state) {
429 drop(state);
430 self.remove_waiter();
432 return Poll::Ready(());
433 }
434 drop(state);
435
436 self.queue_waker(cx.waker());
437 Poll::Pending
438 }
439}
440
441impl<S, C> Drop for StateFuture<S, C> {
442 fn drop(&mut self) {
443 self.remove_waiter();
445 }
446}
447
448impl<'a, S> Deref for StateRef<'a, S> {
449 type Target = S;
450
451 fn deref(&self) -> &Self::Target {
452 &self.0
453 }
454}
455
456impl<'a, S: fmt::Debug> fmt::Debug for StateRef<'a, S> {
457 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458 (**self).fmt(f)
459 }
460}
461
462impl<'a, S: fmt::Display> fmt::Display for StateRef<'a, S> {
463 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
464 (**self).fmt(f)
465 }
466}
467
468#[cfg(test)]
469mod test {
470 use super::*;
471 use tokio::time;
472
473 #[derive(Clone, Copy, Debug, PartialEq)]
474 enum StateEnum {
475 A,
476 B,
477 C,
478 }
479
480 #[test]
481 fn test_state() {
482 let state = State::new(StateEnum::A);
483
484 assert_eq!(state.get(), StateEnum::A);
485
486 state.set(StateEnum::B);
487
488 assert_eq!(state.get(), StateEnum::B);
489 }
490
491 #[tokio::test]
492 async fn test_future1() {
493 let state = State::new(StateEnum::A);
494
495 let state_clone = state.clone();
496 let fut = tokio::spawn(async move { state_clone.wait_for_state(StateEnum::B).await });
497
498 assert_eq!(state.get(), StateEnum::A);
499
500 state.set(StateEnum::B);
501
502 assert_eq!(state.get(), StateEnum::B);
503 time::sleep(time::Duration::from_millis(100)).await;
505 assert!(fut.is_finished());
506 }
507
508 #[tokio::test]
509 async fn test_future2() {
510 let state = State::new(StateEnum::A);
511
512 let state_clone = state.clone();
513 let fut = tokio::spawn(async move { state_clone.wait_for_state(StateEnum::B).await });
514
515 assert_eq!(state.get(), StateEnum::A);
516
517 state.set(StateEnum::C);
518
519 assert_eq!(state.get(), StateEnum::C);
520 time::sleep(time::Duration::from_millis(100)).await;
522 assert!(!fut.is_finished());
523
524 state.set(StateEnum::B);
525
526 assert_eq!(state.get(), StateEnum::B);
527 time::sleep(time::Duration::from_millis(100)).await;
529 assert!(fut.is_finished());
530 }
531
532 #[tokio::test]
533 async fn multiple_waiters() {
534 const NUM_WAITERS: usize = 100;
535
536 let state = State::new(StateEnum::A);
537
538 let mut handles = Vec::with_capacity(NUM_WAITERS);
539 for _ in 0..NUM_WAITERS {
540 let state_clone = state.clone();
541 let handle =
542 tokio::spawn(async move { state_clone.wait_for_state(StateEnum::B).await });
543 handles.push(handle);
544 }
545
546 assert_eq!(state.get(), StateEnum::A);
547
548 state.set(StateEnum::C);
549
550 assert_eq!(state.get(), StateEnum::C);
551 time::sleep(time::Duration::from_millis(100)).await;
553 assert!(!handles.iter().any(|h| h.is_finished()));
554
555 state.set(StateEnum::B);
556
557 assert_eq!(state.get(), StateEnum::B);
558 time::sleep(time::Duration::from_millis(100)).await;
560 assert!(handles.iter().all(|h| h.is_finished()));
561 }
562}