Skip to main content

snapdir_stores/
transfer.rs

1//! Transfer configuration, rate limiting, and bounded-concurrency driver.
2//!
3//! This module is the foundation for concurrent object transfers and bandwidth
4//! limiting. It provides:
5//!
6//! - [`TransferConfig`] — how many objects to transfer in parallel and an
7//!   optional aggregate byte-rate cap.
8//! - [`RateLimiter`] — a zero-dependency async token bucket built on
9//!   [`tokio::time`], shareable across tasks via [`Arc`].
10//! - [`run_concurrent`] — a generic bounded-concurrency driver that runs up to
11//!   `concurrency` async operations in flight and returns the first error.
12//!
13//! Nothing here changes the existing (sequential) push / fetch loops yet; the
14//! stores merely carry a [`TransferConfig`] so later gates can wire these
15//! primitives into their transfer loops.
16
17use std::num::NonZeroUsize;
18use std::sync::Arc;
19use std::time::Duration;
20
21use futures::stream::{self, StreamExt, TryStreamExt};
22use snapdir_core::store::StoreError;
23use tokio::sync::Mutex;
24
25/// Upper bound on the auto-detected default concurrency.
26const DEFAULT_CONCURRENCY_CAP: usize = 16;
27
28/// Configuration for object transfers: how many to run in parallel and an
29/// optional aggregate byte-rate cap.
30///
31/// `Default` auto-detects the available parallelism (capped at
32/// [`DEFAULT_CONCURRENCY_CAP`]) and leaves bandwidth unlimited.
33#[derive(Debug, Clone)]
34pub struct TransferConfig {
35    /// Maximum number of object transfers to run concurrently.
36    pub concurrency: NonZeroUsize,
37    /// Optional aggregate bandwidth cap, in bytes per second. `None` means
38    /// unlimited.
39    pub max_bytes_per_sec: Option<u64>,
40}
41
42impl TransferConfig {
43    /// Builds a config, clamping `concurrency` to at least 1.
44    #[must_use]
45    pub fn new(concurrency: usize, max_bytes_per_sec: Option<u64>) -> Self {
46        Self {
47            concurrency: NonZeroUsize::new(concurrency.max(1)).unwrap_or(NonZeroUsize::MIN),
48            max_bytes_per_sec,
49        }
50    }
51}
52
53impl Default for TransferConfig {
54    fn default() -> Self {
55        let detected = std::thread::available_parallelism()
56            .map_or(1, NonZeroUsize::get)
57            .clamp(1, DEFAULT_CONCURRENCY_CAP);
58        Self {
59            // `detected` is >= 1, so the NonZeroUsize is always Some.
60            concurrency: NonZeroUsize::new(detected).unwrap_or(NonZeroUsize::MIN),
61            max_bytes_per_sec: None,
62        }
63    }
64}
65
66/// Shared token-bucket state, guarded by an async mutex.
67#[derive(Debug)]
68struct Bucket {
69    /// Currently available tokens (bytes).
70    tokens: f64,
71    /// Last time the bucket was refilled.
72    last_refill: tokio::time::Instant,
73}
74
75/// Inner state of a [`RateLimiter`].
76#[derive(Debug)]
77struct Inner {
78    /// Refill rate in bytes per second. `0` is impossible here (unlimited is
79    /// modelled by `bucket = None`).
80    rate: f64,
81    /// Maximum burst capacity, in bytes (~1 second's worth of budget).
82    capacity: f64,
83    /// `None` when unlimited; otherwise the live bucket state.
84    bucket: Option<Mutex<Bucket>>,
85}
86
87/// An async token-bucket rate limiter that throttles aggregate transfer
88/// throughput.
89///
90/// Construct with [`RateLimiter::new`]. When `max_bytes_per_sec` is `None` (or
91/// `Some(0)`), the limiter is unlimited and [`acquire`](RateLimiter::acquire)
92/// returns immediately. Otherwise tokens refill at `max_bytes_per_sec` per
93/// second, allowing a burst of up to ~1 second's worth of budget.
94///
95/// The limiter is [`Arc`]-shareable and [`Clone`] (cloning shares the same
96/// underlying bucket).
97#[derive(Debug, Clone)]
98pub struct RateLimiter {
99    inner: Arc<Inner>,
100}
101
102impl RateLimiter {
103    /// Builds a limiter. `None` (or `Some(0)`) yields an unlimited, no-op
104    /// limiter whose [`acquire`](RateLimiter::acquire) never waits.
105    #[must_use]
106    pub fn new(max_bytes_per_sec: Option<u64>) -> Self {
107        let inner = match max_bytes_per_sec {
108            Some(rate) if rate > 0 => {
109                #[allow(clippy::cast_precision_loss)]
110                let rate = rate as f64;
111                Inner {
112                    rate,
113                    capacity: rate,
114                    bucket: Some(Mutex::new(Bucket {
115                        tokens: rate,
116                        last_refill: tokio::time::Instant::now(),
117                    })),
118                }
119            }
120            _ => Inner {
121                rate: 0.0,
122                capacity: 0.0,
123                bucket: None,
124            },
125        };
126        Self {
127            inner: Arc::new(inner),
128        }
129    }
130
131    /// Blocks until `n` bytes of budget are available, refilling the bucket at
132    /// the configured rate. Unlimited limiters return immediately.
133    ///
134    /// A single request larger than the bucket capacity is still satisfied: the
135    /// bucket is allowed to go negative and the caller waits out the deficit,
136    /// so throttling is correct even for objects bigger than one second's
137    /// worth of budget.
138    pub async fn acquire(&self, n: u64) {
139        let Some(bucket) = self.inner.bucket.as_ref() else {
140            return; // unlimited fast path
141        };
142        if n == 0 {
143            return;
144        }
145        #[allow(clippy::cast_precision_loss)]
146        let need = n as f64;
147
148        loop {
149            let wait = {
150                let mut state = bucket.lock().await;
151                let now = tokio::time::Instant::now();
152                let elapsed = now.duration_since(state.last_refill).as_secs_f64();
153                state.tokens = (state.tokens + elapsed * self.inner.rate).min(self.inner.capacity);
154                state.last_refill = now;
155
156                if state.tokens >= need {
157                    state.tokens -= need;
158                    return;
159                }
160                // Not enough budget: compute how long until the deficit is
161                // covered, then sleep (releasing the lock first).
162                let deficit = need - state.tokens;
163                deficit / self.inner.rate
164            };
165            tokio::time::sleep(Duration::from_secs_f64(wait)).await;
166        }
167    }
168}
169
170/// Runs `op` over `items` with at most `concurrency` operations in flight,
171/// collecting their results in completion-independent order and returning the
172/// first error encountered (remaining in-flight work is cancelled).
173///
174/// This is the engine later gates use to drive concurrent uploads/downloads.
175///
176/// # Errors
177///
178/// Returns the first [`StoreError`] produced by any operation.
179pub async fn run_concurrent<I, T, F, Fut>(
180    items: I,
181    concurrency: NonZeroUsize,
182    op: F,
183) -> Result<Vec<T>, StoreError>
184where
185    I: IntoIterator,
186    F: Fn(I::Item) -> Fut,
187    Fut: std::future::Future<Output = Result<T, StoreError>>,
188{
189    stream::iter(items)
190        .map(op)
191        .buffer_unordered(concurrency.get())
192        .try_collect()
193        .await
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use std::sync::atomic::{AtomicUsize, Ordering};
200
201    /// Builds a current-thread tokio runtime with time enabled, avoiding a
202    /// dependency on the `#[tokio::test]` macro (keeps tokio's feature set
203    /// minimal).
204    fn runtime() -> tokio::runtime::Runtime {
205        tokio::runtime::Builder::new_current_thread()
206            .enable_time()
207            .build()
208            .expect("build tokio runtime")
209    }
210
211    #[test]
212    fn transfer_config_default_caps_concurrency() {
213        let cfg = TransferConfig::default();
214        assert!(cfg.concurrency.get() >= 1, "concurrency must be >= 1");
215        assert!(
216            cfg.concurrency.get() <= DEFAULT_CONCURRENCY_CAP,
217            "default concurrency must be capped at {DEFAULT_CONCURRENCY_CAP}, got {}",
218            cfg.concurrency.get()
219        );
220        assert_eq!(cfg.max_bytes_per_sec, None);
221
222        // The clamping ctor never yields 0.
223        assert_eq!(TransferConfig::new(0, None).concurrency.get(), 1);
224        assert_eq!(TransferConfig::new(7, Some(99)).concurrency.get(), 7);
225        assert_eq!(TransferConfig::new(7, Some(99)).max_bytes_per_sec, Some(99));
226    }
227
228    /// Drives `run_concurrent` over N > concurrency items, recording the peak
229    /// number of simultaneously-running ops, and asserts the bound is exactly
230    /// `min(concurrency, N)` — and strictly 1 (sequential) when concurrency=1.
231    fn max_in_flight_for(concurrency: usize, items: usize) -> usize {
232        let in_flight = Arc::new(AtomicUsize::new(0));
233        let high_water = Arc::new(AtomicUsize::new(0));
234
235        let rt = runtime();
236        let result = rt.block_on(async {
237            let in_flight = Arc::clone(&in_flight);
238            let high_water = Arc::clone(&high_water);
239            run_concurrent(
240                0..items,
241                NonZeroUsize::new(concurrency).unwrap(),
242                move |_item| {
243                    let in_flight = Arc::clone(&in_flight);
244                    let high_water = Arc::clone(&high_water);
245                    async move {
246                        let cur = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
247                        high_water.fetch_max(cur, Ordering::SeqCst);
248                        tokio::time::sleep(Duration::from_millis(20)).await;
249                        in_flight.fetch_sub(1, Ordering::SeqCst);
250                        Ok::<_, StoreError>(())
251                    }
252                },
253            )
254            .await
255        });
256        assert!(result.is_ok());
257        high_water.load(Ordering::SeqCst)
258    }
259
260    #[test]
261    fn transfer_config_run_concurrent_max_in_flight() {
262        // concurrency=4 over 12 items: peak in-flight is exactly 4.
263        assert_eq!(max_in_flight_for(4, 12), 4);
264        // concurrency=1 over 5 items: strictly sequential, peak in-flight is 1.
265        assert_eq!(max_in_flight_for(1, 5), 1);
266        // concurrency greater than item count is bounded by the item count.
267        assert_eq!(max_in_flight_for(8, 3), 3);
268    }
269
270    #[test]
271    fn transfer_config_run_concurrent_propagates_error() {
272        let rt = runtime();
273        let result: Result<Vec<()>, StoreError> = rt.block_on(async {
274            run_concurrent(0..10, NonZeroUsize::new(3).unwrap(), |item| async move {
275                if item == 5 {
276                    Err(StoreError::Backend {
277                        message: "boom".to_owned(),
278                        source: None,
279                    })
280                } else {
281                    tokio::time::sleep(Duration::from_millis(5)).await;
282                    Ok(())
283                }
284            })
285            .await
286        });
287        let err = result.expect_err("must surface the failing op's error");
288        assert!(
289            matches!(err, StoreError::Backend { ref message, .. } if message == "boom"),
290            "unexpected error: {err:?}"
291        );
292    }
293
294    #[test]
295    fn transfer_config_rate_limiter() {
296        let rt = runtime();
297        rt.block_on(async {
298            // Unlimited: acquiring a large amount returns essentially instantly.
299            let unlimited = RateLimiter::new(None);
300            let start = tokio::time::Instant::now();
301            unlimited.acquire(1_000_000).await;
302            assert!(
303                start.elapsed() < Duration::from_millis(200),
304                "unlimited acquire should not block"
305            );
306
307            // Limited to 1000 bytes/sec. The bucket starts full (1000), so the
308            // first 1000 bytes are free; acquiring another ~2000 bytes total
309            // must wait for the deficit to refill — at least ~1s.
310            let limiter = RateLimiter::new(Some(1000));
311            let start = tokio::time::Instant::now();
312            limiter.acquire(1000).await; // drains the initial burst
313            limiter.acquire(1000).await; // must wait ~1s to refill
314            let elapsed = start.elapsed();
315            assert!(
316                elapsed >= Duration::from_millis(900),
317                "throttled acquire should take ~1s, took {elapsed:?}"
318            );
319        });
320    }
321}