trillium_server_common/
clone_counter.rs

1use event_listener::{Event, EventListener};
2use std::{
3    fmt::{Debug, Formatter, Result},
4    future::{Future, IntoFuture},
5    pin::{pin, Pin},
6    sync::{
7        atomic::{AtomicUsize, Ordering},
8        Arc,
9    },
10    task::{ready, Context, Poll},
11};
12
13pub struct CloneCounterInner {
14    count: AtomicUsize,
15    event: Event,
16}
17
18impl CloneCounterInner {
19    fn new(start: usize) -> Self {
20        Self {
21            count: AtomicUsize::new(start),
22            event: Event::new(),
23        }
24    }
25}
26
27impl Debug for CloneCounterInner {
28    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
29        f.debug_struct("CloneCounterInner")
30            .field("count", &self.count)
31            .finish()
32    }
33}
34
35/**
36# an atomic counter that increments on clone & decrements on drop
37
38One-indexed, because the first CloneCounter is included. If you don't
39want the original to count, construct a [`CloneCounterObserver`]
40instead and use [`CloneCounterObserver::counter`] to increment.
41
42Awaiting a [`CloneCounter`] will be pending until it is the only remaining
43counter and resolve to `()` when the count is 1.
44
45*/
46
47#[derive(Debug)]
48pub struct CloneCounter(Arc<CloneCounterInner>);
49
50impl Default for CloneCounter {
51    fn default() -> Self {
52        Self(Arc::new(CloneCounterInner::new(1)))
53    }
54}
55
56impl CloneCounter {
57    /// Constructs a new CloneCounter. Identical to CloneCounter::default()
58    pub fn new() -> Self {
59        Self::default()
60    }
61
62    /// Returns the current count. The first CloneCounter is one, so
63    /// this can either be considered a one-indexed count of the
64    /// total number of CloneCounters in memory
65    pub fn current(&self) -> usize {
66        self.0.current()
67    }
68
69    /// Manually decrement the count. This is useful when taking a
70    /// clone of the counter that does not represent an increase in
71    /// the underlying property or resource being counted. This is
72    /// called automatically on drop and is usually unnecessary to
73    /// call directly
74    pub fn decrement(&self) {
75        let previously = self.0.count.fetch_sub(1, Ordering::SeqCst);
76        debug_assert!(previously > 0);
77        self.0.wake();
78        if previously > 0 {
79            log::trace!("decrementing from {} -> {}", previously, previously - 1);
80        } else {
81            log::trace!("decrementing from 0");
82        }
83    }
84
85    /// Manually increment the count. unless paired with a decrement,
86    /// this will prevent the clone counter from ever reaching
87    /// zero. This is called automatically on clone.
88    pub fn increment(&self) {
89        let previously = self.0.count.fetch_add(1, Ordering::SeqCst);
90        log::trace!("incrementing from {} -> {}", previously, previously + 1);
91    }
92
93    /// Returns an observer that can be cloned any number of times
94    /// without modifying the clone counter. See
95    /// [`CloneCounterObserver`] for more.
96    pub fn observer(&self) -> CloneCounterObserver {
97        CloneCounterObserver(Arc::clone(&self.0))
98    }
99}
100
101impl IntoFuture for CloneCounter {
102    type Output = ();
103
104    type IntoFuture = CloneCounterFuture;
105
106    fn into_future(self) -> Self::IntoFuture {
107        CloneCounterFuture {
108            inner: Arc::clone(&self.0),
109            listener: EventListener::new(),
110        }
111    }
112}
113
114impl Future for &CloneCounter {
115    type Output = ();
116
117    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
118        let mut listener = pin!(EventListener::new());
119        loop {
120            if 1 == self.0.current() {
121                return Poll::Ready(());
122            }
123
124            if listener.is_listening() {
125                ready!(listener.as_mut().poll(cx));
126            } else {
127                listener.as_mut().listen(&self.0.event)
128            }
129        }
130    }
131}
132impl Clone for CloneCounter {
133    fn clone(&self) -> Self {
134        self.increment();
135        Self(self.0.clone())
136    }
137}
138
139impl Drop for CloneCounter {
140    fn drop(&mut self) {
141        self.decrement();
142    }
143}
144
145impl CloneCounterInner {
146    fn current(&self) -> usize {
147        self.count.load(Ordering::SeqCst)
148    }
149
150    fn wake(&self) {
151        self.event.notify(usize::MAX);
152    }
153}
154
155impl PartialEq<usize> for CloneCounter {
156    fn eq(&self, other: &usize) -> bool {
157        self.current() == *other
158    }
159}
160
161/**
162An observer that can be cloned without modifying the clone
163counter, but can be used to inspect its state and awaited
164
165Zero-indexed, but each [`CloneCounter`] retrieved with
166[`CloneCounterObserver::counter`] increments the count by 1.
167
168Awaiting a [`CloneCounterObserver`] will be pending until all
169associated [`CloneCounter`]s have been dropped, and will resolve to
170`()` when the count is 0.
171
172*/
173
174#[derive(Debug)]
175pub struct CloneCounterObserver(Arc<CloneCounterInner>);
176
177impl Clone for CloneCounterObserver {
178    fn clone(&self) -> Self {
179        Self(self.0.clone())
180    }
181}
182
183impl Default for CloneCounterObserver {
184    fn default() -> Self {
185        Self(Arc::new(CloneCounterInner::new(0)))
186    }
187}
188
189impl PartialEq<usize> for CloneCounterObserver {
190    fn eq(&self, other: &usize) -> bool {
191        self.current() == *other
192    }
193}
194
195impl CloneCounterObserver {
196    /// returns a new observer with a zero count. use [`CloneCounterObserver::counter`] to
197    pub fn new() -> Self {
198        Self::default()
199    }
200    /// returns the current counter value
201    pub fn current(&self) -> usize {
202        self.0.current()
203    }
204
205    /// creates a new CloneCounter from this observer, incrementing the count
206    pub fn counter(&self) -> CloneCounter {
207        let counter = CloneCounter(Arc::clone(&self.0));
208        counter.increment();
209        counter
210    }
211}
212
213impl IntoFuture for CloneCounterObserver {
214    type Output = ();
215
216    type IntoFuture = CloneCounterFuture;
217
218    fn into_future(self) -> Self::IntoFuture {
219        CloneCounterFuture {
220            listener: EventListener::new(),
221            inner: self.0,
222        }
223    }
224}
225
226impl From<CloneCounter> for CloneCounterObserver {
227    fn from(value: CloneCounter) -> Self {
228        // value will be decremented on drop of the original
229        Self(Arc::clone(&value.0))
230    }
231}
232
233impl From<CloneCounterObserver> for CloneCounter {
234    fn from(value: CloneCounterObserver) -> Self {
235        let counter = Self(value.0);
236        counter.increment();
237        counter
238    }
239}
240
241pin_project_lite::pin_project! {
242    /// A future that waits for the clone counter to decrement to zero
243    #[derive(Debug)]
244    pub struct CloneCounterFuture {
245        inner: Arc<CloneCounterInner>,
246        #[pin]
247        listener: EventListener,
248    }
249}
250
251impl Clone for CloneCounterFuture {
252    fn clone(&self) -> Self {
253        let listener = EventListener::new();
254        Self {
255            inner: Arc::clone(&self.inner),
256            listener,
257        }
258    }
259}
260
261impl Future for CloneCounterFuture {
262    type Output = ();
263
264    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265        let mut this = self.project();
266        loop {
267            if 0 == this.inner.current() {
268                return Poll::Ready(());
269            };
270            if this.listener.is_listening() {
271                ready!(this.listener.as_mut().poll(cx));
272            } else {
273                this.listener.as_mut().listen(&this.inner.event);
274            }
275        }
276    }
277}
278
279#[cfg(test)]
280mod test {
281    use crate::clone_counter::CloneCounterObserver;
282
283    use super::CloneCounter;
284    use futures_lite::future::poll_once;
285    use std::future::{Future, IntoFuture};
286    use test_harness::test;
287
288    fn block_on<F, Fut>(test: F)
289    where
290        F: FnOnce() -> Fut,
291        Fut: Future<Output = ()>,
292    {
293        trillium_testing::block_on(test());
294    }
295
296    #[test(harness = block_on)]
297    async fn doctest_example() {
298        let counter = CloneCounter::new();
299        assert_eq!(counter.current(), 1);
300        counter.await; // ready immediately
301
302        let counter = CloneCounter::new();
303        assert_eq!(counter.current(), 1);
304        let clone = counter.clone();
305        assert_eq!(counter.current(), 2);
306        let clone2 = counter.clone();
307        assert_eq!(counter.current(), 3);
308        assert_eq!(poll_once(clone2.into_future()).await, None); // pending
309        assert_eq!(counter.current(), 2);
310        drop(clone);
311
312        assert_eq!(counter.current(), 1);
313        counter.await; // ready
314    }
315
316    #[test(harness = block_on)]
317    async fn observer_into_and_from() {
318        let counter = CloneCounter::new();
319        assert_eq!(counter, 1);
320        assert_eq!(counter.clone(), 2);
321        assert_eq!(counter, 1);
322        let observer = CloneCounterObserver::from(counter);
323        assert_eq!(poll_once(observer.clone().into_future()).await, Some(()));
324        assert_eq!(observer, 0);
325        let counter = CloneCounter::from(observer);
326        assert_eq!(counter, 1);
327        assert_eq!(poll_once(counter.into_future()).await, Some(()));
328    }
329
330    #[test(harness = block_on)]
331    async fn observer_test() {
332        let counter = CloneCounter::new();
333        assert_eq!(counter.current(), 1);
334        counter.await; // ready immediately
335
336        let counter = CloneCounter::new();
337        let mut clones = Vec::new();
338        let observer = counter.observer();
339        assert_eq!(observer.current(), 1);
340        for i in 1..=10 {
341            clones.push(counter.clone());
342            assert_eq!(counter.current(), 1 + i);
343            assert_eq!(observer.current(), 1 + i);
344        }
345
346        let _observers = std::iter::repeat_with(|| observer.clone())
347            .take(10)
348            .collect::<Vec<_>>();
349        assert_eq!(observer.current(), 11); // unchanged,
350
351        let _observers = std::iter::repeat_with(|| counter.observer())
352            .take(10)
353            .collect::<Vec<_>>();
354        assert_eq!(observer.current(), 11); // unchanged,
355
356        for (i, clone) in clones.drain(..).enumerate() {
357            assert_eq!(clone.current(), 11 - i);
358            assert_eq!(observer.current(), 11 - i);
359            assert_eq!(poll_once(&clone).await, None); // pending
360            assert_eq!(poll_once(observer.clone().into_future()).await, None); // pending
361            drop(clone);
362            assert_eq!(counter.current(), 10 - i);
363            assert_eq!(observer.current(), 10 - i);
364        }
365
366        assert_eq!(counter.current(), 1);
367        assert_eq!(poll_once(counter.into_future()).await, Some(()));
368        assert_eq!(observer.current(), 0);
369        assert_eq!(poll_once(observer.into_future()).await, Some(()));
370    }
371}