singleflight_async/
lib.rs

1use std::{
2    collections::HashMap,
3    future::Future,
4    hash::Hash,
5    sync::{Arc, Weak},
6};
7
8use parking_lot::Mutex as SyncMutex;
9use tokio::sync::Mutex;
10
11type SharedMapping<K, T> = Arc<SyncMutex<HashMap<K, BroadcastOnce<T>>>>;
12
13/// SingleFlight represents a class of work and creates a space in which units of work
14/// can be executed with duplicate suppression.
15#[derive(Debug)]
16pub struct SingleFlight<K, T> {
17    mapping: SharedMapping<K, T>,
18}
19
20impl<K, T> Default for SingleFlight<K, T> {
21    fn default() -> Self {
22        Self {
23            mapping: Default::default(),
24        }
25    }
26}
27
28struct Shared<T> {
29    slot: Mutex<Option<T>>,
30}
31
32impl<T> Default for Shared<T> {
33    fn default() -> Self {
34        Self {
35            slot: Mutex::new(None),
36        }
37    }
38}
39
40/// `BroadcastOnce` consists of shared slot and notify.
41#[derive(Clone)]
42struct BroadcastOnce<T> {
43    shared: Weak<Shared<T>>,
44}
45
46impl<T> BroadcastOnce<T> {
47    fn new() -> (Self, Arc<Shared<T>>) {
48        let shared = Arc::new(Shared::default());
49        (
50            Self {
51                shared: Arc::downgrade(&shared),
52            },
53            shared,
54        )
55    }
56}
57
58// After calling BroadcastOnce::waiter we can get a waiter.
59// It's in WaitList.
60struct BroadcastOnceWaiter<K, T, F> {
61    func: F,
62    shared: Arc<Shared<T>>,
63
64    key: K,
65    mapping: SharedMapping<K, T>,
66}
67
68impl<T> std::fmt::Debug for BroadcastOnce<T> {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        write!(f, "BroadcastOnce")
71    }
72}
73
74#[allow(clippy::type_complexity)]
75impl<T> BroadcastOnce<T> {
76    fn try_waiter<K, F>(
77        &self,
78        func: F,
79        key: K,
80        mapping: SharedMapping<K, T>,
81    ) -> Result<BroadcastOnceWaiter<K, T, F>, (F, K, SharedMapping<K, T>)> {
82        let Some(upgraded) = self.shared.upgrade() else {
83            return Err((func, key, mapping));
84        };
85        Ok(BroadcastOnceWaiter {
86            func,
87            shared: upgraded,
88            key,
89            mapping,
90        })
91    }
92
93    #[inline]
94    const fn waiter<K, F>(
95        shared: Arc<Shared<T>>,
96        func: F,
97        key: K,
98        mapping: SharedMapping<K, T>,
99    ) -> BroadcastOnceWaiter<K, T, F> {
100        BroadcastOnceWaiter {
101            func,
102            shared,
103            key,
104            mapping,
105        }
106    }
107}
108
109// We already in WaitList, so wait will be fine, we won't miss
110// anything after Waiter generated.
111impl<K, T, F, Fut> BroadcastOnceWaiter<K, T, F>
112where
113    K: Hash + Eq,
114    F: FnOnce() -> Fut,
115    Fut: Future<Output = T>,
116    T: Clone,
117{
118    async fn wait(self) -> T {
119        let mut slot = self.shared.slot.lock().await;
120        if let Some(value) = (*slot).as_ref() {
121            return value.clone();
122        }
123
124        let value = (self.func)().await;
125        *slot = Some(value.clone());
126
127        self.mapping.lock().remove(&self.key);
128
129        value
130    }
131}
132
133impl<K, T> SingleFlight<K, T>
134where
135    K: Hash + Eq + Clone,
136{
137    /// Create a new BroadcastOnce to do work with.
138    #[inline]
139    pub fn new() -> Self {
140        Self::default()
141    }
142
143    /// Execute and return the value for a given function, making sure that only one
144    /// operation is in-flight at a given moment. If a duplicate call comes in, that caller will
145    /// wait until the original call completes and return the same value.
146    pub fn work<F, Fut>(&self, key: K, func: F) -> impl Future<Output = T>
147    where
148        F: FnOnce() -> Fut,
149        Fut: Future<Output = T>,
150        T: Clone,
151    {
152        let owned_mapping = self.mapping.clone();
153        let mut mapping = self.mapping.lock();
154        let val = mapping.get_mut(&key);
155        match val {
156            Some(call) => {
157                let (func, key, owned_mapping) = match call.try_waiter(func, key, owned_mapping) {
158                    Ok(waiter) => return waiter.wait(),
159                    Err(fm) => fm,
160                };
161                let (new_call, shared) = BroadcastOnce::new();
162                *call = new_call;
163                let waiter = BroadcastOnce::waiter(shared, func, key, owned_mapping);
164                waiter.wait()
165            }
166            None => {
167                let (call, shared) = BroadcastOnce::new();
168                mapping.insert(key.clone(), call);
169                let waiter = BroadcastOnce::waiter(shared, func, key, owned_mapping);
170                waiter.wait()
171            }
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use std::{
179        sync::atomic::{
180            AtomicUsize,
181            Ordering::{AcqRel, Acquire},
182        },
183        time::Duration,
184    };
185
186    use futures_util::{stream::FuturesUnordered, StreamExt};
187
188    use super::*;
189
190    #[tokio::test]
191    async fn direct_call() {
192        let group = SingleFlight::new();
193        let result = group
194            .work("key", || async {
195                tokio::time::sleep(Duration::from_millis(10)).await;
196                "Result".to_string()
197            })
198            .await;
199        assert_eq!(result, "Result");
200    }
201
202    #[tokio::test]
203    async fn parallel_call() {
204        let call_counter = AtomicUsize::default();
205
206        let group = SingleFlight::new();
207        let futures = FuturesUnordered::new();
208        for _ in 0..10 {
209            futures.push(group.work("key", || async {
210                tokio::time::sleep(Duration::from_millis(100)).await;
211                call_counter.fetch_add(1, AcqRel);
212                "Result".to_string()
213            }));
214        }
215
216        assert!(futures.all(|out| async move { out == "Result" }).await);
217        assert_eq!(
218            call_counter.load(Acquire),
219            1,
220            "future should only be executed once"
221        );
222    }
223
224    #[tokio::test]
225    async fn parallel_call_seq_await() {
226        let call_counter = AtomicUsize::default();
227
228        let group = SingleFlight::new();
229        let mut futures = Vec::new();
230        for _ in 0..10 {
231            futures.push(group.work("key", || async {
232                tokio::time::sleep(Duration::from_millis(100)).await;
233                call_counter.fetch_add(1, AcqRel);
234                "Result".to_string()
235            }));
236        }
237
238        for fut in futures.into_iter() {
239            assert_eq!(fut.await, "Result");
240        }
241        assert_eq!(
242            call_counter.load(Acquire),
243            1,
244            "future should only be executed once"
245        );
246    }
247
248    #[tokio::test]
249    async fn call_with_static_str_key() {
250        let group = SingleFlight::new();
251        let result = group
252            .work("key".to_string(), || async {
253                tokio::time::sleep(Duration::from_millis(1)).await;
254                "Result".to_string()
255            })
256            .await;
257        assert_eq!(result, "Result");
258    }
259
260    #[tokio::test]
261    async fn call_with_static_string_key() {
262        let group = SingleFlight::new();
263        let result = group
264            .work("key".to_string(), || async {
265                tokio::time::sleep(Duration::from_millis(1)).await;
266                "Result".to_string()
267            })
268            .await;
269        assert_eq!(result, "Result");
270    }
271
272    #[tokio::test]
273    async fn call_with_custom_key() {
274        #[derive(Clone, PartialEq, Eq, Hash)]
275        struct K(i32);
276        let group = SingleFlight::new();
277        let result = group
278            .work(K(1), || async {
279                tokio::time::sleep(Duration::from_millis(1)).await;
280                "Result".to_string()
281            })
282            .await;
283        assert_eq!(result, "Result");
284    }
285
286    #[tokio::test]
287    async fn late_wait() {
288        let group = SingleFlight::new();
289        let fut_early = group.work("key".to_string(), || async {
290            tokio::time::sleep(Duration::from_millis(20)).await;
291            "Result".to_string()
292        });
293        let fut_late = group.work("key".into(), || async { panic!("unexpected") });
294        assert_eq!(fut_early.await, "Result");
295        tokio::time::sleep(Duration::from_millis(50)).await;
296        assert_eq!(fut_late.await, "Result");
297    }
298
299    #[tokio::test]
300    async fn cancel() {
301        let group = SingleFlight::new();
302
303        // the executer cancelled and the other awaiter will create a new future and execute.
304        let fut_cancel = group.work("key".to_string(), || async {
305            tokio::time::sleep(Duration::from_millis(2000)).await;
306            "Result1".to_string()
307        });
308        let _ = tokio::time::timeout(Duration::from_millis(10), fut_cancel).await;
309        let fut_late = group.work("key".to_string(), || async { "Result2".to_string() });
310        assert_eq!(fut_late.await, "Result2");
311
312        // the first executer is slow but not dropped, so the result will be the first ones.
313        let begin = tokio::time::Instant::now();
314        let fut_1 = group.work("key".to_string(), || async {
315            tokio::time::sleep(Duration::from_millis(2000)).await;
316            "Result1".to_string()
317        });
318        let fut_2 = group.work("key".to_string(), || async { panic!() });
319        let (v1, v2) = tokio::join!(fut_1, fut_2);
320        assert_eq!(v1, "Result1");
321        assert_eq!(v2, "Result1");
322        assert!(begin.elapsed() > Duration::from_millis(1500));
323    }
324}