Skip to main content

snapdir_stores/
transfer.rs

1//! Transfer configuration, rate limiting, and bounded-concurrency driver.
2//!
3//! This module is the foundation for concurrent object transfers and bandwidth
4//! limiting. It provides:
5//!
6//! - [`TransferConfig`] — how many objects to transfer in parallel and an
7//!   optional aggregate byte-rate cap.
8//! - [`RateLimiter`] — a zero-dependency async token bucket built on
9//!   [`tokio::time`], shareable across tasks via [`Arc`].
10//! - [`run_concurrent`] — a generic bounded-concurrency driver that runs up to
11//!   `concurrency` async operations in flight and returns the first error.
12//!
13//! Nothing here changes the existing (sequential) push / fetch loops yet; the
14//! stores merely carry a [`TransferConfig`] so later gates can wire these
15//! primitives into their transfer loops.
16
17use std::num::NonZeroUsize;
18use std::sync::Arc;
19use std::time::Duration;
20
21use futures::stream::{self, StreamExt, TryStreamExt};
22use snapdir_core::store::StoreError;
23use tokio::sync::Mutex;
24
25/// Upper bound on the auto-detected default concurrency.
26const DEFAULT_CONCURRENCY_CAP: usize = 16;
27
28/// Configuration for object transfers: how many to run in parallel and an
29/// optional aggregate byte-rate cap.
30///
31/// `Default` auto-detects the available parallelism (capped at
32/// [`DEFAULT_CONCURRENCY_CAP`]) and leaves bandwidth unlimited.
33#[derive(Debug, Clone)]
34pub struct TransferConfig {
35    /// Maximum number of object transfers to run concurrently.
36    pub concurrency: NonZeroUsize,
37    /// Optional aggregate bandwidth cap, in bytes per second. `None` means
38    /// unlimited.
39    pub max_bytes_per_sec: Option<u64>,
40}
41
42impl TransferConfig {
43    /// Builds a config, clamping `concurrency` to at least 1.
44    #[must_use]
45    pub fn new(concurrency: usize, max_bytes_per_sec: Option<u64>) -> Self {
46        Self {
47            concurrency: NonZeroUsize::new(concurrency.max(1)).unwrap_or(NonZeroUsize::MIN),
48            max_bytes_per_sec,
49        }
50    }
51}
52
53impl Default for TransferConfig {
54    fn default() -> Self {
55        let detected = std::thread::available_parallelism()
56            .map_or(1, NonZeroUsize::get)
57            .clamp(1, DEFAULT_CONCURRENCY_CAP);
58        Self {
59            // `detected` is >= 1, so the NonZeroUsize is always Some.
60            concurrency: NonZeroUsize::new(detected).unwrap_or(NonZeroUsize::MIN),
61            max_bytes_per_sec: None,
62        }
63    }
64}
65
66/// Shared token-bucket state, guarded by an async mutex.
67#[derive(Debug)]
68struct Bucket {
69    /// Currently available tokens (bytes).
70    tokens: f64,
71    /// Last time the bucket was refilled.
72    last_refill: tokio::time::Instant,
73}
74
75/// Inner state of a [`RateLimiter`].
76#[derive(Debug)]
77struct Inner {
78    /// Refill rate in bytes per second. `0` is impossible here (unlimited is
79    /// modelled by `bucket = None`).
80    rate: f64,
81    /// Maximum burst capacity, in bytes (~1 second's worth of budget).
82    capacity: f64,
83    /// `None` when unlimited; otherwise the live bucket state.
84    bucket: Option<Mutex<Bucket>>,
85}
86
87/// An async token-bucket rate limiter that throttles aggregate transfer
88/// throughput.
89///
90/// Construct with [`RateLimiter::new`]. When `max_bytes_per_sec` is `None` (or
91/// `Some(0)`), the limiter is unlimited and [`acquire`](RateLimiter::acquire)
92/// returns immediately. Otherwise tokens refill at `max_bytes_per_sec` per
93/// second, allowing a burst of up to ~1 second's worth of budget.
94///
95/// The limiter is [`Arc`]-shareable and [`Clone`] (cloning shares the same
96/// underlying bucket).
97#[derive(Debug, Clone)]
98pub struct RateLimiter {
99    inner: Arc<Inner>,
100}
101
102impl RateLimiter {
103    /// Builds a limiter. `None` (or `Some(0)`) yields an unlimited, no-op
104    /// limiter whose [`acquire`](RateLimiter::acquire) never waits.
105    #[must_use]
106    pub fn new(max_bytes_per_sec: Option<u64>) -> Self {
107        let inner = match max_bytes_per_sec {
108            Some(rate) if rate > 0 => {
109                #[allow(clippy::cast_precision_loss)]
110                let rate = rate as f64;
111                Inner {
112                    rate,
113                    capacity: rate,
114                    bucket: Some(Mutex::new(Bucket {
115                        tokens: rate,
116                        last_refill: tokio::time::Instant::now(),
117                    })),
118                }
119            }
120            _ => Inner {
121                rate: 0.0,
122                capacity: 0.0,
123                bucket: None,
124            },
125        };
126        Self {
127            inner: Arc::new(inner),
128        }
129    }
130
131    /// Blocks until `n` bytes of budget are available, refilling the bucket at
132    /// the configured rate. Unlimited limiters return immediately.
133    ///
134    /// A single request larger than the bucket capacity is still satisfied: the
135    /// bucket is allowed to go negative and the caller waits out the deficit,
136    /// so throttling is correct even for objects bigger than one second's
137    /// worth of budget.
138    pub async fn acquire(&self, n: u64) {
139        let Some(bucket) = self.inner.bucket.as_ref() else {
140            return; // unlimited fast path
141        };
142        if n == 0 {
143            return;
144        }
145        #[allow(clippy::cast_precision_loss)]
146        let need = n as f64;
147
148        loop {
149            let wait = {
150                let mut state = bucket.lock().await;
151                let now = tokio::time::Instant::now();
152                let elapsed = now.duration_since(state.last_refill).as_secs_f64();
153                state.tokens = (state.tokens + elapsed * self.inner.rate).min(self.inner.capacity);
154                state.last_refill = now;
155
156                if state.tokens >= need {
157                    state.tokens -= need;
158                    return;
159                }
160                // Not enough budget: compute how long until the deficit is
161                // covered, then sleep (releasing the lock first).
162                let deficit = need - state.tokens;
163                deficit / self.inner.rate
164            };
165            tokio::time::sleep(Duration::from_secs_f64(wait)).await;
166        }
167    }
168}
169
170/// Shared token-bucket state for [`BlockingRateLimiter`], guarded by a
171/// **synchronous** [`std::sync::Mutex`] (not tokio's async mutex).
172#[derive(Debug)]
173struct BlockingBucket {
174    /// Currently available tokens (bytes).
175    tokens: f64,
176    /// Last time the bucket was refilled.
177    last_refill: std::time::Instant,
178}
179
180/// Inner state of a [`BlockingRateLimiter`].
181#[derive(Debug)]
182struct BlockingInner {
183    /// Refill rate in bytes per second.
184    rate: f64,
185    /// Maximum burst capacity, in bytes (~1 second's worth of budget).
186    capacity: f64,
187    /// `None` when unlimited; otherwise the live bucket state.
188    bucket: Option<std::sync::Mutex<BlockingBucket>>,
189}
190
191/// A **synchronous** token-bucket rate limiter for the store-to-store sync
192/// path.
193///
194/// This is the blocking sibling of [`RateLimiter`]. The
195/// [`StreamStore`](crate::stream::StreamStore) methods are synchronous and
196/// drive their backends' async SDK calls on an internal runtime via `block_on`,
197/// so the store-to-store sync orchestrator parallelizes them across a **rayon**
198/// thread pool of plain OS threads — it cannot use the async [`RateLimiter`]
199/// (awaiting inside a `block_on`-ing rayon worker would nest tokio runtimes).
200/// [`acquire_blocking`](BlockingRateLimiter::acquire_blocking) therefore parks
201/// the calling OS thread with [`std::thread::sleep`] instead of `.await`.
202///
203/// When `max_bytes_per_sec` is `None` (or `Some(0)`), the limiter is unlimited
204/// and [`acquire_blocking`](BlockingRateLimiter::acquire_blocking) returns
205/// immediately. Otherwise tokens refill at `max_bytes_per_sec` per second,
206/// allowing a burst of up to ~1 second's worth of budget. The token math
207/// mirrors [`RateLimiter::acquire`] exactly.
208///
209/// The limiter is [`Arc`]-shareable and [`Clone`] (cloning shares the same
210/// underlying bucket), so every rayon worker throttles against one aggregate
211/// budget.
212#[derive(Debug, Clone)]
213pub struct BlockingRateLimiter {
214    inner: Arc<BlockingInner>,
215}
216
217impl BlockingRateLimiter {
218    /// Builds a synchronous limiter. `None` (or `Some(0)`) yields an unlimited,
219    /// no-op limiter whose
220    /// [`acquire_blocking`](BlockingRateLimiter::acquire_blocking) never waits.
221    #[must_use]
222    pub fn new(max_bytes_per_sec: Option<u64>) -> Self {
223        let inner = match max_bytes_per_sec {
224            Some(rate) if rate > 0 => {
225                #[allow(clippy::cast_precision_loss)]
226                let rate = rate as f64;
227                BlockingInner {
228                    rate,
229                    capacity: rate,
230                    bucket: Some(std::sync::Mutex::new(BlockingBucket {
231                        tokens: rate,
232                        last_refill: std::time::Instant::now(),
233                    })),
234                }
235            }
236            _ => BlockingInner {
237                rate: 0.0,
238                capacity: 0.0,
239                bucket: None,
240            },
241        };
242        Self {
243            inner: Arc::new(inner),
244        }
245    }
246
247    /// Blocks the calling OS thread until `n` bytes of budget are available,
248    /// refilling the bucket at the configured rate. Unlimited limiters return
249    /// immediately.
250    ///
251    /// A single request larger than the bucket capacity is still satisfied: the
252    /// bucket is allowed to go negative and the caller waits out the deficit,
253    /// so throttling is correct even for objects bigger than one second's worth
254    /// of budget. Mirrors [`RateLimiter::acquire`], but parks the thread with
255    /// [`std::thread::sleep`] instead of awaiting.
256    pub fn acquire_blocking(&self, n: u64) {
257        let Some(bucket) = self.inner.bucket.as_ref() else {
258            return; // unlimited fast path
259        };
260        if n == 0 {
261            return;
262        }
263        #[allow(clippy::cast_precision_loss)]
264        let need = n as f64;
265
266        loop {
267            let wait = {
268                // A poisoned bucket only means a thread panicked mid-acquire;
269                // the token state is still usable, so recover the guard.
270                let mut state = bucket
271                    .lock()
272                    .unwrap_or_else(std::sync::PoisonError::into_inner);
273                let now = std::time::Instant::now();
274                let elapsed = now.duration_since(state.last_refill).as_secs_f64();
275                state.tokens = (state.tokens + elapsed * self.inner.rate).min(self.inner.capacity);
276                state.last_refill = now;
277
278                if state.tokens >= need {
279                    state.tokens -= need;
280                    return;
281                }
282                // Not enough budget: compute how long until the deficit is
283                // covered, then sleep (releasing the lock first).
284                let deficit = need - state.tokens;
285                deficit / self.inner.rate
286            };
287            std::thread::sleep(Duration::from_secs_f64(wait));
288        }
289    }
290}
291
292/// Runs `op` over `items` with at most `concurrency` operations in flight,
293/// collecting their results in completion-independent order and returning the
294/// first error encountered (remaining in-flight work is cancelled).
295///
296/// This is the engine later gates use to drive concurrent uploads/downloads.
297///
298/// # Errors
299///
300/// Returns the first [`StoreError`] produced by any operation.
301pub async fn run_concurrent<I, T, F, Fut>(
302    items: I,
303    concurrency: NonZeroUsize,
304    op: F,
305) -> Result<Vec<T>, StoreError>
306where
307    I: IntoIterator,
308    F: Fn(I::Item) -> Fut,
309    Fut: std::future::Future<Output = Result<T, StoreError>>,
310{
311    stream::iter(items)
312        .map(op)
313        .buffer_unordered(concurrency.get())
314        .try_collect()
315        .await
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use std::sync::atomic::{AtomicUsize, Ordering};
322
323    /// Builds a current-thread tokio runtime with time enabled, avoiding a
324    /// dependency on the `#[tokio::test]` macro (keeps tokio's feature set
325    /// minimal).
326    fn runtime() -> tokio::runtime::Runtime {
327        tokio::runtime::Builder::new_current_thread()
328            .enable_time()
329            .build()
330            .expect("build tokio runtime")
331    }
332
333    #[test]
334    fn transfer_config_default_caps_concurrency() {
335        let cfg = TransferConfig::default();
336        assert!(cfg.concurrency.get() >= 1, "concurrency must be >= 1");
337        assert!(
338            cfg.concurrency.get() <= DEFAULT_CONCURRENCY_CAP,
339            "default concurrency must be capped at {DEFAULT_CONCURRENCY_CAP}, got {}",
340            cfg.concurrency.get()
341        );
342        assert_eq!(cfg.max_bytes_per_sec, None);
343
344        // The clamping ctor never yields 0.
345        assert_eq!(TransferConfig::new(0, None).concurrency.get(), 1);
346        assert_eq!(TransferConfig::new(7, Some(99)).concurrency.get(), 7);
347        assert_eq!(TransferConfig::new(7, Some(99)).max_bytes_per_sec, Some(99));
348    }
349
350    /// Drives `run_concurrent` over N > concurrency items, recording the peak
351    /// number of simultaneously-running ops, and asserts the bound is exactly
352    /// `min(concurrency, N)` — and strictly 1 (sequential) when concurrency=1.
353    fn max_in_flight_for(concurrency: usize, items: usize) -> usize {
354        let in_flight = Arc::new(AtomicUsize::new(0));
355        let high_water = Arc::new(AtomicUsize::new(0));
356
357        let rt = runtime();
358        let result = rt.block_on(async {
359            let in_flight = Arc::clone(&in_flight);
360            let high_water = Arc::clone(&high_water);
361            run_concurrent(
362                0..items,
363                NonZeroUsize::new(concurrency).unwrap(),
364                move |_item| {
365                    let in_flight = Arc::clone(&in_flight);
366                    let high_water = Arc::clone(&high_water);
367                    async move {
368                        let cur = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
369                        high_water.fetch_max(cur, Ordering::SeqCst);
370                        tokio::time::sleep(Duration::from_millis(20)).await;
371                        in_flight.fetch_sub(1, Ordering::SeqCst);
372                        Ok::<_, StoreError>(())
373                    }
374                },
375            )
376            .await
377        });
378        assert!(result.is_ok());
379        high_water.load(Ordering::SeqCst)
380    }
381
382    #[test]
383    fn transfer_config_run_concurrent_max_in_flight() {
384        // concurrency=4 over 12 items: peak in-flight is exactly 4.
385        assert_eq!(max_in_flight_for(4, 12), 4);
386        // concurrency=1 over 5 items: strictly sequential, peak in-flight is 1.
387        assert_eq!(max_in_flight_for(1, 5), 1);
388        // concurrency greater than item count is bounded by the item count.
389        assert_eq!(max_in_flight_for(8, 3), 3);
390    }
391
392    #[test]
393    fn transfer_config_run_concurrent_propagates_error() {
394        let rt = runtime();
395        let result: Result<Vec<()>, StoreError> = rt.block_on(async {
396            run_concurrent(0..10, NonZeroUsize::new(3).unwrap(), |item| async move {
397                if item == 5 {
398                    Err(StoreError::Backend {
399                        message: "boom".to_owned(),
400                        source: None,
401                    })
402                } else {
403                    tokio::time::sleep(Duration::from_millis(5)).await;
404                    Ok(())
405                }
406            })
407            .await
408        });
409        let err = result.expect_err("must surface the failing op's error");
410        assert!(
411            matches!(err, StoreError::Backend { ref message, .. } if message == "boom"),
412            "unexpected error: {err:?}"
413        );
414    }
415
416    #[test]
417    fn sync_snapshot_blocking_rate_limiter() {
418        use std::time::Instant;
419
420        // Unlimited: acquiring a large amount returns essentially instantly.
421        let unlimited = BlockingRateLimiter::new(None);
422        let start = Instant::now();
423        unlimited.acquire_blocking(1_000_000);
424        assert!(
425            start.elapsed() < Duration::from_millis(200),
426            "unlimited acquire_blocking should not block"
427        );
428        // Some(0) is also unlimited.
429        let zero = BlockingRateLimiter::new(Some(0));
430        let start = Instant::now();
431        zero.acquire_blocking(1_000_000);
432        assert!(
433            start.elapsed() < Duration::from_millis(200),
434            "Some(0) acquire_blocking should not block"
435        );
436
437        // Limited to 1000 bytes/sec. The bucket starts full (1000), so the
438        // first 1000 bytes are free; acquiring another ~1000 bytes (2x the
439        // per-second budget in total) must wait for the deficit to refill —
440        // at least ~1s.
441        let limiter = BlockingRateLimiter::new(Some(1000));
442        let start = Instant::now();
443        limiter.acquire_blocking(1000); // drains the initial burst
444        limiter.acquire_blocking(1000); // must wait ~1s to refill
445        let elapsed = start.elapsed();
446        assert!(
447            elapsed >= Duration::from_millis(900),
448            "throttled acquire_blocking should take ~1s, took {elapsed:?}"
449        );
450    }
451
452    #[test]
453    fn transfer_config_rate_limiter() {
454        let rt = runtime();
455        rt.block_on(async {
456            // Unlimited: acquiring a large amount returns essentially instantly.
457            let unlimited = RateLimiter::new(None);
458            let start = tokio::time::Instant::now();
459            unlimited.acquire(1_000_000).await;
460            assert!(
461                start.elapsed() < Duration::from_millis(200),
462                "unlimited acquire should not block"
463            );
464
465            // Limited to 1000 bytes/sec. The bucket starts full (1000), so the
466            // first 1000 bytes are free; acquiring another ~2000 bytes total
467            // must wait for the deficit to refill — at least ~1s.
468            let limiter = RateLimiter::new(Some(1000));
469            let start = tokio::time::Instant::now();
470            limiter.acquire(1000).await; // drains the initial burst
471            limiter.acquire(1000).await; // must wait ~1s to refill
472            let elapsed = start.elapsed();
473            assert!(
474                elapsed >= Duration::from_millis(900),
475                "throttled acquire should take ~1s, took {elapsed:?}"
476            );
477        });
478    }
479}