Skip to main content

xet_runtime/utils/
singleflight.rs

1//! A singleflight implementation for tokio.
2//!
3//! Inspired by [async_singleflight](https://crates.io/crates/async_singleflight).
4//!
5//! # Examples
6//!
7//! ```no_run
8//! use std::sync::Arc;
9//! use std::time::Duration;
10//!
11//! use futures::future::join_all;
12//! use xet_runtime::utils::singleflight::Group;
13//!
14//! const RES: usize = 7;
15//!
16//! async fn expensive_fn() -> Result<usize, ()> {
17//!     tokio::time::sleep(Duration::new(1, 500)).await;
18//!     Ok(RES)
19//! }
20//!
21//! #[tokio::main]
22//! async fn main() {
23//!     let g = Arc::new(Group::<_, ()>::new());
24//!     let mut handlers = Vec::new();
25//!     for _ in 0..10 {
26//!         let g = g.clone();
27//!         handlers.push(tokio::spawn(async move {
28//!             let res = g.work("key", expensive_fn()).await.0;
29//!             let r = res.unwrap();
30//!             println!("{}", r);
31//!         }));
32//!     }
33//!
34//!     join_all(handlers).await;
35//! }
36//! ```
37
38use std::collections::HashMap;
39use std::fmt::Debug;
40use std::future::Future;
41use std::marker::PhantomData;
42use std::pin::Pin;
43use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
44use std::sync::{Arc, Mutex, RwLock};
45use std::task::{Context, Poll, ready};
46
47use futures::future::Either;
48use pin_project::{pin_project, pinned_drop};
49use tokio::runtime::Handle;
50use tokio::sync::Notify;
51use tracing::{debug, error};
52
53pub use super::errors::SingleflightError;
54use crate::error_printer::ErrorPrinter;
55
56type SingleflightResult<T, E> = Result<T, SingleflightError<E>>;
57type CallMap<T, E> = HashMap<String, Arc<Call<T, E>>>;
58type CallCreate<'a, T, E> = (Arc<Call<T, E>>, CreateGuard<'a, T, E>);
59
60// Marker Traits to help make the code a bit cleaner.
61
62/// ResultType indicates the success type for a singleflight [Group].
63/// Since the actual processing might occur on a separate thread,
64/// we need to type to be [Send] + [Sync]. It also needs to be [Clone]
65/// so that we can clone the response across many tasks
66pub trait ResultType: Send + Clone + Sync + Debug {}
67impl<T: Send + Clone + Sync + Debug> ResultType for T {}
68
69/// Indicates the Error type for a singleflight [Group].
70/// The response might have been generated on a separate
71/// thread, thus, we need this type to be [Send] + [Sync].
72pub trait ResultError: Send + Debug + Sync {}
73impl<E: Send + Debug + Sync> ResultError for E {}
74
75/// Futures provided to a singleflight Group must produce a [Result<T, E>]
76/// for some T, E. This future must also be [Send]
77/// as it could be spawned as a tokio task.
78pub trait TaskFuture<T, E>: Future<Output = Result<T, E>> + Send {}
79impl<T, E, F: Future<Output = Result<T, E>> + Send> TaskFuture<T, E> for F {}
80
81/// Call represents the (eventual) results of running some Future.
82///
83/// It consists of a condition variable that can be waited upon until the
84/// owner task [completes](Call::complete) it.
85///
86/// Tasks can get the Call's result using [get_future](Call::get_future)
87/// to get a Future to await. Or they can call [get](Call::get)
88/// to try and get the result synchronously if the Call is already complete.
89#[derive(Debug, Clone)]
90struct Call<T, E>
91where
92    T: ResultType,
93    E: ResultError,
94{
95    // The condition variable
96    nt: Arc<Notify>,
97
98    // The result of the operation. Kept under a RWLock that is expected
99    // to be write-once, read-many.
100    // We use a lock instead of an AtomicPtr since updating the result and
101    // notifying the waiters needs to be atomic to avoid tasks missing the
102    // notification or to avoid tasks reading an empty value.
103    //
104    // Also important to note is that this lock is synchronous as we need
105    // to be able to store the value in the [OwnerTask::drop] function if
106    // the underlying future panics. Thus, complete() must be synchronous.
107    // This is ok since we are never holding the mutex across an await
108    // boundary (all functions are synchronous), and the critical section
109    // is fast.
110    res: Arc<RwLock<Option<SingleflightResult<T, E>>>>,
111
112    // Number of tasks that were waiting
113    num_waiters: Arc<AtomicU16>,
114}
115
116impl<T, E> Call<T, E>
117where
118    T: ResultType,
119    E: ResultError,
120{
121    fn new() -> Self {
122        Self {
123            nt: Arc::new(Notify::new()),
124            res: Arc::new(RwLock::new(None)),
125            num_waiters: Arc::new(AtomicU16::new(0)),
126        }
127    }
128
129    /// Completes the Call. This involves storing the provided result into the Call
130    /// and notifying all waiters that there is a value.
131    fn complete(&self, res: SingleflightResult<T, E>) {
132        // write-lock
133        let mut val = self.res.write().unwrap();
134        *val = Some(res);
135        self.nt.notify_waiters();
136        let num_waiters = self.num_waiters.load(Ordering::SeqCst);
137        debug!("Completed Call with: {} waiters", num_waiters);
138    }
139
140    /// Gets a Future that can be awaited to get the singleflight results, whenever that
141    /// might occur.
142    fn get_future(&self) -> impl Future<Output = SingleflightResult<T, E>> + '_ {
143        // read-lock
144        let res = self.res.read().unwrap();
145        if let Some(result) = res.clone() {
146            // we already have the result, provide it back to the caller.
147            debug!("Call already completed");
148            Either::Left(async move { result })
149        } else {
150            // no result yet, we are a waiter task.
151            self.num_waiters.fetch_add(1, Ordering::SeqCst);
152            debug!("Adding to Call's Notify");
153
154            // Note that the `notified()` needs to be performed outside the async
155            // block since we need to register our waiting within this read-lock
156            // or else, we might miss the owner task's notification.
157            let notified = self.nt.notified();
158            Either::Right(async move {
159                notified.await;
160                self.get()
161            })
162        }
163    }
164
165    /// Gets the result for the Call if set.
166    /// If not set, then [SingleflightError::NoResult] is returned
167    fn get(&self) -> SingleflightResult<T, E> {
168        let res = self.res.read().unwrap();
169        res.clone().unwrap_or(Err(SingleflightError::NoResult))
170    }
171}
172
173/// Group represents a class of work and creates a space in which units of work
174/// can be executed with duplicate suppression.
175#[derive(Debug)]
176pub struct Group<T, E>
177where
178    T: ResultType + 'static,
179    E: ResultError,
180{
181    call_map: Arc<Mutex<CallMap<T, E>>>,
182    _marker: PhantomData<fn(E)>,
183}
184
185impl<T, E: 'static> Default for Group<T, E>
186where
187    T: ResultType + 'static,
188    E: ResultError,
189{
190    fn default() -> Self {
191        Self {
192            call_map: Arc::new(Default::default()),
193            _marker: Default::default(),
194        }
195    }
196}
197
198impl<T, E: 'static> Group<T, E>
199where
200    T: ResultType + 'static,
201    E: ResultError,
202{
203    /// Create a new Group to do work with.
204    pub fn new() -> Group<T, E> {
205        Self::default()
206    }
207
208    /// Execute and return the value for a given function, making sure that only one
209    /// operation is in-flight at a given moment. If a duplicate call comes in, that caller will
210    /// wait until the original call completes and return the same value.
211    /// The second return value indicates whether the call is the owner.
212    ///
213    /// On error, the owner will receive the original error returned from the function
214    /// as a SingleflightError::InternalError, all waiters will receive a copy of the
215    /// error message wrapped in a SingleflightError::WaiterInternalError.
216    /// This is due to the fact that most error types don't implement Clone (e.g. anyhow::Error)
217    /// and thus we can't clone the original error for all the waiters.
218    pub async fn work(
219        &self,
220        key: &str,
221        fut: impl TaskFuture<T, E> + 'static,
222    ) -> (Result<T, SingleflightError<E>>, bool) {
223        // Get the call to use and a handle for retrieving the results
224        let (call, create_guard) = match self.get_call_or_create(key) {
225            Ok((call, create_guard)) => (call, create_guard),
226            Err(err) => return (Err(err), false),
227        };
228        // Use reference for created since we don't want it to drop until after this is done.
229        match &create_guard {
230            CreateGuard::Owned(_, _) => {
231                // spawn the owner task and wait
232                let owner_task = OwnerTask::new(fut, call.clone());
233                let owner_handle = Handle::current().spawn(owner_task);
234
235                // wait for the owner task to come back with results
236                match owner_handle.await {
237                    Ok(res) => (res, true),
238                    Err(e) => (Err(SingleflightError::JoinError(e.to_string())), true),
239                }
240            },
241            CreateGuard::Waiter => (call.get_future().await, false),
242        }
243    }
244
245    /// Like work but only returns the result, dumps the bool result value
246    pub async fn work_dump_caller_info(
247        &self,
248        key: &str,
249        fut: impl TaskFuture<T, E> + 'static,
250    ) -> Result<T, SingleflightError<E>> {
251        let (result, _) = self.work(key, fut).await;
252        result
253    }
254
255    /// Gets the [Call] to use from the call_map or else inserts a new Call
256    /// into the map.  
257    ///
258    /// Returns the [Call] that should be used and whether it was created or
259    /// not.
260    ///
261    /// Returns an error if the underlying `call_map` Mutex is poisoned.
262    fn get_call_or_create<'a>(&'a self, key: &'a str) -> Result<CallCreate<'a, T, E>, SingleflightError<E>> {
263        let mut m = self
264            .call_map
265            .lock()
266            .log_error("Failed to lock call map")
267            .map_err(|_| SingleflightError::GroupLockPoisoned)?;
268        if let Some(c) = m.get(key).cloned() {
269            Ok((c, CreateGuard::Waiter))
270        } else {
271            let c = Arc::new(Call::new());
272            let our_call = c.clone();
273            m.insert(key.to_owned(), c);
274            Ok((our_call, CreateGuard::Owned(self, key)))
275        }
276    }
277
278    /// Removes the [Call] associated with the Key. If there is no such [Call],
279    /// or the Group's `call_map` Mutex is poisoned, then an error is returned.
280    fn remove_call(&self, key: &str) -> SingleflightResult<(), E> {
281        let mut m = self
282            .call_map
283            .lock()
284            .log_error("Failed to lock call map")
285            .map_err(|_| SingleflightError::GroupLockPoisoned)?;
286        m.remove(key).ok_or(SingleflightError::CallMissing)?;
287        Ok(())
288    }
289}
290
291/// RAII for creating a Call in a Group. The guard indicates whether the Call is:
292/// - Owned - the current task owns the Call and will remove it from the Group's CallMap on [Self::drop]
293/// - Waiter - the current task is a waiter
294enum CreateGuard<'a, T, E>
295where
296    T: ResultType + 'static,
297    E: ResultError + 'static,
298{
299    Owned(&'a Group<T, E>, &'a str),
300    Waiter,
301}
302
303impl<T, E> Drop for CreateGuard<'_, T, E>
304where
305    T: ResultType + 'static,
306    E: ResultError + 'static,
307{
308    fn drop(&mut self) {
309        match self {
310            CreateGuard::Owned(group, key) => group
311                .remove_call(key)
312                .inspect_err(|err| error!(?err, "Couldn't remove call from map"))
313                .unwrap(),
314            CreateGuard::Waiter => {},
315        }
316    }
317}
318
319/// Defines a task to own the polling the Future and ensure the call is
320/// updated (i.e. result stored and waiters notified) when the Future
321/// completes (even if the future panics).
322///
323/// We can guarantee that the [Call] gets notified even during a Panic
324/// since tokio tasks will catch panics and call the `drop()` function.
325///
326/// For more info, see: https://github.com/tokio-rs/tokio/blob/4eed411519783ef6f58cbf74f886f91142b5cfa6/tokio/src/runtime/task/harness.rs#L453-L459
327/// and the discussion on: https://users.rust-lang.org/t/how-panic-calls-drop-functions/53663/8
328///
329/// Pin'ed since it is a Future implementation.
330#[pin_project(PinnedDrop)]
331#[must_use = "futures do nothing unless you `.await` or poll them"]
332struct OwnerTask<T, E, F>
333where
334    T: ResultType,
335    E: ResultError,
336    F: TaskFuture<T, E>,
337{
338    #[pin]
339    fut: F,
340    got_response: AtomicBool,
341    call: Arc<Call<T, E>>,
342}
343
344impl<T, E, F> OwnerTask<T, E, F>
345where
346    T: ResultType,
347    E: ResultError,
348    F: TaskFuture<T, E>,
349{
350    fn new(fut: F, call: Arc<Call<T, E>>) -> Self {
351        Self {
352            fut,
353            got_response: AtomicBool::new(false),
354            call,
355        }
356    }
357}
358
359impl<T, E, F> Future for OwnerTask<T, E, F>
360where
361    T: ResultType,
362    E: ResultError,
363    F: TaskFuture<T, E>,
364{
365    type Output = Result<T, SingleflightError<E>>;
366
367    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
368        let this = self.project();
369        let res: Result<T, E> = ready!(this.fut.poll(cx));
370        let res = res.map_err(|e| SingleflightError::InternalError(e));
371        // we have a result, so store it into our call and notify all waiters.
372        let call = this.call;
373        this.got_response.store(true, Ordering::SeqCst);
374        call.complete(res.clone());
375        Poll::Ready(res)
376    }
377}
378
379#[pinned_drop]
380impl<T, E, F> PinnedDrop for OwnerTask<T, E, F>
381where
382    T: ResultType,
383    E: ResultError,
384    F: TaskFuture<T, E>,
385{
386    fn drop(self: Pin<&mut Self>) {
387        // If we don't have a result stored in the call, then we panicked and
388        // should store an error, notifying all waiters of the panic.
389        let this = self.project();
390        if !this.got_response.load(Ordering::SeqCst) {
391            let call = this.call;
392            call.complete(Err(SingleflightError::OwnerPanicked));
393        }
394    }
395}
396
397#[cfg(test)]
398pub(crate) mod tests {
399    use std::sync::Arc;
400    use std::sync::atomic::{AtomicU32, Ordering};
401    use std::time::Duration;
402
403    use futures::future::join_all;
404    use tokio::runtime::Handle;
405    use tokio::task::JoinHandle;
406    use tokio::time::timeout;
407
408    use super::super::errors::SingleflightError;
409    use super::{Call, Group, OwnerTask};
410    use crate::core::XetRuntime;
411
412    /// A period of time for waiters to wait for a notification from the owner
413    /// task. This is expected to be sufficient time for the test futures to
414    /// complete. Thus, if we hit this timeout, then likely, there is something
415    /// wrong with the [Call] notifications.
416    pub(crate) const WAITER_TIMEOUT: Duration = Duration::from_millis(100);
417
418    const RES: usize = 7;
419
420    async fn return_res() -> Result<usize, ()> {
421        Ok(RES)
422    }
423
424    async fn expensive_fn(x: Arc<AtomicU32>, resp: usize) -> Result<usize, ()> {
425        tokio::time::sleep(Duration::new(1, 0)).await;
426        x.fetch_add(1, Ordering::SeqCst);
427        Ok(resp)
428    }
429
430    #[test]
431    fn test_simple_with_threadpool() {
432        let threadpool = Arc::new(XetRuntime::new().unwrap());
433        let g = Group::new();
434        let res = threadpool
435            .bridge_sync(async move { g.work("key", return_res()).await })
436            .unwrap()
437            .0;
438        let r = res.unwrap();
439        assert_eq!(r, RES);
440    }
441
442    #[tokio::test]
443    async fn test_simple() {
444        let g = Group::new();
445        let res = g.work("key", return_res()).await.0;
446        let r = res.unwrap();
447        assert_eq!(r, RES);
448    }
449
450    #[test]
451    #[cfg_attr(feature = "smoke-test", ignore)]
452    fn test_multiple_threads_with_threadpool() {
453        let times_called = Arc::new(AtomicU32::new(0));
454        let threadpool = Arc::new(XetRuntime::new().unwrap());
455        let g: Arc<Group<usize, ()>> = Arc::new(Group::new());
456        let mut handlers: Vec<JoinHandle<(usize, bool)>> = Vec::new();
457        let threadpool_ = threadpool.clone();
458        let tasks = async move {
459            for _ in 0..10 {
460                let g = g.clone();
461                let counter = times_called.clone();
462                handlers.push(threadpool_.spawn(async move {
463                    let tup = g.work("key", expensive_fn(counter, RES)).await;
464                    let res = tup.0;
465                    let fn_response = res.unwrap();
466                    (fn_response, tup.1)
467                }));
468            }
469
470            let num_callers = join_all(handlers)
471                .await
472                .into_iter()
473                .map(|r| r.unwrap())
474                .filter(|(val, is_caller)| {
475                    assert_eq!(*val, RES);
476                    *is_caller
477                })
478                .count();
479            assert_eq!(1, num_callers);
480            assert_eq!(1, times_called.load(Ordering::SeqCst));
481        };
482        threadpool.bridge_sync(tasks).unwrap();
483    }
484
485    #[tokio::test]
486    #[cfg_attr(feature = "smoke-test", ignore)]
487    async fn test_multiple_threads() {
488        let times_called = Arc::new(AtomicU32::new(0));
489        let g: Arc<Group<usize, ()>> = Arc::new(Group::new());
490        let mut handlers: Vec<JoinHandle<(usize, bool)>> = Vec::new();
491        for _ in 0..10 {
492            let g = g.clone();
493            let counter = times_called.clone();
494            handlers.push(Handle::current().spawn(async move {
495                let tup = g.work("key", expensive_fn(counter, RES)).await;
496                let res = tup.0;
497                let fn_response = res.unwrap();
498                (fn_response, tup.1)
499            }));
500        }
501
502        let num_callers = join_all(handlers)
503            .await
504            .into_iter()
505            .map(|r| r.unwrap())
506            .filter(|(val, is_caller)| {
507                assert_eq!(*val, RES);
508                *is_caller
509            })
510            .count();
511        assert_eq!(1, num_callers);
512        assert_eq!(1, times_called.load(Ordering::SeqCst));
513    }
514
515    #[tokio::test]
516    #[cfg_attr(feature = "smoke-test", ignore)]
517    async fn test_error() {
518        let times_called = Arc::new(AtomicU32::new(0));
519
520        async fn expensive_error_fn(x: Arc<AtomicU32>) -> Result<usize, &'static str> {
521            tokio::time::sleep(Duration::new(1, 500)).await;
522            x.fetch_add(1, Ordering::SeqCst);
523            Err("Error")
524        }
525
526        let g: Arc<Group<usize, &'static str>> = Arc::new(Group::new());
527        let mut handlers = Vec::new();
528
529        for _ in 0..10 {
530            let g = g.clone();
531            let counter = times_called.clone();
532            handlers.push(Handle::current().spawn(async move {
533                let tup = g.work("key", expensive_error_fn(counter)).await;
534                let res = tup.0;
535                assert!(res.is_err());
536                tup.1
537            }));
538        }
539
540        let num_callers = join_all(handlers).await.into_iter().map(|r| r.unwrap()).filter(|b| *b).count();
541        assert_eq!(1, num_callers);
542        assert_eq!(1, times_called.load(Ordering::SeqCst));
543    }
544
545    #[tokio::test]
546    #[cfg_attr(feature = "smoke-test", ignore)]
547    async fn test_multiple_keys() {
548        let times_called_x = Arc::new(AtomicU32::new(0));
549        let times_called_y = Arc::new(AtomicU32::new(0));
550
551        let mut handlers1 = call_success_n_times(5, "key", times_called_x.clone(), 7);
552        let mut handlers2 = call_success_n_times(5, "key2", times_called_y.clone(), 13);
553        handlers1.append(&mut handlers2);
554        let count_x = AtomicU32::new(0);
555        let count_y = AtomicU32::new(0);
556
557        let num_callers = join_all(handlers1)
558            .await
559            .into_iter()
560            .map(|r| r.unwrap())
561            .filter(|(val, is_caller)| {
562                if *val == 7 {
563                    count_x.fetch_add(1, Ordering::SeqCst);
564                } else if *val == 13 {
565                    count_y.fetch_add(1, Ordering::SeqCst);
566                } else {
567                    panic!("joined a number not expected: {}", *val);
568                }
569                *is_caller
570            })
571            .count();
572        assert_eq!(2, num_callers);
573        assert_eq!(5, count_x.load(Ordering::SeqCst));
574        assert_eq!(5, count_y.load(Ordering::SeqCst));
575        assert_eq!(1, times_called_x.load(Ordering::SeqCst));
576        assert_eq!(1, times_called_y.load(Ordering::SeqCst));
577    }
578
579    // must be run in a #[tokio::test]
580    fn call_success_n_times(times: usize, key: &str, c: Arc<AtomicU32>, val: usize) -> Vec<JoinHandle<(usize, bool)>> {
581        let g: Arc<Group<usize, ()>> = Arc::new(Group::new());
582        let mut handlers = Vec::new();
583        for _ in 0..times {
584            let g = g.clone();
585            let counter = c.clone();
586            let k = key.to_owned();
587            handlers.push(Handle::current().spawn(async move {
588                let tup = g.work(k.as_str(), expensive_fn(counter, val)).await;
589                let res = tup.0;
590                let fn_response = res.unwrap();
591                (fn_response, tup.1)
592            }));
593        }
594        handlers
595    }
596
597    #[tokio::test]
598    async fn test_owner_task_future_impl() {
599        const VAL: i32 = 10;
600        let future = async { Ok::<i32, String>(VAL) };
601        let call = Arc::new(Call::new());
602        let owner_task = OwnerTask::new(future, call.clone());
603        let result = tokio::spawn(owner_task).await;
604        assert_eq!(VAL, result.unwrap().unwrap());
605        assert_eq!(VAL, call.get().unwrap());
606    }
607
608    #[tokio::test]
609    async fn test_owner_task_future_notify() {
610        const VAL: i32 = 10;
611        let future = async { Ok::<i32, String>(VAL) };
612        let call = Arc::new(Call::new());
613        let call_waiter = call.clone();
614        let waiter_task = async move {
615            let waiter_future = call_waiter.get_future();
616            assert_eq!(VAL, waiter_future.await.unwrap());
617        };
618        let waiter_handle = tokio::spawn(waiter_task);
619        let owner_task = OwnerTask::new(future, call.clone());
620        let result = tokio::spawn(owner_task).await;
621        timeout(WAITER_TIMEOUT, waiter_handle).await.unwrap().unwrap();
622        assert_eq!(VAL, result.unwrap().unwrap());
623        assert_eq!(VAL, call.get().unwrap());
624        assert_eq!(1, call.num_waiters.load(Ordering::SeqCst)) // we should have had 1 waiter
625    }
626
627    #[tokio::test]
628    async fn test_owner_task_future_panic() {
629        let future = async { panic!("failing task") };
630        let call = Arc::new(Call::<i32, String>::new());
631        let call_waiter = call.clone();
632        let waiter_task = async move {
633            let waiter_future = call_waiter.get_future();
634            let result = waiter_future.await;
635            assert!(matches!(result, Err(SingleflightError::OwnerPanicked)));
636        };
637        let waiter_handle = tokio::spawn(waiter_task);
638
639        let owner_task = OwnerTask::new(future, call.clone());
640        let result = tokio::spawn(owner_task).await;
641        assert!(result.is_err());
642        timeout(WAITER_TIMEOUT, waiter_handle).await.unwrap().unwrap();
643        assert_eq!(1, call.num_waiters.load(Ordering::SeqCst)) // we should have had 1 waiter
644    }
645}
646
647#[cfg(test)]
648mod test_deadlock {
649    use std::collections::HashMap;
650    use std::sync::Arc;
651
652    use futures::StreamExt;
653    use futures::stream::iter;
654    use tests::WAITER_TIMEOUT;
655    use tokio::runtime::Handle;
656    use tokio::sync::mpsc::error::SendError;
657    use tokio::sync::mpsc::{Sender, channel};
658    use tokio::sync::{Mutex, Notify};
659    use tokio::time::timeout;
660
661    use super::{Group, tests};
662
663    #[tokio::test]
664    async fn test_deadlock() {
665        /*
666        Each spawned tokio task is expected to send some ints to the main task via a bounded buffer.
667        The ints are fetched using a futures::Buffered stream over some future. These futures will
668        call into singleflight to fetch an int.
669
670        To set up the deadlock, we have 3 tasks: main, t1, and t2 with the following dependency:
671        main is waiting to read from t1, t1 is a waiter on some element that t2 is working on,
672        t2 is blocked writing to the buffer (i.e. waiting for main to read).
673
674        to accomplish this, we spawn t1, t2. Each will start up their sub-tasks (3 at a time).
675        However, there is a dependency where task2[2] runs for some int x and task1[4] needs
676        that value, thus triggering a dependency within singleflight.
677         */
678
679        let group: Arc<Group<usize, ()>> = Arc::new(Group::new());
680        // communication channels
681        let (send1, mut recv1) = channel::<usize>(1);
682        let (send2, mut recv2) = channel::<usize>(1);
683        // Items to return on the channels from the tasks.
684        let vals1: Vec<usize> = vec![1, 2, 3, 4, SHARED_ITEM];
685        let vals2: Vec<usize> = vec![6, 7, SHARED_ITEM, 8, 9];
686
687        // waiters allows us to define the order that sub-tasks run in the underlying tasks.
688        // We need this for 2 reasons:
689        // 1. SHARED_ITEM sub-task in t2 needs to block until we can ensure that it has a waiter
690        // 2. vals2[1] needs to block to ensure that t2's SHARED_ITEM starts.
691        let waiters: Arc<Mutex<HashMap<usize, Arc<Notify>>>> = Arc::new(Mutex::new(HashMap::new()));
692        {
693            let mut guard = waiters.lock().await;
694            guard.insert(vals2[1], Arc::new(Notify::new()));
695            guard.insert(SHARED_ITEM, Arc::new(Notify::new()));
696        }
697
698        // spawn tasks
699        let t1 = Handle::current().spawn(run_task(1, group.clone(), waiters.clone(), send1, false, vals1.clone()));
700        let t2 = Handle::current().spawn(run_task(2, group.clone(), waiters.clone(), send2, true, vals2.clone()));
701
702        // try to receive all the values from task1 without getting stuck.
703        for (i, expected_val) in vals1.into_iter().enumerate() {
704            if i == 3 {
705                // resume vals2[1] to allow task2 to get "stuck" waiting on send2.send()
706                println!("[main] notifying val: {}", vals2[1]);
707                let guard = waiters.lock().await;
708                guard.get(&vals2[1]).unwrap().notify_one();
709                println!("[main] notified val: {}", vals2[1])
710            }
711            if i == 4 {
712                // resume task2's SHARED_ITEM sub-task since we now have a waiter (i.e. vals1[4]).
713                println!("[main] notifying val: {}", SHARED_ITEM);
714                let guard = waiters.lock().await;
715                guard.get(&SHARED_ITEM).unwrap().notify_one();
716                println!("[main] notified val: {}", SHARED_ITEM);
717            }
718            println!("[main] getting t1[{}]", i);
719            let res = timeout(WAITER_TIMEOUT, recv1.recv())
720                .await
721                .map_err(|_| format!("Timed out on task1 waiting for val: {}. Likely deadlock.", expected_val));
722            let val = res.unwrap().unwrap();
723            println!("[main] got val: {} from t1[{}]", val, i);
724            assert_eq!(expected_val, val);
725        }
726
727        // try to receive all the values from task2 without getting stuck.
728        for expected_val in vals2 {
729            let res = timeout(WAITER_TIMEOUT, recv2.recv())
730                .await
731                .map_err(|_| format!("Timed out on task2 waiting for val: {}. Likely deadlock.", expected_val));
732            let val = res.unwrap().unwrap();
733            assert_eq!(expected_val, val);
734        }
735
736        // make sure t1,t2 completed successfully.
737        t1.await.unwrap().unwrap();
738        t2.await.unwrap().unwrap();
739    }
740
741    const SHARED_ITEM: usize = 5;
742
743    async fn run_task(
744        id: i32,
745        g: Arc<Group<usize, ()>>,
746        waiters: Arc<Mutex<HashMap<usize, Arc<Notify>>>>,
747        send_chan: Sender<usize>,
748        should_own: bool,
749        vals: Vec<usize>,
750    ) -> Result<(), SendError<usize>> {
751        // create a buffered stream that will run at most 3 sub-tasks concurrently.
752        let mut strm = iter(vals.into_iter().map(|v| {
753            let g = g.clone();
754            let waiters = waiters.clone();
755            // get the sub-task for the given item.
756            async move {
757                println!("[task: {}] running task for: {}", id, v);
758                let (res, is_owner) = g.work(format!("{}", v).as_str(), run_fut(v, waiters)).await;
759                println!("[task: {}] completed task for: {}, is_owner: {}", id, v, is_owner);
760                if v == SHARED_ITEM {
761                    assert_eq!(should_own, is_owner);
762                }
763                res.unwrap()
764            }
765        }))
766        .buffered(3);
767
768        while let Some(val) = strm.next().await {
769            println!("[task: {}] sending next element: {}", id, val);
770            send_chan.send(val).await?;
771            println!("[task: {}] sent next element: {}", id, val);
772        }
773        println!("[task: {}] done executing", id);
774        Ok(())
775    }
776
777    async fn run_fut(v: usize, waiters: Arc<Mutex<HashMap<usize, Arc<Notify>>>>) -> Result<usize, ()> {
778        let waiter = {
779            let x = waiters.lock().await;
780            x.get(&v).cloned()
781        };
782        // wait for the main task to tell us to proceed.
783        if let Some(waiter) = waiter {
784            println!("val: {}, waiting for signal", v);
785            waiter.notified().await;
786            println!("val: {}, woke up from signal", v);
787        }
788        Ok(v)
789    }
790}
791
792#[cfg(test)]
793mod test_futures_unordered {
794    use std::future::Future;
795    use std::pin::Pin;
796    use std::sync::Arc;
797    use std::time::Duration;
798
799    use futures_util::TryStreamExt;
800    use futures_util::stream::FuturesUnordered;
801    use tokio::sync::mpsc;
802    use tokio::time::sleep;
803
804    use super::super::errors::SingleflightError;
805    use super::Group;
806
807    type FutType = Pin<Box<dyn Future<Output = Result<(i32, bool), SingleflightError<String>>> + Send>>;
808
809    #[tokio::test]
810    async fn test_dropped_owner() {
811        /*
812         We test out a situation where the owner of a task is dropped before the task can complete.
813         This is done by having the owner task be part of a FuturesUnordered execution where a
814         separate task errors-out, cancelling the others.
815
816         We expect that when an owning task is dropped, that the spawned owning task is still able
817         to complete in the background, that the Call state is properly cleaned up, and that a
818         new `work()` invocation for the key runs as an owner.
819
820             main       fut_error     fut_owner     owner_task    fut_waiter
821         try_collect()====>|------------->|              |
822              |        start(k2)      start(k1)------>start()
823              |<----------err             |              |         start(k1)
824             err----------------------->drop()           |             |
825              |                                        Ok(1)-------->Ok(1)
826        */
827        let group = Arc::new(Group::new());
828
829        // ready channels help the owner task tell the waiter task to start
830        let (ready_tx, mut ready_rx) = mpsc::channel(1);
831        // done channels help the owner task signal to main that the operation completed,
832        // even though fut_owner was dropped.
833        let (done_tx, mut done_rx) = mpsc::channel(1);
834
835        // Owner task for "key1" that will delay then return a `1`.
836        let fut_owner = get_fut(group.clone(), "key1", async move {
837            ready_tx.send(true).await.unwrap();
838            sleep(Duration::from_millis(100)).await;
839            done_tx.send(true).await.unwrap();
840            Ok(1)
841        });
842        // Waiter task for "key1" that should not get called (uses the results of owner task)
843        let fut_waiter =
844            get_fut(group.clone(), "key1", async { Err("Test BUG: waiter should not be called".to_string()) });
845
846        // Task for "key2" that will fail and cause fut_owner to be dropped.
847        let fut_err = get_fut(group.clone(), "key2", async { Err("failed".to_string()) });
848
849        // spawn a task to wait for fut_owner to be ready then run fut_waiter
850        let handle = tokio::spawn(async move {
851            assert!(ready_rx.recv().await.unwrap());
852            let (i, is_owner) = fut_waiter.await.unwrap();
853            assert!(!is_owner);
854            assert_eq!(i, 1);
855        });
856
857        // Use FuturesUnordered to run `fut_owner` and `fut_error`. Since `fut_error` immediately fails,
858        // it will complete first, causing the try_collect to short-circuit and drop `fut_owner`
859        //
860        // Implementation note: the order of the vec matters since FuturesUnordered will try to
861        // run the futures in-order (until it hits an await). If `fut_err` is first, since it has
862        // no awaits, it will immediately finish (i.e. err), causing fut_owner to never get run.
863        let futures: Result<Vec<(i32, bool)>, SingleflightError<String>> =
864            FuturesUnordered::from_iter(vec![fut_owner, fut_err]).try_collect().await;
865
866        assert!(futures.is_err());
867        // "key1" should be deleted from the call_map even though fut_owner was dropped before finishing
868        assert!(!group.call_map.lock().unwrap().contains_key("key1"));
869        assert!(done_rx.recv().await.unwrap());
870        handle.await.unwrap();
871
872        // Ensure that subsequent calls to the same key are able to go through as there are
873        // no currently running tasks.
874        let fut_after = get_fut(group, "key1", async { Ok(5) });
875        let (i, is_owner) = fut_after.await.unwrap();
876        assert!(is_owner);
877        assert_eq!(i, 5);
878    }
879
880    fn get_fut(
881        g: Arc<Group<i32, String>>,
882        key: &str,
883        f: impl Future<Output = Result<i32, String>> + Send + 'static,
884    ) -> FutType {
885        let key = key.to_string();
886        Box::pin(async move {
887            let (res, is_owner) = g.work(&key, f).await;
888            let i = res?;
889            Ok((i, is_owner))
890        })
891    }
892}