wait_for_me/
lib.rs

1//! This library provides an implementation of an async [`CountDownLatch`],
2//! which keeps a counter synchronized via [`Lock`][async-lock::Lock] in it's internal state and allows tasks to wait until
3//! the counter reaches zero.
4//!
5//! # Example
6//! ```rust,no_run
7//! use wait_for_me::CountDownLatch;
8//! use smol::{self,Task};
9//! fn main() -> Result<(), Box<dyn std::error::Error>> {
10//!    smol::block_on(async {
11//!         let latch = CountDownLatch::new(1);
12//!         let latch1 = latch.clone();
13//!         smol::spawn(async move {
14//!             latch1.count_down().await;
15//!         }).detach();
16//!         latch.wait().await;
17//!         Ok(())
18//!    })
19//!
20//!}
21//! ```
22//!
23//! With timeout
24//!
25//! ```rust,no_run
26//! use wait_for_me::CountDownLatch;
27//! use smol::{Task,Timer};
28//! use std::time::Duration;
29//! #[smol_potat::main]
30//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
31//!    let latch = CountDownLatch::new(10);
32//!    for _ in 0..10 {
33//!        let latch1 = latch.clone();
34//!        smol::spawn(async move {
35//!            Timer::after(Duration::from_secs(3)).await;
36//!            latch1.count_down().await;
37//!        }).detach();
38//!    }
39//!    let result = latch.wait_for(Duration::from_secs(1)).await;
40//!
41//!    assert_eq!(false,result);
42//!
43//!    Ok(())
44//!}
45//!```
46//!
47
48use futures::future::Either;
49use futures_timer::Delay;
50use std::future::Future;
51use std::pin::Pin;
52use std::sync::Arc;
53use std::task::{Context, Poll, Waker};
54use std::time::Duration;
55use async_lock::{Mutex};
56use async_lock::futures::{LockArc};
57
58struct CountDownState {
59    count: usize,
60    wakers: Vec<Waker>,
61}
62
63impl CountDownLatch {
64    /// Creates a new [`CountDownLatch`] with a given count.
65    pub fn new(count: usize) -> CountDownLatch {
66        CountDownLatch {
67            state: Arc::new(Mutex::new(CountDownState {
68                count,
69                wakers: vec![],
70            })),
71        }
72    }
73
74    /// Returns the current count.
75    pub async fn count(&self) -> usize {
76        let state = self.state.lock().await;
77        state.count
78    }
79
80    /// Cause the current task to wait until the counter reaches zero
81    pub fn wait(&self) -> impl Future<Output = ()> {
82        WaitFuture {
83            latch: self.clone(),
84            state_lock: None,
85        }
86    }
87
88    /// Cause the current task to wait until the counter reaches zero with timeout.
89    ///
90    /// If the specified timeout elapsed `false` is returned. Otherwise `true`.
91    pub async fn wait_for(&self, timeout: Duration) -> bool {
92        let delay = Delay::new(timeout);
93        match futures::future::select(delay, self.wait()).await {
94            Either::Left(_) => false,
95            Either::Right(_) => true,
96        }
97    }
98
99    /// Decrement the counter of one unit. If the counter reaches zero all the waiting tasks are released.
100    pub async fn count_down(&self) {
101        let mut state = self.state.lock().await;
102        let count = state.count.saturating_sub(1);
103        state.set(count);
104    }
105
106    /// Sets the internal count.
107    pub async fn set(&self, count: usize) {
108        let mut state = self.state.lock().await;
109        state.set(count);
110    }
111}
112
113impl CountDownState {
114    fn set(&mut self, count: usize) {
115        self.count = count;
116        if count == 0 {
117            for waker in self.wakers.drain(..) {
118                waker.wake();
119            }
120        }
121    }
122}
123
124struct WaitFuture {
125    latch: CountDownLatch,
126    state_lock: Option<Box<LockArc<CountDownState>>>,
127}
128
129impl Future for WaitFuture {
130    type Output = ();
131
132    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
133        loop {
134            match self.state_lock.take() {
135                Some(mut state_lock) => {
136                    return match unsafe { Pin::new_unchecked(state_lock.as_mut()) }.poll(cx) {
137                        Poll::Ready(mut guard) => {
138                            if guard.count > 0 {
139                                for waker in guard.wakers.iter() {
140                                    if waker.will_wake(cx.waker()) {
141                                        return Poll::Pending
142                                    }
143                                }
144                                guard.wakers.push(cx.waker().clone());
145                                Poll::Pending
146                            } else {
147                                for waker in guard.wakers.drain(..) {
148                                    waker.wake();
149                                }
150                                Poll::Ready(())
151                            }
152                        }
153                        Poll::Pending => {
154                            // Do not drop state_lock otherwise our waker from the poll call
155                            // would be dropped as well.
156                            self.state_lock = Some(state_lock);
157                            Poll::Pending
158                        }
159                    }
160                }
161                None => {
162                    self.state_lock = Some(Box::new(self.latch.state.lock_arc()));
163                }
164            }
165        }
166    }
167}
168
169/// A synchronization primitive that allows one or more tasks to wait until the given counter reaches zero.
170/// This is an async port of [CountDownLatch](https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/CountDownLatch.html) in Java.
171#[derive(Clone)]
172pub struct CountDownLatch {
173    state: Arc<Mutex<CountDownState>>,
174}
175
176#[cfg(test)]
177mod tests {
178    use super::CountDownLatch;
179    use futures_executor::{LocalPool, ThreadPool};
180    use futures_util::task::SpawnExt;
181    use std::time::Duration;
182    use futures_util::future::{join, join_all};
183
184    #[test]
185    fn countdownlatch_test() {
186        let mut pool = LocalPool::new();
187
188        let spawner = pool.spawner();
189        let latch = CountDownLatch::new(2);
190        let latch1 = latch.clone();
191        spawner
192            .spawn(async move { latch1.count_down().await })
193            .unwrap();
194
195        let latch2 = latch.clone();
196        spawner
197            .spawn(async move { latch2.count_down().await })
198            .unwrap();
199
200        let latch3 = latch.clone();
201        spawner
202            .spawn(async move {
203                latch3.wait().await;
204            })
205            .unwrap();
206
207        spawner
208            .spawn(async move {
209                latch.wait().await;
210            })
211            .unwrap();
212
213        pool.run();
214    }
215
216    #[test]
217    fn countdownlatch_pre_wait_test() {
218        let mut pool = LocalPool::new();
219
220        let spawner = pool.spawner();
221        let latch = CountDownLatch::new(1);
222
223        let latch1 = latch.clone();
224        spawner
225            .spawn(async move { latch1.wait().await })
226            .unwrap();
227
228        spawner
229            .spawn(async move { latch.count_down().await })
230            .unwrap();
231
232        pool.run();
233    }
234
235    #[test]
236    fn countdownlatch_parallel_pre_wait_test() {
237        let pool = ThreadPool::builder().pool_size(4).create().unwrap();
238
239        let latch = CountDownLatch::new(1);
240
241        let latch1 = latch.clone();
242        let handle1 = pool
243            .spawn_with_handle(async move { latch1.wait().await })
244            .unwrap();
245
246        let handle2 = pool
247            .spawn_with_handle(async move { latch.count_down().await })
248            .unwrap();
249
250        futures_executor::block_on(join(handle1, handle2));
251    }
252
253    #[test]
254    fn countdownlatch_concurrent_test() {
255        let mut pool = LocalPool::new();
256
257        let spawner = pool.spawner();
258        let latch = CountDownLatch::new(100);
259
260        for _ in 0..200 {
261            let latch1 = latch.clone();
262            spawner
263                .spawn(async move { latch1.count_down().await })
264                .unwrap();
265        }
266
267        for _ in 0..100 {
268            let latch1 = latch.clone();
269            spawner.spawn(async move { latch1.wait().await }).unwrap();
270        }
271
272        pool.run();
273    }
274
275    #[test]
276    fn countdownlatch_no_wait_test() {
277        let mut pool = LocalPool::new();
278
279        let spawner = pool.spawner();
280        let latch = CountDownLatch::new(100);
281
282        for _ in 0..200 {
283            let latch1 = latch.clone();
284            spawner
285                .spawn(async move { latch1.count_down().await })
286                .unwrap();
287        }
288
289        pool.run();
290    }
291
292    #[test]
293    fn countdownlatch_post_wait_test() {
294        let mut pool = LocalPool::new();
295
296        let spawner = pool.spawner();
297        let latch = CountDownLatch::new(100);
298
299        for _ in 0..200 {
300            let latch1 = latch.clone();
301            spawner
302                .spawn(async move { latch1.count_down().await })
303                .unwrap();
304        }
305
306        pool.run();
307
308        for _ in 0..100 {
309            let latch1 = latch.clone();
310            spawner.spawn(async move { latch1.wait().await }).unwrap();
311        }
312
313        pool.run();
314    }
315
316    #[test]
317    fn countdownlatch_count_test() {
318        use std::sync::atomic::{AtomicUsize, Ordering};
319        use std::sync::Arc;
320
321        let mut pool = LocalPool::new();
322        let pre_counter = Arc::new(AtomicUsize::new(0));
323        let post_counter = Arc::new(AtomicUsize::new(0));
324
325        let spawner = pool.spawner();
326        let latch = CountDownLatch::new(1);
327
328        let latch1 = latch.clone();
329        let pre_counter1 = pre_counter.clone();
330        let post_counter1 = post_counter.clone();
331        spawner
332            .spawn(async move {
333                pre_counter1.store(latch1.count().await, Ordering::Relaxed);
334                latch1.count_down().await;
335                post_counter1.store(latch1.count().await, Ordering::Relaxed);
336            })
337            .unwrap();
338
339        pool.run();
340
341        assert_eq!(1, pre_counter.load(Ordering::Relaxed));
342        assert_eq!(0, post_counter.load(Ordering::Relaxed));
343    }
344
345    #[test]
346    fn wait_with_timeout_test() {
347        use futures_timer::Delay;
348        use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
349        use std::sync::Arc;
350
351        let mut pool = LocalPool::new();
352        let counter = Arc::new(AtomicUsize::new(1));
353        let no_timeout = Arc::new(AtomicBool::new(true));
354
355        let spawner = pool.spawner();
356        let latch = CountDownLatch::new(1);
357
358        let latch1 = latch.clone();
359        spawner
360            .spawn(async move {
361                Delay::new(Duration::from_secs(3)).await;
362                latch1.count_down().await;
363            })
364            .unwrap();
365
366        let counter1 = counter.clone();
367        let no_timeout1 = no_timeout.clone();
368        spawner
369            .spawn(async move {
370                let result = latch.wait_for(Duration::from_secs(1)).await;
371                counter1.store(latch.count().await, Ordering::Relaxed);
372                no_timeout1.store(result, Ordering::Relaxed);
373            })
374            .unwrap();
375
376        pool.run();
377
378        assert_eq!(1, counter.load(Ordering::Relaxed));
379        assert_eq!(false, no_timeout.load(Ordering::Relaxed));
380    }
381
382    #[test]
383    fn stress_test() {
384        let mut pool = LocalPool::new();
385
386        let n = 10_000;
387        let latch = CountDownLatch::new(n);
388
389        let spawner = pool.spawner();
390
391        for _ in 0..(2 * n) {
392            let latch1 = latch.clone();
393            spawner.spawn(async move {
394                latch1.wait().await;
395            }).unwrap();
396        }
397
398        for _ in 0..n {
399            let latch2 = latch.clone();
400            spawner.spawn(async move {
401                latch2.count_down().await;
402            }).unwrap();
403        }
404
405        for _ in 0..(2 * n) {
406            let latch3 = latch.clone();
407            spawner.spawn(async move {
408                latch3.wait().await;
409            }).unwrap();
410        }
411
412        pool.run();
413    }
414
415    #[test]
416    fn parallel_stress_test() {
417        let pool = ThreadPool::builder().pool_size(4).create().unwrap();
418
419        let n = 10_000;
420        let latch = CountDownLatch::new(n);
421
422        let mut handles = Vec::with_capacity(5 * n);
423
424        for _ in 0..(2 * n) {
425            let latch1 = latch.clone();
426            handles.push(pool.spawn_with_handle(async move {
427                latch1.wait().await;
428            }).unwrap());
429        }
430
431        for _ in 0..n {
432            let latch2 = latch.clone();
433            handles.push(pool.spawn_with_handle(async move {
434                latch2.count_down().await;
435            }).unwrap());
436        }
437
438        for _ in 0..(2 * n) {
439            let latch3 = latch.clone();
440            handles.push(pool.spawn_with_handle(async move {
441                latch3.wait().await;
442            }).unwrap());
443        }
444
445        futures_executor::block_on(join_all(handles));
446    }
447
448    #[test]
449    fn countdownlatch_set_zero_test() {
450        let mut pool = LocalPool::new();
451
452        let spawner = pool.spawner();
453        let latch = CountDownLatch::new(1);
454
455        let latch1 = latch.clone();
456        spawner.spawn(latch1.wait()).unwrap();
457
458        let latch2 = latch.clone();
459        spawner
460            .spawn(async move {
461                latch2.set(0).await;
462            })
463            .unwrap();
464
465        pool.run();
466    }
467
468    #[test]
469    fn countdownlatch_reuse_test() {
470        let mut pool = LocalPool::new();
471
472        let spawner = pool.spawner();
473        let latch = CountDownLatch::new(0);
474
475        let latch1 = latch.clone();
476        spawner
477            .spawn(async move {
478                latch1.set(1).await;
479            })
480            .unwrap();
481
482        pool.run();
483
484        let latch2 = latch.clone();
485        spawner.spawn(latch2.wait()).unwrap();
486
487        let latch3 = latch.clone();
488        spawner.spawn(async move {
489            latch3.count_down().await;
490        }).unwrap();
491
492        pool.run();
493    }
494}