1use std::{
2 collections::HashMap,
3 future::Future,
4 hash::Hash,
5 sync::{Arc, Weak},
6};
7
8use parking_lot::Mutex as SyncMutex;
9use tokio::sync::Mutex;
10
11type SharedMapping<K, T> = Arc<SyncMutex<HashMap<K, BroadcastOnce<T>>>>;
12
13#[derive(Debug)]
16pub struct SingleFlight<K, T> {
17 mapping: SharedMapping<K, T>,
18}
19
20impl<K, T> Default for SingleFlight<K, T> {
21 fn default() -> Self {
22 Self {
23 mapping: Default::default(),
24 }
25 }
26}
27
28struct Shared<T> {
29 slot: Mutex<Option<T>>,
30}
31
32impl<T> Default for Shared<T> {
33 fn default() -> Self {
34 Self {
35 slot: Mutex::new(None),
36 }
37 }
38}
39
40#[derive(Clone)]
42struct BroadcastOnce<T> {
43 shared: Weak<Shared<T>>,
44}
45
46impl<T> BroadcastOnce<T> {
47 fn new() -> (Self, Arc<Shared<T>>) {
48 let shared = Arc::new(Shared::default());
49 (
50 Self {
51 shared: Arc::downgrade(&shared),
52 },
53 shared,
54 )
55 }
56}
57
58struct BroadcastOnceWaiter<K, T, F> {
61 func: F,
62 shared: Arc<Shared<T>>,
63
64 key: K,
65 mapping: SharedMapping<K, T>,
66}
67
68impl<T> std::fmt::Debug for BroadcastOnce<T> {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 write!(f, "BroadcastOnce")
71 }
72}
73
74#[allow(clippy::type_complexity)]
75impl<T> BroadcastOnce<T> {
76 fn try_waiter<K, F>(
77 &self,
78 func: F,
79 key: K,
80 mapping: SharedMapping<K, T>,
81 ) -> Result<BroadcastOnceWaiter<K, T, F>, (F, K, SharedMapping<K, T>)> {
82 let Some(upgraded) = self.shared.upgrade() else {
83 return Err((func, key, mapping));
84 };
85 Ok(BroadcastOnceWaiter {
86 func,
87 shared: upgraded,
88 key,
89 mapping,
90 })
91 }
92
93 #[inline]
94 const fn waiter<K, F>(
95 shared: Arc<Shared<T>>,
96 func: F,
97 key: K,
98 mapping: SharedMapping<K, T>,
99 ) -> BroadcastOnceWaiter<K, T, F> {
100 BroadcastOnceWaiter {
101 func,
102 shared,
103 key,
104 mapping,
105 }
106 }
107}
108
109impl<K, T, F, Fut> BroadcastOnceWaiter<K, T, F>
112where
113 K: Hash + Eq,
114 F: FnOnce() -> Fut,
115 Fut: Future<Output = T>,
116 T: Clone,
117{
118 async fn wait(self) -> T {
119 let mut slot = self.shared.slot.lock().await;
120 if let Some(value) = (*slot).as_ref() {
121 return value.clone();
122 }
123
124 let value = (self.func)().await;
125 *slot = Some(value.clone());
126
127 self.mapping.lock().remove(&self.key);
128
129 value
130 }
131}
132
133impl<K, T> SingleFlight<K, T>
134where
135 K: Hash + Eq + Clone,
136{
137 #[inline]
139 pub fn new() -> Self {
140 Self::default()
141 }
142
143 pub fn work<F, Fut>(&self, key: K, func: F) -> impl Future<Output = T>
147 where
148 F: FnOnce() -> Fut,
149 Fut: Future<Output = T>,
150 T: Clone,
151 {
152 let owned_mapping = self.mapping.clone();
153 let mut mapping = self.mapping.lock();
154 let val = mapping.get_mut(&key);
155 match val {
156 Some(call) => {
157 let (func, key, owned_mapping) = match call.try_waiter(func, key, owned_mapping) {
158 Ok(waiter) => return waiter.wait(),
159 Err(fm) => fm,
160 };
161 let (new_call, shared) = BroadcastOnce::new();
162 *call = new_call;
163 let waiter = BroadcastOnce::waiter(shared, func, key, owned_mapping);
164 waiter.wait()
165 }
166 None => {
167 let (call, shared) = BroadcastOnce::new();
168 mapping.insert(key.clone(), call);
169 let waiter = BroadcastOnce::waiter(shared, func, key, owned_mapping);
170 waiter.wait()
171 }
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use std::{
179 sync::atomic::{
180 AtomicUsize,
181 Ordering::{AcqRel, Acquire},
182 },
183 time::Duration,
184 };
185
186 use futures_util::{stream::FuturesUnordered, StreamExt};
187
188 use super::*;
189
190 #[tokio::test]
191 async fn direct_call() {
192 let group = SingleFlight::new();
193 let result = group
194 .work("key", || async {
195 tokio::time::sleep(Duration::from_millis(10)).await;
196 "Result".to_string()
197 })
198 .await;
199 assert_eq!(result, "Result");
200 }
201
202 #[tokio::test]
203 async fn parallel_call() {
204 let call_counter = AtomicUsize::default();
205
206 let group = SingleFlight::new();
207 let futures = FuturesUnordered::new();
208 for _ in 0..10 {
209 futures.push(group.work("key", || async {
210 tokio::time::sleep(Duration::from_millis(100)).await;
211 call_counter.fetch_add(1, AcqRel);
212 "Result".to_string()
213 }));
214 }
215
216 assert!(futures.all(|out| async move { out == "Result" }).await);
217 assert_eq!(
218 call_counter.load(Acquire),
219 1,
220 "future should only be executed once"
221 );
222 }
223
224 #[tokio::test]
225 async fn parallel_call_seq_await() {
226 let call_counter = AtomicUsize::default();
227
228 let group = SingleFlight::new();
229 let mut futures = Vec::new();
230 for _ in 0..10 {
231 futures.push(group.work("key", || async {
232 tokio::time::sleep(Duration::from_millis(100)).await;
233 call_counter.fetch_add(1, AcqRel);
234 "Result".to_string()
235 }));
236 }
237
238 for fut in futures.into_iter() {
239 assert_eq!(fut.await, "Result");
240 }
241 assert_eq!(
242 call_counter.load(Acquire),
243 1,
244 "future should only be executed once"
245 );
246 }
247
248 #[tokio::test]
249 async fn call_with_static_str_key() {
250 let group = SingleFlight::new();
251 let result = group
252 .work("key".to_string(), || async {
253 tokio::time::sleep(Duration::from_millis(1)).await;
254 "Result".to_string()
255 })
256 .await;
257 assert_eq!(result, "Result");
258 }
259
260 #[tokio::test]
261 async fn call_with_static_string_key() {
262 let group = SingleFlight::new();
263 let result = group
264 .work("key".to_string(), || async {
265 tokio::time::sleep(Duration::from_millis(1)).await;
266 "Result".to_string()
267 })
268 .await;
269 assert_eq!(result, "Result");
270 }
271
272 #[tokio::test]
273 async fn call_with_custom_key() {
274 #[derive(Clone, PartialEq, Eq, Hash)]
275 struct K(i32);
276 let group = SingleFlight::new();
277 let result = group
278 .work(K(1), || async {
279 tokio::time::sleep(Duration::from_millis(1)).await;
280 "Result".to_string()
281 })
282 .await;
283 assert_eq!(result, "Result");
284 }
285
286 #[tokio::test]
287 async fn late_wait() {
288 let group = SingleFlight::new();
289 let fut_early = group.work("key".to_string(), || async {
290 tokio::time::sleep(Duration::from_millis(20)).await;
291 "Result".to_string()
292 });
293 let fut_late = group.work("key".into(), || async { panic!("unexpected") });
294 assert_eq!(fut_early.await, "Result");
295 tokio::time::sleep(Duration::from_millis(50)).await;
296 assert_eq!(fut_late.await, "Result");
297 }
298
299 #[tokio::test]
300 async fn cancel() {
301 let group = SingleFlight::new();
302
303 let fut_cancel = group.work("key".to_string(), || async {
305 tokio::time::sleep(Duration::from_millis(2000)).await;
306 "Result1".to_string()
307 });
308 let _ = tokio::time::timeout(Duration::from_millis(10), fut_cancel).await;
309 let fut_late = group.work("key".to_string(), || async { "Result2".to_string() });
310 assert_eq!(fut_late.await, "Result2");
311
312 let begin = tokio::time::Instant::now();
314 let fut_1 = group.work("key".to_string(), || async {
315 tokio::time::sleep(Duration::from_millis(2000)).await;
316 "Result1".to_string()
317 });
318 let fut_2 = group.work("key".to_string(), || async { panic!() });
319 let (v1, v2) = tokio::join!(fut_1, fut_2);
320 assert_eq!(v1, "Result1");
321 assert_eq!(v2, "Result1");
322 assert!(begin.elapsed() > Duration::from_millis(1500));
323 }
324}