1use futures::future::Either;
49use futures_timer::Delay;
50use std::future::Future;
51use std::pin::Pin;
52use std::sync::Arc;
53use std::task::{Context, Poll, Waker};
54use std::time::Duration;
55use async_lock::{Mutex};
56use async_lock::futures::{LockArc};
57
58struct CountDownState {
59 count: usize,
60 wakers: Vec<Waker>,
61}
62
63impl CountDownLatch {
64 pub fn new(count: usize) -> CountDownLatch {
66 CountDownLatch {
67 state: Arc::new(Mutex::new(CountDownState {
68 count,
69 wakers: vec![],
70 })),
71 }
72 }
73
74 pub async fn count(&self) -> usize {
76 let state = self.state.lock().await;
77 state.count
78 }
79
80 pub fn wait(&self) -> impl Future<Output = ()> {
82 WaitFuture {
83 latch: self.clone(),
84 state_lock: None,
85 }
86 }
87
88 pub async fn wait_for(&self, timeout: Duration) -> bool {
92 let delay = Delay::new(timeout);
93 match futures::future::select(delay, self.wait()).await {
94 Either::Left(_) => false,
95 Either::Right(_) => true,
96 }
97 }
98
99 pub async fn count_down(&self) {
101 let mut state = self.state.lock().await;
102 let count = state.count.saturating_sub(1);
103 state.set(count);
104 }
105
106 pub async fn set(&self, count: usize) {
108 let mut state = self.state.lock().await;
109 state.set(count);
110 }
111}
112
113impl CountDownState {
114 fn set(&mut self, count: usize) {
115 self.count = count;
116 if count == 0 {
117 for waker in self.wakers.drain(..) {
118 waker.wake();
119 }
120 }
121 }
122}
123
124struct WaitFuture {
125 latch: CountDownLatch,
126 state_lock: Option<Box<LockArc<CountDownState>>>,
127}
128
129impl Future for WaitFuture {
130 type Output = ();
131
132 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
133 loop {
134 match self.state_lock.take() {
135 Some(mut state_lock) => {
136 return match unsafe { Pin::new_unchecked(state_lock.as_mut()) }.poll(cx) {
137 Poll::Ready(mut guard) => {
138 if guard.count > 0 {
139 for waker in guard.wakers.iter() {
140 if waker.will_wake(cx.waker()) {
141 return Poll::Pending
142 }
143 }
144 guard.wakers.push(cx.waker().clone());
145 Poll::Pending
146 } else {
147 for waker in guard.wakers.drain(..) {
148 waker.wake();
149 }
150 Poll::Ready(())
151 }
152 }
153 Poll::Pending => {
154 self.state_lock = Some(state_lock);
157 Poll::Pending
158 }
159 }
160 }
161 None => {
162 self.state_lock = Some(Box::new(self.latch.state.lock_arc()));
163 }
164 }
165 }
166 }
167}
168
169#[derive(Clone)]
172pub struct CountDownLatch {
173 state: Arc<Mutex<CountDownState>>,
174}
175
176#[cfg(test)]
177mod tests {
178 use super::CountDownLatch;
179 use futures_executor::{LocalPool, ThreadPool};
180 use futures_util::task::SpawnExt;
181 use std::time::Duration;
182 use futures_util::future::{join, join_all};
183
184 #[test]
185 fn countdownlatch_test() {
186 let mut pool = LocalPool::new();
187
188 let spawner = pool.spawner();
189 let latch = CountDownLatch::new(2);
190 let latch1 = latch.clone();
191 spawner
192 .spawn(async move { latch1.count_down().await })
193 .unwrap();
194
195 let latch2 = latch.clone();
196 spawner
197 .spawn(async move { latch2.count_down().await })
198 .unwrap();
199
200 let latch3 = latch.clone();
201 spawner
202 .spawn(async move {
203 latch3.wait().await;
204 })
205 .unwrap();
206
207 spawner
208 .spawn(async move {
209 latch.wait().await;
210 })
211 .unwrap();
212
213 pool.run();
214 }
215
216 #[test]
217 fn countdownlatch_pre_wait_test() {
218 let mut pool = LocalPool::new();
219
220 let spawner = pool.spawner();
221 let latch = CountDownLatch::new(1);
222
223 let latch1 = latch.clone();
224 spawner
225 .spawn(async move { latch1.wait().await })
226 .unwrap();
227
228 spawner
229 .spawn(async move { latch.count_down().await })
230 .unwrap();
231
232 pool.run();
233 }
234
235 #[test]
236 fn countdownlatch_parallel_pre_wait_test() {
237 let pool = ThreadPool::builder().pool_size(4).create().unwrap();
238
239 let latch = CountDownLatch::new(1);
240
241 let latch1 = latch.clone();
242 let handle1 = pool
243 .spawn_with_handle(async move { latch1.wait().await })
244 .unwrap();
245
246 let handle2 = pool
247 .spawn_with_handle(async move { latch.count_down().await })
248 .unwrap();
249
250 futures_executor::block_on(join(handle1, handle2));
251 }
252
253 #[test]
254 fn countdownlatch_concurrent_test() {
255 let mut pool = LocalPool::new();
256
257 let spawner = pool.spawner();
258 let latch = CountDownLatch::new(100);
259
260 for _ in 0..200 {
261 let latch1 = latch.clone();
262 spawner
263 .spawn(async move { latch1.count_down().await })
264 .unwrap();
265 }
266
267 for _ in 0..100 {
268 let latch1 = latch.clone();
269 spawner.spawn(async move { latch1.wait().await }).unwrap();
270 }
271
272 pool.run();
273 }
274
275 #[test]
276 fn countdownlatch_no_wait_test() {
277 let mut pool = LocalPool::new();
278
279 let spawner = pool.spawner();
280 let latch = CountDownLatch::new(100);
281
282 for _ in 0..200 {
283 let latch1 = latch.clone();
284 spawner
285 .spawn(async move { latch1.count_down().await })
286 .unwrap();
287 }
288
289 pool.run();
290 }
291
292 #[test]
293 fn countdownlatch_post_wait_test() {
294 let mut pool = LocalPool::new();
295
296 let spawner = pool.spawner();
297 let latch = CountDownLatch::new(100);
298
299 for _ in 0..200 {
300 let latch1 = latch.clone();
301 spawner
302 .spawn(async move { latch1.count_down().await })
303 .unwrap();
304 }
305
306 pool.run();
307
308 for _ in 0..100 {
309 let latch1 = latch.clone();
310 spawner.spawn(async move { latch1.wait().await }).unwrap();
311 }
312
313 pool.run();
314 }
315
316 #[test]
317 fn countdownlatch_count_test() {
318 use std::sync::atomic::{AtomicUsize, Ordering};
319 use std::sync::Arc;
320
321 let mut pool = LocalPool::new();
322 let pre_counter = Arc::new(AtomicUsize::new(0));
323 let post_counter = Arc::new(AtomicUsize::new(0));
324
325 let spawner = pool.spawner();
326 let latch = CountDownLatch::new(1);
327
328 let latch1 = latch.clone();
329 let pre_counter1 = pre_counter.clone();
330 let post_counter1 = post_counter.clone();
331 spawner
332 .spawn(async move {
333 pre_counter1.store(latch1.count().await, Ordering::Relaxed);
334 latch1.count_down().await;
335 post_counter1.store(latch1.count().await, Ordering::Relaxed);
336 })
337 .unwrap();
338
339 pool.run();
340
341 assert_eq!(1, pre_counter.load(Ordering::Relaxed));
342 assert_eq!(0, post_counter.load(Ordering::Relaxed));
343 }
344
345 #[test]
346 fn wait_with_timeout_test() {
347 use futures_timer::Delay;
348 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
349 use std::sync::Arc;
350
351 let mut pool = LocalPool::new();
352 let counter = Arc::new(AtomicUsize::new(1));
353 let no_timeout = Arc::new(AtomicBool::new(true));
354
355 let spawner = pool.spawner();
356 let latch = CountDownLatch::new(1);
357
358 let latch1 = latch.clone();
359 spawner
360 .spawn(async move {
361 Delay::new(Duration::from_secs(3)).await;
362 latch1.count_down().await;
363 })
364 .unwrap();
365
366 let counter1 = counter.clone();
367 let no_timeout1 = no_timeout.clone();
368 spawner
369 .spawn(async move {
370 let result = latch.wait_for(Duration::from_secs(1)).await;
371 counter1.store(latch.count().await, Ordering::Relaxed);
372 no_timeout1.store(result, Ordering::Relaxed);
373 })
374 .unwrap();
375
376 pool.run();
377
378 assert_eq!(1, counter.load(Ordering::Relaxed));
379 assert_eq!(false, no_timeout.load(Ordering::Relaxed));
380 }
381
382 #[test]
383 fn stress_test() {
384 let mut pool = LocalPool::new();
385
386 let n = 10_000;
387 let latch = CountDownLatch::new(n);
388
389 let spawner = pool.spawner();
390
391 for _ in 0..(2 * n) {
392 let latch1 = latch.clone();
393 spawner.spawn(async move {
394 latch1.wait().await;
395 }).unwrap();
396 }
397
398 for _ in 0..n {
399 let latch2 = latch.clone();
400 spawner.spawn(async move {
401 latch2.count_down().await;
402 }).unwrap();
403 }
404
405 for _ in 0..(2 * n) {
406 let latch3 = latch.clone();
407 spawner.spawn(async move {
408 latch3.wait().await;
409 }).unwrap();
410 }
411
412 pool.run();
413 }
414
415 #[test]
416 fn parallel_stress_test() {
417 let pool = ThreadPool::builder().pool_size(4).create().unwrap();
418
419 let n = 10_000;
420 let latch = CountDownLatch::new(n);
421
422 let mut handles = Vec::with_capacity(5 * n);
423
424 for _ in 0..(2 * n) {
425 let latch1 = latch.clone();
426 handles.push(pool.spawn_with_handle(async move {
427 latch1.wait().await;
428 }).unwrap());
429 }
430
431 for _ in 0..n {
432 let latch2 = latch.clone();
433 handles.push(pool.spawn_with_handle(async move {
434 latch2.count_down().await;
435 }).unwrap());
436 }
437
438 for _ in 0..(2 * n) {
439 let latch3 = latch.clone();
440 handles.push(pool.spawn_with_handle(async move {
441 latch3.wait().await;
442 }).unwrap());
443 }
444
445 futures_executor::block_on(join_all(handles));
446 }
447
448 #[test]
449 fn countdownlatch_set_zero_test() {
450 let mut pool = LocalPool::new();
451
452 let spawner = pool.spawner();
453 let latch = CountDownLatch::new(1);
454
455 let latch1 = latch.clone();
456 spawner.spawn(latch1.wait()).unwrap();
457
458 let latch2 = latch.clone();
459 spawner
460 .spawn(async move {
461 latch2.set(0).await;
462 })
463 .unwrap();
464
465 pool.run();
466 }
467
468 #[test]
469 fn countdownlatch_reuse_test() {
470 let mut pool = LocalPool::new();
471
472 let spawner = pool.spawner();
473 let latch = CountDownLatch::new(0);
474
475 let latch1 = latch.clone();
476 spawner
477 .spawn(async move {
478 latch1.set(1).await;
479 })
480 .unwrap();
481
482 pool.run();
483
484 let latch2 = latch.clone();
485 spawner.spawn(latch2.wait()).unwrap();
486
487 let latch3 = latch.clone();
488 spawner.spawn(async move {
489 latch3.count_down().await;
490 }).unwrap();
491
492 pool.run();
493 }
494}