uniflight/lib.rs
1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Coalesces duplicate async tasks into a single execution.
5//!
6//! This crate provides [`Merger`], a mechanism for deduplicating concurrent async operations.
7//! When multiple tasks request the same work (identified by a key), only the first task (the
8//! "leader") performs the actual work while subsequent tasks (the "followers") wait and receive
9//! a clone of the result.
10//!
11//! # When to Use
12//!
13//! Use `Merger` when you have expensive or rate-limited operations that may be requested
14//! concurrently with the same parameters:
15//!
16//! - **Cache population**: Prevent thundering herd when a cache entry expires
17//! - **API calls**: Deduplicate concurrent requests to the same endpoint
18//! - **Database queries**: Coalesce identical queries issued simultaneously
19//! - **File I/O**: Avoid reading the same file multiple times concurrently
20//!
21//! # Example
22//!
23//! ```
24//! use uniflight::Merger;
25//!
26//! # async fn example() {
27//! let group: Merger<String, String> = Merger::new();
28//!
29//! // Multiple concurrent calls with the same key will share a single execution.
30//! // Note: you can pass &str directly when the key type is String.
31//! let result = group.execute("user:123", || async {
32//! // This expensive operation runs only once, even if called concurrently
33//! "expensive_result".to_string()
34//! }).await.expect("leader should not panic");
35//! # }
36//! ```
37//!
38//! # Flexible Key Types
39//!
40//! The [`Merger::execute`] method accepts keys using [`Borrow`] semantics, allowing you to pass
41//! borrowed forms of the key type. For example, with `Merger<String, T>`, you can pass `&str`
42//! directly without allocating:
43//!
44//! ```
45//! # use uniflight::Merger;
46//! # async fn example() {
47//! let merger: Merger<String, i32> = Merger::new();
48//!
49//! // Pass &str directly - no need to call .to_string()
50//! let result = merger.execute("my-key", || async { 42 }).await;
51//! assert_eq!(result, Ok(42));
52//! # }
53//! ```
54//!
55//! # Thread-Aware Scoping
56//!
57//! `Merger` supports thread-aware scoping via a [`Strategy`]
58//! type parameter. This controls how the internal state is partitioned across threads/NUMA nodes:
59//!
60//! - [`PerProcess`] (default): Single global state, maximum deduplication
61//! - [`PerNuma`]: Separate state per NUMA node, NUMA-local memory access
62//! - [`PerCore`]: Separate state per core, no deduplication (useful for already-partitioned work)
63//!
64//! ```
65//! use uniflight::Merger;
66//! use thread_aware::PerNuma;
67//!
68//! # async fn example() {
69//! // NUMA-aware merger - each NUMA node gets its own deduplication scope
70//! let merger: Merger<String, String, PerNuma> = Merger::new_per_numa();
71//! # }
72//! ```
73//!
74//! # Cancellation and Panic Handling
75//!
76//! `Merger` handles task cancellation and panics explicitly:
77//!
78//! - If the leader task is cancelled or dropped, a follower becomes the new leader
79//! - If the leader task panics, followers receive [`LeaderPanicked`] error with the panic message
80//! - Followers that join before the leader completes receive the value the leader returns
81//!
82//! When a panic occurs, followers are notified via the error type rather than silently
83//! retrying. The panic message is captured and available via [`LeaderPanicked::message`]:
84//!
85//! ```
86//! # use uniflight::Merger;
87//! # async fn example() {
88//! let merger: Merger<String, String> = Merger::new();
89//! match merger.execute("key", || async { "result".to_string() }).await {
90//! Ok(value) => println!("got {value}"),
91//! Err(err) => {
92//! println!("leader panicked: {}", err.message());
93//! // Decide whether to retry
94//! }
95//! }
96//! # }
97//! ```
98//!
99//! # Memory Management
100//!
101//! Completed entries are automatically removed from the internal map when the last caller
102//! finishes. This ensures no stale entries accumulate over time.
103//!
104//! # Type Requirements
105//!
106//! The value type `T` must implement [`Clone`] because followers receive a clone of the
107//! leader's result. The key type `K` must implement [`Hash`] and [`Eq`].
108//!
109//! # Thread Safety
110//!
111//! [`Merger`] is `Send` and `Sync`, and can be shared across threads. The returned futures
112//! are `Send` when the closure, future, key, and value types are `Send`.
113//!
114//! # Performance
115//!
116//! Run benchmarks with `cargo bench -p uniflight`. The suite covers:
117//!
118//! - `single_call`: Baseline latency with no contention
119//! - `high_contention_100`: 100 concurrent tasks on the same key
120//! - `distributed_10x10`: 10 keys with 10 tasks each
121//!
122//! Use `--save-baseline` and `--baseline` flags to track regressions over time.
123
124#![doc(html_logo_url = "https://media.githubusercontent.com/media/microsoft/oxidizer/refs/heads/main/crates/uniflight/logo.png")]
125#![doc(html_favicon_url = "https://media.githubusercontent.com/media/microsoft/oxidizer/refs/heads/main/crates/uniflight/favicon.ico")]
126
127use std::{
128 borrow::Borrow,
129 fmt::Debug,
130 hash::Hash,
131 panic::AssertUnwindSafe,
132 sync::{Arc, Weak},
133};
134
135use ahash::RandomState;
136use async_once_cell::OnceCell;
137use dashmap::{
138 DashMap,
139 Entry::{Occupied, Vacant},
140};
141use futures_util::FutureExt; // catch_unwind, map
142use thread_aware::{
143 Arc as TaArc, PerCore, PerNuma, PerProcess, ThreadAware,
144 affinity::{MemoryAffinity, PinnedAffinity},
145 storage::Strategy,
146};
147
148/// Suppresses duplicate async operations identified by a key.
149///
150/// The `S` type parameter controls the thread-aware scoping strategy:
151/// - [`PerProcess`]: Single global scope (default, maximum deduplication)
152/// - [`PerNuma`]: Per-NUMA-node scope (NUMA-local memory access)
153/// - [`PerCore`]: Per-core scope (no deduplication)
154pub struct Merger<K, T, S: Strategy = PerProcess> {
155 inner: TaArc<DashMap<K, Weak<PanicAwareCell<T>>, RandomState>, S>,
156}
157
158impl<K, T, S: Strategy> Debug for Merger<K, T, S> {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("Merger").field("inner", &format_args!("DashMap<...>")).finish()
161 }
162}
163
164impl<K, T, S: Strategy> Clone for Merger<K, T, S> {
165 fn clone(&self) -> Self {
166 Self { inner: self.inner.clone() }
167 }
168}
169
170impl<K, T, S> Default for Merger<K, T, S>
171where
172 K: Hash + Eq + Send + Sync + 'static,
173 T: Send + Sync + 'static,
174 S: Strategy,
175{
176 fn default() -> Self {
177 Self {
178 inner: TaArc::new(|| DashMap::with_hasher(RandomState::new())),
179 }
180 }
181}
182
183impl<K, T, S> Merger<K, T, S>
184where
185 K: Hash + Eq + Send + Sync + 'static,
186 T: Send + Sync + 'static,
187 S: Strategy,
188{
189 /// Creates a new `Merger` instance.
190 ///
191 /// The scoping strategy is determined by the type parameter `S`:
192 /// - [`PerProcess`] (default): Process-wide scope, maximum deduplication
193 /// - [`PerNuma`]: Per-NUMA-node scope, NUMA-local memory access
194 /// - [`PerCore`]: Per-core scope, no cross-core deduplication
195 ///
196 /// # Examples
197 ///
198 /// ```
199 /// use uniflight::Merger;
200 /// use thread_aware::{PerNuma, PerCore};
201 ///
202 /// // Default (PerProcess) - type can be inferred
203 /// let global: Merger<String, String> = Merger::new();
204 ///
205 /// // NUMA-local scope
206 /// let numa: Merger<String, String, PerNuma> = Merger::new();
207 ///
208 /// // Per-core scope
209 /// let core: Merger<String, String, PerCore> = Merger::new();
210 /// ```
211 #[inline]
212 #[must_use]
213 pub fn new() -> Self {
214 Self::default()
215 }
216}
217
218impl<K, T> Merger<K, T, PerProcess>
219where
220 K: Hash + Eq + Send + Sync + 'static,
221 T: Send + Sync + 'static,
222{
223 /// Creates a new `Merger` with process-wide scoping (default).
224 ///
225 /// All threads share a single deduplication scope, providing maximum
226 /// work deduplication across the entire process.
227 ///
228 /// # Example
229 ///
230 /// ```
231 /// use uniflight::Merger;
232 ///
233 /// let merger = Merger::<String, String, _>::new_per_process();
234 /// ```
235 #[inline]
236 #[must_use]
237 #[cfg_attr(test, mutants::skip)] // Equivalent mutant: delegates to Default
238 pub fn new_per_process() -> Self {
239 Self::default()
240 }
241}
242
243impl<K, T> Merger<K, T, PerNuma>
244where
245 K: Hash + Eq + Send + Sync + 'static,
246 T: Send + Sync + 'static,
247{
248 /// Creates a new `Merger` with per-NUMA-node scoping.
249 ///
250 /// Each NUMA node gets its own deduplication scope, ensuring memory
251 /// locality for cached results while still deduplicating within each node.
252 ///
253 /// # Example
254 ///
255 /// ```
256 /// use uniflight::Merger;
257 ///
258 /// let merger = Merger::<String, String, _>::new_per_numa();
259 /// ```
260 #[inline]
261 #[must_use]
262 #[cfg_attr(test, mutants::skip)] // Equivalent mutant: delegates to Default
263 pub fn new_per_numa() -> Self {
264 Self::default()
265 }
266}
267
268impl<K, T> Merger<K, T, PerCore>
269where
270 K: Hash + Eq + Send + Sync + 'static,
271 T: Send + Sync + 'static,
272{
273 /// Creates a new `Merger` with per-core scoping.
274 ///
275 /// Each core gets its own deduplication scope. This is useful when work
276 /// is already partitioned by core and cross-core deduplication is not needed.
277 ///
278 /// # Example
279 ///
280 /// ```
281 /// use uniflight::Merger;
282 ///
283 /// let merger = Merger::<String, String, _>::new_per_core();
284 /// ```
285 #[inline]
286 #[must_use]
287 #[cfg_attr(test, mutants::skip)] // Equivalent mutant: delegates to Default
288 pub fn new_per_core() -> Self {
289 Self::default()
290 }
291}
292
293impl<K, T, S: Strategy> Merger<K, T, S>
294where
295 K: Hash + Eq,
296{
297 /// Returns the number of in-flight operations.
298 #[cfg(test)]
299 fn len(&self) -> usize {
300 self.inner.len()
301 }
302
303 /// Returns `true` if there are no in-flight operations.
304 #[cfg(test)]
305 fn is_empty(&self) -> bool {
306 self.inner.is_empty()
307 }
308}
309
310impl<K, T, S> ThreadAware for Merger<K, T, S>
311where
312 S: Strategy,
313{
314 fn relocated(self, source: MemoryAffinity, destination: PinnedAffinity) -> Self {
315 Self {
316 inner: self.inner.relocated(source, destination),
317 }
318 }
319}
320
321impl<K, T, S> Merger<K, T, S>
322where
323 K: Hash + Eq + Send + Sync,
324 T: Send + Sync,
325 S: Strategy + Send + Sync,
326{
327 /// Execute and return the value for a given function, making sure that only one
328 /// operation is in-flight at a given moment. If a duplicate call comes in,
329 /// that caller will wait until the leader completes and return the same value.
330 ///
331 /// # Errors
332 ///
333 /// Returns [`LeaderPanicked`] if the leader task panicked during execution.
334 /// Callers can retry by calling `execute` again if desired.
335 ///
336 /// # Example
337 ///
338 /// The key can be passed as any borrowed form of `K`. For example, if `K` is `String`,
339 /// you can pass `&str` directly:
340 ///
341 /// ```
342 /// # use uniflight::Merger;
343 /// # async fn example() {
344 /// let merger: Merger<String, i32> = Merger::new();
345 /// let result = merger.execute("my-key", || async { 42 }).await;
346 /// assert_eq!(result, Ok(42));
347 /// # }
348 /// ```
349 pub fn execute<Q, F, Fut>(&self, key: &Q, func: F) -> impl Future<Output = Result<T, LeaderPanicked>> + Send + use<Q, F, Fut, K, T, S>
350 where
351 K: Borrow<Q>,
352 Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
353 F: FnOnce() -> Fut + Send,
354 Fut: Future<Output = T> + Send,
355 T: Clone,
356 {
357 // Clone the TaArc - the async block owns this clone
358 let inner = self.inner.clone();
359 let cell = Self::get_or_create_cell(&inner, key);
360 let owned_key = key.to_owned();
361 async move {
362 let result = cell.get_or_init(func()).await.clone();
363 drop(cell); // Release our Arc before cleanup check
364 // Remove entry if no one else is using it (weak can't upgrade)
365 inner.remove_if(owned_key.borrow(), |_, weak| weak.upgrade().is_none());
366 result
367 }
368 }
369
370 /// Gets an existing cell for the key, or creates a new one.
371 fn get_or_create_cell<Q>(map: &DashMap<K, Weak<PanicAwareCell<T>>, RandomState>, key: &Q) -> Arc<PanicAwareCell<T>>
372 where
373 K: Borrow<Q>,
374 Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
375 {
376 // Fast path: check if entry exists and is still valid
377 if let Some(entry) = map.get(key)
378 && let Some(cell) = entry.value().upgrade()
379 {
380 return cell;
381 }
382
383 // Slow path: need to insert or replace expired entry
384 Self::insert_or_get_existing(map, key)
385 }
386
387 /// Inserts a new cell or returns an existing live cell (handling races).
388 ///
389 /// This is the slow path of `get_or_create_cell`, separated for testability.
390 /// It handles the case where another thread may have inserted a cell between
391 /// our fast-path check and this insertion attempt.
392 fn insert_or_get_existing<Q>(map: &DashMap<K, Weak<PanicAwareCell<T>>, RandomState>, key: &Q) -> Arc<PanicAwareCell<T>>
393 where
394 K: Borrow<Q>,
395 Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
396 {
397 let cell = Arc::new(PanicAwareCell::new());
398 let weak = Arc::downgrade(&cell);
399
400 // Use Entry enum to atomically check-and-return or insert
401 match map.entry(key.to_owned()) {
402 Occupied(mut entry) => {
403 // Entry exists - check if still alive
404 if let Some(existing) = entry.get().upgrade() {
405 // Another thread's cell is still alive - use it
406 return existing;
407 }
408 // Expired - replace with ours
409 entry.insert(weak);
410 }
411 Vacant(entry) => {
412 entry.insert(weak);
413 }
414 }
415
416 // We inserted our cell, return it
417 cell
418 }
419}
420
421/// Error returned when the leader task panicked during execution.
422///
423/// When a leader task panics, followers receive this error instead of
424/// silently retrying. Callers can decide whether to retry by calling
425/// `execute` again.
426///
427/// The panic message is captured and available via [`std::fmt::Display`] or [`LeaderPanicked::message`].
428#[derive(Debug, Clone, PartialEq, Eq)]
429pub struct LeaderPanicked {
430 message: Arc<str>,
431}
432
433impl LeaderPanicked {
434 /// Returns the panic message from the leader task.
435 #[must_use]
436 pub fn message(&self) -> &str {
437 &self.message
438 }
439}
440
441impl std::fmt::Display for LeaderPanicked {
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 write!(f, "leader task panicked: {}", self.message)
444 }
445}
446
447impl std::error::Error for LeaderPanicked {}
448
449/// Extracts a message from a panic payload.
450///
451/// Tries to downcast to `&str` or `String`, falling back to a default message.
452fn extract_panic_message(payload: &(dyn std::any::Any + Send)) -> Arc<str> {
453 if let Some(s) = payload.downcast_ref::<&str>() {
454 return Arc::from(*s);
455 }
456 if let Some(s) = payload.downcast_ref::<String>() {
457 return Arc::from(s.as_str());
458 }
459 Arc::from("unknown panic")
460}
461
462struct PanicAwareCell<T> {
463 inner: OnceCell<Result<T, LeaderPanicked>>,
464}
465
466impl<T> PanicAwareCell<T> {
467 fn new() -> Self {
468 Self { inner: OnceCell::new() }
469 }
470
471 #[expect(clippy::future_not_send, reason = "Send bounds enforced by Merger::execute")]
472 async fn get_or_init<F>(&self, f: F) -> &Result<T, LeaderPanicked>
473 where
474 F: Future<Output = T>,
475 {
476 // Use map combinator instead of async block to avoid extra state machine
477 self.inner
478 .get_or_init(AssertUnwindSafe(f).catch_unwind().map(|result| {
479 result.map_err(|payload| LeaderPanicked {
480 message: extract_panic_message(&*payload),
481 })
482 }))
483 .await
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use std::time::Duration;
491 use thread_aware::affinity::pinned_affinities;
492
493 #[test]
494 fn relocated_delegates_to_inner() {
495 let affinities = pinned_affinities(&[2]);
496 let source = affinities[0].into();
497 let destination = affinities[1];
498
499 let merger: Merger<String, String> = Merger::new();
500 let relocated = merger.relocated(source, destination);
501
502 // Verify the relocated merger still works
503 assert!(relocated.is_empty());
504 }
505
506 #[test]
507 fn fast_path_returns_existing() {
508 let map: DashMap<String, Weak<PanicAwareCell<String>>, RandomState> = DashMap::with_hasher(RandomState::new());
509 let existing_cell = Arc::new(PanicAwareCell::new());
510 map.insert("key".to_string(), Arc::downgrade(&existing_cell));
511
512 let result = Merger::<String, String>::get_or_create_cell(&map, "key");
513
514 assert!(Arc::ptr_eq(&result, &existing_cell));
515 }
516
517 #[test]
518 fn replaces_expired_entry() {
519 let map: DashMap<String, Weak<PanicAwareCell<String>>, RandomState> = DashMap::with_hasher(RandomState::new());
520 let expired_weak = Arc::downgrade(&Arc::new(PanicAwareCell::<String>::new()));
521 map.insert("key".to_string(), expired_weak);
522
523 let result = Merger::<String, String>::get_or_create_cell(&map, "key");
524
525 let entry = map.get("key").unwrap();
526 assert!(Arc::ptr_eq(&result, &entry.value().upgrade().unwrap()));
527 }
528
529 /// Simulates a race where another thread inserted between fast-path check and `entry()`.
530 #[test]
531 fn race_returns_existing() {
532 let map: DashMap<String, Weak<PanicAwareCell<String>>, RandomState> = DashMap::with_hasher(RandomState::new());
533 let other_cell = Arc::new(PanicAwareCell::new());
534 map.insert("key".to_string(), Arc::downgrade(&other_cell));
535
536 let result = Merger::<String, String>::insert_or_get_existing(&map, "key");
537
538 assert!(Arc::ptr_eq(&result, &other_cell));
539 }
540
541 #[tokio::test]
542 async fn cleanup_after_completion() {
543 let group: Merger<String, String> = Merger::new();
544 assert!(group.is_empty());
545
546 // Single call should clean up after completion
547 let result = group.execute("key1", || async { "Result".to_string() }).await;
548 assert_eq!(result, Ok("Result".to_string()));
549 assert!(group.is_empty(), "Map should be empty after single call completes");
550
551 // Multiple concurrent calls should clean up after all complete
552 let futures: Vec<_> = (0..10)
553 .map(|_| {
554 group.execute("key2", || async {
555 tokio::time::sleep(Duration::from_millis(50)).await;
556 "Result".to_string()
557 })
558 })
559 .collect();
560
561 // While in flight, map should have an entry
562 assert_eq!(group.len(), 1);
563
564 for fut in futures {
565 assert_eq!(fut.await, Ok("Result".to_string()));
566 }
567
568 assert!(group.is_empty(), "Map should be empty after all concurrent calls complete");
569
570 // Multiple different keys should all be cleaned up
571 let fut1 = group.execute("a", || async { "A".to_string() });
572 let fut2 = group.execute("b", || async { "B".to_string() });
573 let fut3 = group.execute("c", || async { "C".to_string() });
574
575 assert_eq!(group.len(), 3);
576
577 let (r1, r2, r3) = tokio::join!(fut1, fut2, fut3);
578 assert_eq!(r1, Ok("A".to_string()));
579 assert_eq!(r2, Ok("B".to_string()));
580 assert_eq!(r3, Ok("C".to_string()));
581
582 assert!(group.is_empty(), "Map should be empty after all keys complete");
583 }
584
585 #[tokio::test]
586 async fn catch_unwind_works() {
587 // Verify that catch_unwind actually catches panics in async code
588 let result = AssertUnwindSafe(async {
589 panic!("test panic");
590 #[expect(unreachable_code, reason = "Required to satisfy return type after panic")]
591 42i32
592 })
593 .catch_unwind()
594 .await;
595
596 assert!(result.is_err(), "catch_unwind should catch the panic");
597 }
598
599 #[tokio::test]
600 async fn panic_aware_cell_catches_panic() {
601 let cell = PanicAwareCell::<String>::new();
602 let result = cell
603 .get_or_init(async {
604 panic!("test panic");
605 #[expect(unreachable_code, reason = "Required to satisfy return type after panic")]
606 "never".to_string()
607 })
608 .await;
609
610 let err = result.as_ref().unwrap_err();
611 assert_eq!(err.message(), "test panic");
612 }
613
614 #[test]
615 fn extract_panic_message_from_string() {
616 let payload: Box<dyn std::any::Any + Send> = Box::new(String::from("owned string panic"));
617 let message = extract_panic_message(&*payload);
618 assert_eq!(&*message, "owned string panic");
619 }
620
621 #[test]
622 fn extract_panic_message_unknown_type() {
623 let payload: Box<dyn std::any::Any + Send> = Box::new(42i32);
624 let message = extract_panic_message(&*payload);
625 assert_eq!(&*message, "unknown panic");
626 }
627}