Skip to main content

uniflight/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Coalesces duplicate async tasks into a single execution.
5//!
6//! This crate provides [`Merger`], a mechanism for deduplicating concurrent async operations.
7//! When multiple tasks request the same work (identified by a key), only the first task (the
8//! "leader") performs the actual work while subsequent tasks (the "followers") wait and receive
9//! a clone of the result.
10//!
11//! # When to Use
12//!
13//! Use `Merger` when you have expensive or rate-limited operations that may be requested
14//! concurrently with the same parameters:
15//!
16//! - **Cache population**: Prevent thundering herd when a cache entry expires
17//! - **API calls**: Deduplicate concurrent requests to the same endpoint
18//! - **Database queries**: Coalesce identical queries issued simultaneously
19//! - **File I/O**: Avoid reading the same file multiple times concurrently
20//!
21//! # Example
22//!
23//! ```
24//! use uniflight::Merger;
25//!
26//! # async fn example() {
27//! let group: Merger<String, String> = Merger::new();
28//!
29//! // Multiple concurrent calls with the same key will share a single execution.
30//! // Note: you can pass &str directly when the key type is String.
31//! let result = group.execute("user:123", || async {
32//!     // This expensive operation runs only once, even if called concurrently
33//!     "expensive_result".to_string()
34//! }).await.expect("leader should not panic");
35//! # }
36//! ```
37//!
38//! # Flexible Key Types
39//!
40//! The [`Merger::execute`] method accepts keys using [`Borrow`] semantics, allowing you to pass
41//! borrowed forms of the key type. For example, with `Merger<String, T>`, you can pass `&str`
42//! directly without allocating:
43//!
44//! ```
45//! # use uniflight::Merger;
46//! # async fn example() {
47//! let merger: Merger<String, i32> = Merger::new();
48//!
49//! // Pass &str directly - no need to call .to_string()
50//! let result = merger.execute("my-key", || async { 42 }).await;
51//! assert_eq!(result, Ok(42));
52//! # }
53//! ```
54//!
55//! # Thread-Aware Scoping
56//!
57//! `Merger` supports thread-aware scoping via a [`Strategy`]
58//! type parameter. This controls how the internal state is partitioned across threads/NUMA nodes:
59//!
60//! - [`PerProcess`] (default): Single global state, maximum deduplication
61//! - [`PerNuma`]: Separate state per NUMA node, NUMA-local memory access
62//! - [`PerCore`]: Separate state per core, no deduplication (useful for already-partitioned work)
63//!
64//! ```
65//! use uniflight::Merger;
66//! use thread_aware::PerNuma;
67//!
68//! # async fn example() {
69//! // NUMA-aware merger - each NUMA node gets its own deduplication scope
70//! let merger: Merger<String, String, PerNuma> = Merger::new_per_numa();
71//! # }
72//! ```
73//!
74//! # Cancellation and Panic Handling
75//!
76//! `Merger` handles task cancellation and panics explicitly:
77//!
78//! - If the leader task is cancelled or dropped, a follower becomes the new leader
79//! - If the leader task panics, followers receive [`LeaderPanicked`] error with the panic message
80//! - Followers that join before the leader completes receive the value the leader returns
81//!
82//! When a panic occurs, followers are notified via the error type rather than silently
83//! retrying. The panic message is captured and available via [`LeaderPanicked::message`]:
84//!
85//! ```
86//! # use uniflight::Merger;
87//! # async fn example() {
88//! let merger: Merger<String, String> = Merger::new();
89//! match merger.execute("key", || async { "result".to_string() }).await {
90//!     Ok(value) => println!("got {value}"),
91//!     Err(err) => {
92//!         println!("leader panicked: {}", err.message());
93//!         // Decide whether to retry
94//!     }
95//! }
96//! # }
97//! ```
98//!
99//! # Memory Management
100//!
101//! Completed entries are automatically removed from the internal map when the last caller
102//! finishes. This ensures no stale entries accumulate over time.
103//!
104//! # Type Requirements
105//!
106//! The value type `T` must implement [`Clone`] because followers receive a clone of the
107//! leader's result. The key type `K` must implement [`Hash`] and [`Eq`].
108//!
109//! # Thread Safety
110//!
111//! [`Merger`] is `Send` and `Sync`, and can be shared across threads. The returned futures
112//! are `Send` when the closure, future, key, and value types are `Send`.
113//!
114//! # Performance
115//!
116//! Run benchmarks with `cargo bench -p uniflight`. The suite covers:
117//!
118//! - `single_call`: Baseline latency with no contention
119//! - `high_contention_100`: 100 concurrent tasks on the same key
120//! - `distributed_10x10`: 10 keys with 10 tasks each
121//!
122//! Use `--save-baseline` and `--baseline` flags to track regressions over time.
123
124#![doc(html_logo_url = "https://media.githubusercontent.com/media/microsoft/oxidizer/refs/heads/main/crates/uniflight/logo.png")]
125#![doc(html_favicon_url = "https://media.githubusercontent.com/media/microsoft/oxidizer/refs/heads/main/crates/uniflight/favicon.ico")]
126
127use std::{
128    borrow::Borrow,
129    fmt::Debug,
130    hash::Hash,
131    panic::AssertUnwindSafe,
132    sync::{Arc, Weak},
133};
134
135use ahash::RandomState;
136use async_once_cell::OnceCell;
137use dashmap::{
138    DashMap,
139    Entry::{Occupied, Vacant},
140};
141use futures_util::FutureExt; // catch_unwind, map
142use thread_aware::{
143    Arc as TaArc, PerCore, PerNuma, PerProcess, ThreadAware,
144    affinity::{MemoryAffinity, PinnedAffinity},
145    storage::Strategy,
146};
147
148/// Suppresses duplicate async operations identified by a key.
149///
150/// The `S` type parameter controls the thread-aware scoping strategy:
151/// - [`PerProcess`]: Single global scope (default, maximum deduplication)
152/// - [`PerNuma`]: Per-NUMA-node scope (NUMA-local memory access)
153/// - [`PerCore`]: Per-core scope (no deduplication)
154pub struct Merger<K, T, S: Strategy = PerProcess> {
155    inner: TaArc<DashMap<K, Weak<PanicAwareCell<T>>, RandomState>, S>,
156}
157
158impl<K, T, S: Strategy> Debug for Merger<K, T, S> {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("Merger").field("inner", &format_args!("DashMap<...>")).finish()
161    }
162}
163
164impl<K, T, S: Strategy> Clone for Merger<K, T, S> {
165    fn clone(&self) -> Self {
166        Self { inner: self.inner.clone() }
167    }
168}
169
170impl<K, T, S> Default for Merger<K, T, S>
171where
172    K: Hash + Eq + Send + Sync + 'static,
173    T: Send + Sync + 'static,
174    S: Strategy,
175{
176    fn default() -> Self {
177        Self {
178            inner: TaArc::new(|| DashMap::with_hasher(RandomState::new())),
179        }
180    }
181}
182
183impl<K, T, S> Merger<K, T, S>
184where
185    K: Hash + Eq + Send + Sync + 'static,
186    T: Send + Sync + 'static,
187    S: Strategy,
188{
189    /// Creates a new `Merger` instance.
190    ///
191    /// The scoping strategy is determined by the type parameter `S`:
192    /// - [`PerProcess`] (default): Process-wide scope, maximum deduplication
193    /// - [`PerNuma`]: Per-NUMA-node scope, NUMA-local memory access
194    /// - [`PerCore`]: Per-core scope, no cross-core deduplication
195    ///
196    /// # Examples
197    ///
198    /// ```
199    /// use uniflight::Merger;
200    /// use thread_aware::{PerNuma, PerCore};
201    ///
202    /// // Default (PerProcess) - type can be inferred
203    /// let global: Merger<String, String> = Merger::new();
204    ///
205    /// // NUMA-local scope
206    /// let numa: Merger<String, String, PerNuma> = Merger::new();
207    ///
208    /// // Per-core scope
209    /// let core: Merger<String, String, PerCore> = Merger::new();
210    /// ```
211    #[inline]
212    #[must_use]
213    pub fn new() -> Self {
214        Self::default()
215    }
216}
217
218impl<K, T> Merger<K, T, PerProcess>
219where
220    K: Hash + Eq + Send + Sync + 'static,
221    T: Send + Sync + 'static,
222{
223    /// Creates a new `Merger` with process-wide scoping (default).
224    ///
225    /// All threads share a single deduplication scope, providing maximum
226    /// work deduplication across the entire process.
227    ///
228    /// # Example
229    ///
230    /// ```
231    /// use uniflight::Merger;
232    ///
233    /// let merger = Merger::<String, String, _>::new_per_process();
234    /// ```
235    #[inline]
236    #[must_use]
237    #[cfg_attr(test, mutants::skip)] // Equivalent mutant: delegates to Default
238    pub fn new_per_process() -> Self {
239        Self::default()
240    }
241}
242
243impl<K, T> Merger<K, T, PerNuma>
244where
245    K: Hash + Eq + Send + Sync + 'static,
246    T: Send + Sync + 'static,
247{
248    /// Creates a new `Merger` with per-NUMA-node scoping.
249    ///
250    /// Each NUMA node gets its own deduplication scope, ensuring memory
251    /// locality for cached results while still deduplicating within each node.
252    ///
253    /// # Example
254    ///
255    /// ```
256    /// use uniflight::Merger;
257    ///
258    /// let merger = Merger::<String, String, _>::new_per_numa();
259    /// ```
260    #[inline]
261    #[must_use]
262    #[cfg_attr(test, mutants::skip)] // Equivalent mutant: delegates to Default
263    pub fn new_per_numa() -> Self {
264        Self::default()
265    }
266}
267
268impl<K, T> Merger<K, T, PerCore>
269where
270    K: Hash + Eq + Send + Sync + 'static,
271    T: Send + Sync + 'static,
272{
273    /// Creates a new `Merger` with per-core scoping.
274    ///
275    /// Each core gets its own deduplication scope. This is useful when work
276    /// is already partitioned by core and cross-core deduplication is not needed.
277    ///
278    /// # Example
279    ///
280    /// ```
281    /// use uniflight::Merger;
282    ///
283    /// let merger = Merger::<String, String, _>::new_per_core();
284    /// ```
285    #[inline]
286    #[must_use]
287    #[cfg_attr(test, mutants::skip)] // Equivalent mutant: delegates to Default
288    pub fn new_per_core() -> Self {
289        Self::default()
290    }
291}
292
293impl<K, T, S: Strategy> Merger<K, T, S>
294where
295    K: Hash + Eq,
296{
297    /// Returns the number of in-flight operations.
298    #[cfg(test)]
299    fn len(&self) -> usize {
300        self.inner.len()
301    }
302
303    /// Returns `true` if there are no in-flight operations.
304    #[cfg(test)]
305    fn is_empty(&self) -> bool {
306        self.inner.is_empty()
307    }
308}
309
310impl<K, T, S> ThreadAware for Merger<K, T, S>
311where
312    S: Strategy,
313{
314    fn relocated(self, source: MemoryAffinity, destination: PinnedAffinity) -> Self {
315        Self {
316            inner: self.inner.relocated(source, destination),
317        }
318    }
319}
320
321impl<K, T, S> Merger<K, T, S>
322where
323    K: Hash + Eq + Send + Sync,
324    T: Send + Sync,
325    S: Strategy + Send + Sync,
326{
327    /// Execute and return the value for a given function, making sure that only one
328    /// operation is in-flight at a given moment. If a duplicate call comes in,
329    /// that caller will wait until the leader completes and return the same value.
330    ///
331    /// # Errors
332    ///
333    /// Returns [`LeaderPanicked`] if the leader task panicked during execution.
334    /// Callers can retry by calling `execute` again if desired.
335    ///
336    /// # Example
337    ///
338    /// The key can be passed as any borrowed form of `K`. For example, if `K` is `String`,
339    /// you can pass `&str` directly:
340    ///
341    /// ```
342    /// # use uniflight::Merger;
343    /// # async fn example() {
344    /// let merger: Merger<String, i32> = Merger::new();
345    /// let result = merger.execute("my-key", || async { 42 }).await;
346    /// assert_eq!(result, Ok(42));
347    /// # }
348    /// ```
349    pub fn execute<Q, F, Fut>(&self, key: &Q, func: F) -> impl Future<Output = Result<T, LeaderPanicked>> + Send + use<Q, F, Fut, K, T, S>
350    where
351        K: Borrow<Q>,
352        Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
353        F: FnOnce() -> Fut + Send,
354        Fut: Future<Output = T> + Send,
355        T: Clone,
356    {
357        // Clone the TaArc - the async block owns this clone
358        let inner = self.inner.clone();
359        let cell = Self::get_or_create_cell(&inner, key);
360        let owned_key = key.to_owned();
361        async move {
362            let result = cell.get_or_init(func()).await.clone();
363            drop(cell); // Release our Arc before cleanup check
364            // Remove entry if no one else is using it (weak can't upgrade)
365            inner.remove_if(owned_key.borrow(), |_, weak| weak.upgrade().is_none());
366            result
367        }
368    }
369
370    /// Gets an existing cell for the key, or creates a new one.
371    fn get_or_create_cell<Q>(map: &DashMap<K, Weak<PanicAwareCell<T>>, RandomState>, key: &Q) -> Arc<PanicAwareCell<T>>
372    where
373        K: Borrow<Q>,
374        Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
375    {
376        // Fast path: check if entry exists and is still valid
377        if let Some(entry) = map.get(key)
378            && let Some(cell) = entry.value().upgrade()
379        {
380            return cell;
381        }
382
383        // Slow path: need to insert or replace expired entry
384        Self::insert_or_get_existing(map, key)
385    }
386
387    /// Inserts a new cell or returns an existing live cell (handling races).
388    ///
389    /// This is the slow path of `get_or_create_cell`, separated for testability.
390    /// It handles the case where another thread may have inserted a cell between
391    /// our fast-path check and this insertion attempt.
392    fn insert_or_get_existing<Q>(map: &DashMap<K, Weak<PanicAwareCell<T>>, RandomState>, key: &Q) -> Arc<PanicAwareCell<T>>
393    where
394        K: Borrow<Q>,
395        Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
396    {
397        let cell = Arc::new(PanicAwareCell::new());
398        let weak = Arc::downgrade(&cell);
399
400        // Use Entry enum to atomically check-and-return or insert
401        match map.entry(key.to_owned()) {
402            Occupied(mut entry) => {
403                // Entry exists - check if still alive
404                if let Some(existing) = entry.get().upgrade() {
405                    // Another thread's cell is still alive - use it
406                    return existing;
407                }
408                // Expired - replace with ours
409                entry.insert(weak);
410            }
411            Vacant(entry) => {
412                entry.insert(weak);
413            }
414        }
415
416        // We inserted our cell, return it
417        cell
418    }
419}
420
421/// Error returned when the leader task panicked during execution.
422///
423/// When a leader task panics, followers receive this error instead of
424/// silently retrying. Callers can decide whether to retry by calling
425/// `execute` again.
426///
427/// The panic message is captured and available via [`std::fmt::Display`] or [`LeaderPanicked::message`].
428#[derive(Debug, Clone, PartialEq, Eq)]
429pub struct LeaderPanicked {
430    message: Arc<str>,
431}
432
433impl LeaderPanicked {
434    /// Returns the panic message from the leader task.
435    #[must_use]
436    pub fn message(&self) -> &str {
437        &self.message
438    }
439}
440
441impl std::fmt::Display for LeaderPanicked {
442    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443        write!(f, "leader task panicked: {}", self.message)
444    }
445}
446
447impl std::error::Error for LeaderPanicked {}
448
449/// Extracts a message from a panic payload.
450///
451/// Tries to downcast to `&str` or `String`, falling back to a default message.
452fn extract_panic_message(payload: &(dyn std::any::Any + Send)) -> Arc<str> {
453    if let Some(s) = payload.downcast_ref::<&str>() {
454        return Arc::from(*s);
455    }
456    if let Some(s) = payload.downcast_ref::<String>() {
457        return Arc::from(s.as_str());
458    }
459    Arc::from("unknown panic")
460}
461
462struct PanicAwareCell<T> {
463    inner: OnceCell<Result<T, LeaderPanicked>>,
464}
465
466impl<T> PanicAwareCell<T> {
467    fn new() -> Self {
468        Self { inner: OnceCell::new() }
469    }
470
471    #[expect(clippy::future_not_send, reason = "Send bounds enforced by Merger::execute")]
472    async fn get_or_init<F>(&self, f: F) -> &Result<T, LeaderPanicked>
473    where
474        F: Future<Output = T>,
475    {
476        // Use map combinator instead of async block to avoid extra state machine
477        self.inner
478            .get_or_init(AssertUnwindSafe(f).catch_unwind().map(|result| {
479                result.map_err(|payload| LeaderPanicked {
480                    message: extract_panic_message(&*payload),
481                })
482            }))
483            .await
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use std::time::Duration;
491    use thread_aware::affinity::pinned_affinities;
492
493    #[test]
494    fn relocated_delegates_to_inner() {
495        let affinities = pinned_affinities(&[2]);
496        let source = affinities[0].into();
497        let destination = affinities[1];
498
499        let merger: Merger<String, String> = Merger::new();
500        let relocated = merger.relocated(source, destination);
501
502        // Verify the relocated merger still works
503        assert!(relocated.is_empty());
504    }
505
506    #[test]
507    fn fast_path_returns_existing() {
508        let map: DashMap<String, Weak<PanicAwareCell<String>>, RandomState> = DashMap::with_hasher(RandomState::new());
509        let existing_cell = Arc::new(PanicAwareCell::new());
510        map.insert("key".to_string(), Arc::downgrade(&existing_cell));
511
512        let result = Merger::<String, String>::get_or_create_cell(&map, "key");
513
514        assert!(Arc::ptr_eq(&result, &existing_cell));
515    }
516
517    #[test]
518    fn replaces_expired_entry() {
519        let map: DashMap<String, Weak<PanicAwareCell<String>>, RandomState> = DashMap::with_hasher(RandomState::new());
520        let expired_weak = Arc::downgrade(&Arc::new(PanicAwareCell::<String>::new()));
521        map.insert("key".to_string(), expired_weak);
522
523        let result = Merger::<String, String>::get_or_create_cell(&map, "key");
524
525        let entry = map.get("key").unwrap();
526        assert!(Arc::ptr_eq(&result, &entry.value().upgrade().unwrap()));
527    }
528
529    /// Simulates a race where another thread inserted between fast-path check and `entry()`.
530    #[test]
531    fn race_returns_existing() {
532        let map: DashMap<String, Weak<PanicAwareCell<String>>, RandomState> = DashMap::with_hasher(RandomState::new());
533        let other_cell = Arc::new(PanicAwareCell::new());
534        map.insert("key".to_string(), Arc::downgrade(&other_cell));
535
536        let result = Merger::<String, String>::insert_or_get_existing(&map, "key");
537
538        assert!(Arc::ptr_eq(&result, &other_cell));
539    }
540
541    #[tokio::test]
542    async fn cleanup_after_completion() {
543        let group: Merger<String, String> = Merger::new();
544        assert!(group.is_empty());
545
546        // Single call should clean up after completion
547        let result = group.execute("key1", || async { "Result".to_string() }).await;
548        assert_eq!(result, Ok("Result".to_string()));
549        assert!(group.is_empty(), "Map should be empty after single call completes");
550
551        // Multiple concurrent calls should clean up after all complete
552        let futures: Vec<_> = (0..10)
553            .map(|_| {
554                group.execute("key2", || async {
555                    tokio::time::sleep(Duration::from_millis(50)).await;
556                    "Result".to_string()
557                })
558            })
559            .collect();
560
561        // While in flight, map should have an entry
562        assert_eq!(group.len(), 1);
563
564        for fut in futures {
565            assert_eq!(fut.await, Ok("Result".to_string()));
566        }
567
568        assert!(group.is_empty(), "Map should be empty after all concurrent calls complete");
569
570        // Multiple different keys should all be cleaned up
571        let fut1 = group.execute("a", || async { "A".to_string() });
572        let fut2 = group.execute("b", || async { "B".to_string() });
573        let fut3 = group.execute("c", || async { "C".to_string() });
574
575        assert_eq!(group.len(), 3);
576
577        let (r1, r2, r3) = tokio::join!(fut1, fut2, fut3);
578        assert_eq!(r1, Ok("A".to_string()));
579        assert_eq!(r2, Ok("B".to_string()));
580        assert_eq!(r3, Ok("C".to_string()));
581
582        assert!(group.is_empty(), "Map should be empty after all keys complete");
583    }
584
585    #[tokio::test]
586    async fn catch_unwind_works() {
587        // Verify that catch_unwind actually catches panics in async code
588        let result = AssertUnwindSafe(async {
589            panic!("test panic");
590            #[expect(unreachable_code, reason = "Required to satisfy return type after panic")]
591            42i32
592        })
593        .catch_unwind()
594        .await;
595
596        assert!(result.is_err(), "catch_unwind should catch the panic");
597    }
598
599    #[tokio::test]
600    async fn panic_aware_cell_catches_panic() {
601        let cell = PanicAwareCell::<String>::new();
602        let result = cell
603            .get_or_init(async {
604                panic!("test panic");
605                #[expect(unreachable_code, reason = "Required to satisfy return type after panic")]
606                "never".to_string()
607            })
608            .await;
609
610        let err = result.as_ref().unwrap_err();
611        assert_eq!(err.message(), "test panic");
612    }
613
614    #[test]
615    fn extract_panic_message_from_string() {
616        let payload: Box<dyn std::any::Any + Send> = Box::new(String::from("owned string panic"));
617        let message = extract_panic_message(&*payload);
618        assert_eq!(&*message, "owned string panic");
619    }
620
621    #[test]
622    fn extract_panic_message_unknown_type() {
623        let payload: Box<dyn std::any::Any + Send> = Box::new(42i32);
624        let message = extract_panic_message(&*payload);
625        assert_eq!(&*message, "unknown panic");
626    }
627}