snapdir_stores/transfer.rs
1//! Transfer configuration, rate limiting, and bounded-concurrency driver.
2//!
3//! This module is the foundation for concurrent object transfers and bandwidth
4//! limiting. It provides:
5//!
6//! - [`TransferConfig`] — how many objects to transfer in parallel and an
7//! optional aggregate byte-rate cap.
8//! - [`RateLimiter`] — a zero-dependency async token bucket built on
9//! [`tokio::time`], shareable across tasks via [`Arc`].
10//! - [`run_concurrent`] — a generic bounded-concurrency driver that runs up to
11//! `concurrency` async operations in flight and returns the first error.
12//!
13//! Nothing here changes the existing (sequential) push / fetch loops yet; the
14//! stores merely carry a [`TransferConfig`] so later gates can wire these
15//! primitives into their transfer loops.
16
17use std::num::NonZeroUsize;
18use std::sync::Arc;
19use std::time::Duration;
20
21use futures::stream::{self, StreamExt, TryStreamExt};
22use snapdir_core::store::StoreError;
23use tokio::sync::Mutex;
24
25/// Upper bound on the auto-detected default concurrency.
26const DEFAULT_CONCURRENCY_CAP: usize = 16;
27
28/// Configuration for object transfers: how many to run in parallel and an
29/// optional aggregate byte-rate cap.
30///
31/// `Default` auto-detects the available parallelism (capped at
32/// [`DEFAULT_CONCURRENCY_CAP`]) and leaves bandwidth unlimited.
33#[derive(Debug, Clone)]
34pub struct TransferConfig {
35 /// Maximum number of object transfers to run concurrently.
36 pub concurrency: NonZeroUsize,
37 /// Optional aggregate bandwidth cap, in bytes per second. `None` means
38 /// unlimited.
39 pub max_bytes_per_sec: Option<u64>,
40}
41
42impl TransferConfig {
43 /// Builds a config, clamping `concurrency` to at least 1.
44 #[must_use]
45 pub fn new(concurrency: usize, max_bytes_per_sec: Option<u64>) -> Self {
46 Self {
47 concurrency: NonZeroUsize::new(concurrency.max(1)).unwrap_or(NonZeroUsize::MIN),
48 max_bytes_per_sec,
49 }
50 }
51}
52
53impl Default for TransferConfig {
54 fn default() -> Self {
55 let detected = std::thread::available_parallelism()
56 .map_or(1, NonZeroUsize::get)
57 .clamp(1, DEFAULT_CONCURRENCY_CAP);
58 Self {
59 // `detected` is >= 1, so the NonZeroUsize is always Some.
60 concurrency: NonZeroUsize::new(detected).unwrap_or(NonZeroUsize::MIN),
61 max_bytes_per_sec: None,
62 }
63 }
64}
65
66/// Shared token-bucket state, guarded by an async mutex.
67#[derive(Debug)]
68struct Bucket {
69 /// Currently available tokens (bytes).
70 tokens: f64,
71 /// Last time the bucket was refilled.
72 last_refill: tokio::time::Instant,
73}
74
75/// Inner state of a [`RateLimiter`].
76#[derive(Debug)]
77struct Inner {
78 /// Refill rate in bytes per second. `0` is impossible here (unlimited is
79 /// modelled by `bucket = None`).
80 rate: f64,
81 /// Maximum burst capacity, in bytes (~1 second's worth of budget).
82 capacity: f64,
83 /// `None` when unlimited; otherwise the live bucket state.
84 bucket: Option<Mutex<Bucket>>,
85}
86
87/// An async token-bucket rate limiter that throttles aggregate transfer
88/// throughput.
89///
90/// Construct with [`RateLimiter::new`]. When `max_bytes_per_sec` is `None` (or
91/// `Some(0)`), the limiter is unlimited and [`acquire`](RateLimiter::acquire)
92/// returns immediately. Otherwise tokens refill at `max_bytes_per_sec` per
93/// second, allowing a burst of up to ~1 second's worth of budget.
94///
95/// The limiter is [`Arc`]-shareable and [`Clone`] (cloning shares the same
96/// underlying bucket).
97#[derive(Debug, Clone)]
98pub struct RateLimiter {
99 inner: Arc<Inner>,
100}
101
102impl RateLimiter {
103 /// Builds a limiter. `None` (or `Some(0)`) yields an unlimited, no-op
104 /// limiter whose [`acquire`](RateLimiter::acquire) never waits.
105 #[must_use]
106 pub fn new(max_bytes_per_sec: Option<u64>) -> Self {
107 let inner = match max_bytes_per_sec {
108 Some(rate) if rate > 0 => {
109 #[allow(clippy::cast_precision_loss)]
110 let rate = rate as f64;
111 Inner {
112 rate,
113 capacity: rate,
114 bucket: Some(Mutex::new(Bucket {
115 tokens: rate,
116 last_refill: tokio::time::Instant::now(),
117 })),
118 }
119 }
120 _ => Inner {
121 rate: 0.0,
122 capacity: 0.0,
123 bucket: None,
124 },
125 };
126 Self {
127 inner: Arc::new(inner),
128 }
129 }
130
131 /// Blocks until `n` bytes of budget are available, refilling the bucket at
132 /// the configured rate. Unlimited limiters return immediately.
133 ///
134 /// A single request larger than the bucket capacity is still satisfied: the
135 /// bucket is allowed to go negative and the caller waits out the deficit,
136 /// so throttling is correct even for objects bigger than one second's
137 /// worth of budget.
138 pub async fn acquire(&self, n: u64) {
139 let Some(bucket) = self.inner.bucket.as_ref() else {
140 return; // unlimited fast path
141 };
142 if n == 0 {
143 return;
144 }
145 #[allow(clippy::cast_precision_loss)]
146 let need = n as f64;
147
148 loop {
149 let wait = {
150 let mut state = bucket.lock().await;
151 let now = tokio::time::Instant::now();
152 let elapsed = now.duration_since(state.last_refill).as_secs_f64();
153 state.tokens = (state.tokens + elapsed * self.inner.rate).min(self.inner.capacity);
154 state.last_refill = now;
155
156 if state.tokens >= need {
157 state.tokens -= need;
158 return;
159 }
160 // Not enough budget: compute how long until the deficit is
161 // covered, then sleep (releasing the lock first).
162 let deficit = need - state.tokens;
163 deficit / self.inner.rate
164 };
165 tokio::time::sleep(Duration::from_secs_f64(wait)).await;
166 }
167 }
168}
169
170/// Runs `op` over `items` with at most `concurrency` operations in flight,
171/// collecting their results in completion-independent order and returning the
172/// first error encountered (remaining in-flight work is cancelled).
173///
174/// This is the engine later gates use to drive concurrent uploads/downloads.
175///
176/// # Errors
177///
178/// Returns the first [`StoreError`] produced by any operation.
179pub async fn run_concurrent<I, T, F, Fut>(
180 items: I,
181 concurrency: NonZeroUsize,
182 op: F,
183) -> Result<Vec<T>, StoreError>
184where
185 I: IntoIterator,
186 F: Fn(I::Item) -> Fut,
187 Fut: std::future::Future<Output = Result<T, StoreError>>,
188{
189 stream::iter(items)
190 .map(op)
191 .buffer_unordered(concurrency.get())
192 .try_collect()
193 .await
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use std::sync::atomic::{AtomicUsize, Ordering};
200
201 /// Builds a current-thread tokio runtime with time enabled, avoiding a
202 /// dependency on the `#[tokio::test]` macro (keeps tokio's feature set
203 /// minimal).
204 fn runtime() -> tokio::runtime::Runtime {
205 tokio::runtime::Builder::new_current_thread()
206 .enable_time()
207 .build()
208 .expect("build tokio runtime")
209 }
210
211 #[test]
212 fn transfer_config_default_caps_concurrency() {
213 let cfg = TransferConfig::default();
214 assert!(cfg.concurrency.get() >= 1, "concurrency must be >= 1");
215 assert!(
216 cfg.concurrency.get() <= DEFAULT_CONCURRENCY_CAP,
217 "default concurrency must be capped at {DEFAULT_CONCURRENCY_CAP}, got {}",
218 cfg.concurrency.get()
219 );
220 assert_eq!(cfg.max_bytes_per_sec, None);
221
222 // The clamping ctor never yields 0.
223 assert_eq!(TransferConfig::new(0, None).concurrency.get(), 1);
224 assert_eq!(TransferConfig::new(7, Some(99)).concurrency.get(), 7);
225 assert_eq!(TransferConfig::new(7, Some(99)).max_bytes_per_sec, Some(99));
226 }
227
228 /// Drives `run_concurrent` over N > concurrency items, recording the peak
229 /// number of simultaneously-running ops, and asserts the bound is exactly
230 /// `min(concurrency, N)` — and strictly 1 (sequential) when concurrency=1.
231 fn max_in_flight_for(concurrency: usize, items: usize) -> usize {
232 let in_flight = Arc::new(AtomicUsize::new(0));
233 let high_water = Arc::new(AtomicUsize::new(0));
234
235 let rt = runtime();
236 let result = rt.block_on(async {
237 let in_flight = Arc::clone(&in_flight);
238 let high_water = Arc::clone(&high_water);
239 run_concurrent(
240 0..items,
241 NonZeroUsize::new(concurrency).unwrap(),
242 move |_item| {
243 let in_flight = Arc::clone(&in_flight);
244 let high_water = Arc::clone(&high_water);
245 async move {
246 let cur = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
247 high_water.fetch_max(cur, Ordering::SeqCst);
248 tokio::time::sleep(Duration::from_millis(20)).await;
249 in_flight.fetch_sub(1, Ordering::SeqCst);
250 Ok::<_, StoreError>(())
251 }
252 },
253 )
254 .await
255 });
256 assert!(result.is_ok());
257 high_water.load(Ordering::SeqCst)
258 }
259
260 #[test]
261 fn transfer_config_run_concurrent_max_in_flight() {
262 // concurrency=4 over 12 items: peak in-flight is exactly 4.
263 assert_eq!(max_in_flight_for(4, 12), 4);
264 // concurrency=1 over 5 items: strictly sequential, peak in-flight is 1.
265 assert_eq!(max_in_flight_for(1, 5), 1);
266 // concurrency greater than item count is bounded by the item count.
267 assert_eq!(max_in_flight_for(8, 3), 3);
268 }
269
270 #[test]
271 fn transfer_config_run_concurrent_propagates_error() {
272 let rt = runtime();
273 let result: Result<Vec<()>, StoreError> = rt.block_on(async {
274 run_concurrent(0..10, NonZeroUsize::new(3).unwrap(), |item| async move {
275 if item == 5 {
276 Err(StoreError::Backend {
277 message: "boom".to_owned(),
278 source: None,
279 })
280 } else {
281 tokio::time::sleep(Duration::from_millis(5)).await;
282 Ok(())
283 }
284 })
285 .await
286 });
287 let err = result.expect_err("must surface the failing op's error");
288 assert!(
289 matches!(err, StoreError::Backend { ref message, .. } if message == "boom"),
290 "unexpected error: {err:?}"
291 );
292 }
293
294 #[test]
295 fn transfer_config_rate_limiter() {
296 let rt = runtime();
297 rt.block_on(async {
298 // Unlimited: acquiring a large amount returns essentially instantly.
299 let unlimited = RateLimiter::new(None);
300 let start = tokio::time::Instant::now();
301 unlimited.acquire(1_000_000).await;
302 assert!(
303 start.elapsed() < Duration::from_millis(200),
304 "unlimited acquire should not block"
305 );
306
307 // Limited to 1000 bytes/sec. The bucket starts full (1000), so the
308 // first 1000 bytes are free; acquiring another ~2000 bytes total
309 // must wait for the deficit to refill — at least ~1s.
310 let limiter = RateLimiter::new(Some(1000));
311 let start = tokio::time::Instant::now();
312 limiter.acquire(1000).await; // drains the initial burst
313 limiter.acquire(1000).await; // must wait ~1s to refill
314 let elapsed = start.elapsed();
315 assert!(
316 elapsed >= Duration::from_millis(900),
317 "throttled acquire should take ~1s, took {elapsed:?}"
318 );
319 });
320 }
321}