1use async_trait::async_trait;
5use bytes::Bytes;
6use futures::stream::{BoxStream, StreamExt};
7use object_store::path::Path;
8use object_store::{
9 GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
10 PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as StoreResult,
11};
12use std::fmt::{Debug, Display};
13use std::ops::Range;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicI64, AtomicU32, Ordering};
16use std::time::{Duration, SystemTime, UNIX_EPOCH};
17use tracing::warn;
18use uni_common::config::ObjectStoreConfig;
19
20#[derive(Debug)]
21struct CircuitBreaker {
22 failures: AtomicU32,
23 last_failure: AtomicI64, threshold: u32,
25 reset_timeout: Duration,
26}
27
28impl CircuitBreaker {
29 fn new(threshold: u32, reset_timeout: Duration) -> Self {
30 Self {
31 failures: AtomicU32::new(0),
32 last_failure: AtomicI64::new(0),
33 threshold,
34 reset_timeout,
35 }
36 }
37
38 fn allow_request(&self) -> bool {
39 let failures = self.failures.load(Ordering::Relaxed);
40 if failures < self.threshold {
41 return true;
42 }
43
44 let last = self.last_failure.load(Ordering::Relaxed);
45 let now = SystemTime::now()
46 .duration_since(UNIX_EPOCH)
47 .unwrap()
48 .as_millis() as i64;
49
50 if (now - last) > self.reset_timeout.as_millis() as i64 {
51 return true;
54 }
55 false
56 }
57
58 fn report_success(&self) {
59 self.failures.store(0, Ordering::Relaxed);
62 }
63
64 fn report_failure(&self) {
65 self.failures.fetch_add(1, Ordering::Relaxed);
66 let now = SystemTime::now()
67 .duration_since(UNIX_EPOCH)
68 .unwrap()
69 .as_millis() as i64;
70 self.last_failure.store(now, Ordering::Relaxed);
71 }
72}
73
74#[derive(Debug)]
75pub struct ResilientObjectStore {
76 inner: Arc<dyn ObjectStore>,
77 config: ObjectStoreConfig,
78 cb: CircuitBreaker,
79}
80
81impl ResilientObjectStore {
82 pub fn new(inner: Arc<dyn ObjectStore>, config: ObjectStoreConfig) -> Self {
83 let cb = CircuitBreaker::new(5, Duration::from_secs(30)); Self { inner, config, cb }
87 }
88
89 async fn retry<F, Fut, T>(&self, mut f: F, op_name: &str) -> StoreResult<T>
90 where
91 F: FnMut() -> Fut,
92 Fut: std::future::Future<Output = StoreResult<T>>,
93 {
94 if !self.cb.allow_request() {
95 return Err(object_store::Error::Generic {
96 store: "ResilientObjectStore",
97 source: Box::new(std::io::Error::other("Circuit breaker open")),
98 });
99 }
100
101 let mut attempt = 0;
102 let mut backoff = self.config.retry_backoff_base;
103
104 loop {
105 match f().await {
106 Ok(val) => {
107 self.cb.report_success();
108 return Ok(val);
109 }
110 Err(e) => {
111 attempt += 1;
112 if attempt > self.config.max_retries {
113 self.cb.report_failure();
114 return Err(e);
115 }
116
117 let msg = e.to_string().to_lowercase();
119 if msg.contains("not found") || msg.contains("already exists") {
120 return Err(e);
125 }
126
127 warn!(
128 error = %e,
129 attempt,
130 operation = op_name,
131 "ObjectStore operation failed, retrying"
132 );
133
134 tokio::time::sleep(backoff).await;
135 backoff = std::cmp::min(backoff * 2, self.config.retry_backoff_max);
136 }
137 }
138 }
139 }
140
141 async fn timeout<F, Fut, T>(&self, f: F, duration: std::time::Duration) -> StoreResult<T>
142 where
143 F: FnOnce() -> Fut,
144 Fut: std::future::Future<Output = StoreResult<T>>,
145 {
146 tokio::time::timeout(duration, f())
147 .await
148 .map_err(|_| object_store::Error::Generic {
149 store: "ResilientObjectStore",
150 source: Box::new(std::io::Error::new(
151 std::io::ErrorKind::TimedOut,
152 "operation timed out",
153 )),
154 })?
155 }
156}
157
158impl Display for ResilientObjectStore {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 write!(f, "ResilientObjectStore({})", self.inner)
161 }
162}
163
164#[async_trait]
165impl ObjectStore for ResilientObjectStore {
166 async fn put(&self, location: &Path, payload: PutPayload) -> StoreResult<PutResult> {
167 let timeout = self.config.write_timeout;
168 if !self.cb.allow_request() {
171 return Err(object_store::Error::Generic {
172 store: "ResilientObjectStore",
173 source: Box::new(std::io::Error::other("Circuit breaker open")),
174 });
175 }
176
177 let res = self
178 .timeout(|| self.inner.put(location, payload), timeout)
179 .await;
180 match res {
181 Ok(_) => self.cb.report_success(),
182 Err(_) => self.cb.report_failure(), }
184 res
185 }
186
187 async fn put_opts(
188 &self,
189 location: &Path,
190 payload: PutPayload,
191 opts: PutOptions,
192 ) -> StoreResult<PutResult> {
193 let timeout = self.config.write_timeout;
194 if !self.cb.allow_request() {
195 return Err(object_store::Error::Generic {
196 store: "ResilientObjectStore",
197 source: Box::new(std::io::Error::other("Circuit breaker open")),
198 });
199 }
200 let res = self
201 .timeout(|| self.inner.put_opts(location, payload, opts), timeout)
202 .await;
203 match res {
204 Ok(_) => self.cb.report_success(),
205 Err(_) => self.cb.report_failure(),
206 }
207 res
208 }
209
210 async fn put_multipart(&self, location: &Path) -> StoreResult<Box<dyn MultipartUpload>> {
211 self.put_multipart_opts(location, PutMultipartOptions::default())
212 .await
213 }
214
215 async fn put_multipart_opts(
216 &self,
217 location: &Path,
218 opts: PutMultipartOptions,
219 ) -> StoreResult<Box<dyn MultipartUpload>> {
220 let timeout = self.config.write_timeout;
221 self.retry(
222 || async {
223 self.timeout(
224 || self.inner.put_multipart_opts(location, opts.clone()),
225 timeout,
226 )
227 .await
228 },
229 "put_multipart_opts",
230 )
231 .await
232 }
233
234 async fn get(&self, location: &Path) -> StoreResult<GetResult> {
235 self.get_opts(location, GetOptions::default()).await
236 }
237
238 async fn get_opts(&self, location: &Path, options: GetOptions) -> StoreResult<GetResult> {
239 let timeout = self.config.read_timeout;
240 self.retry(
241 || async {
242 self.timeout(|| self.inner.get_opts(location, options.clone()), timeout)
243 .await
244 },
245 "get_opts",
246 )
247 .await
248 }
249
250 async fn get_range(&self, location: &Path, range: Range<u64>) -> StoreResult<Bytes> {
251 let timeout = self.config.read_timeout;
252 self.retry(
253 || async {
254 self.timeout(|| self.inner.get_range(location, range.clone()), timeout)
255 .await
256 },
257 "get_range",
258 )
259 .await
260 }
261
262 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> StoreResult<Vec<Bytes>> {
263 let timeout = self.config.read_timeout;
264 self.retry(
265 || async {
266 self.timeout(|| self.inner.get_ranges(location, ranges), timeout)
267 .await
268 },
269 "get_ranges",
270 )
271 .await
272 }
273
274 async fn head(&self, location: &Path) -> StoreResult<ObjectMeta> {
275 let timeout = self.config.read_timeout;
276 self.retry(
277 || async { self.timeout(|| self.inner.head(location), timeout).await },
278 "head",
279 )
280 .await
281 }
282
283 async fn delete(&self, location: &Path) -> StoreResult<()> {
284 let timeout = self.config.write_timeout;
285 self.retry(
286 || async { self.timeout(|| self.inner.delete(location), timeout).await },
287 "delete",
288 )
289 .await
290 }
291
292 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, StoreResult<ObjectMeta>> {
293 if !self.cb.allow_request() {
296 return futures::stream::once(async {
300 Err(object_store::Error::Generic {
301 store: "ResilientObjectStore",
302 source: Box::new(std::io::Error::other("Circuit breaker open")),
303 })
304 })
305 .boxed();
306 }
307
308 self.inner.list(prefix)
311 }
312
313 fn list_with_offset(
314 &self,
315 prefix: Option<&Path>,
316 offset: &Path,
317 ) -> BoxStream<'static, StoreResult<ObjectMeta>> {
318 if !self.cb.allow_request() {
319 return futures::stream::once(async {
320 Err(object_store::Error::Generic {
321 store: "ResilientObjectStore",
322 source: Box::new(std::io::Error::other("Circuit breaker open")),
323 })
324 })
325 .boxed();
326 }
327 self.inner.list_with_offset(prefix, offset)
328 }
329
330 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> StoreResult<ListResult> {
331 let timeout = self.config.read_timeout;
332 self.retry(
333 || async {
334 self.timeout(|| self.inner.list_with_delimiter(prefix), timeout)
335 .await
336 },
337 "list_with_delimiter",
338 )
339 .await
340 }
341
342 async fn copy(&self, from: &Path, to: &Path) -> StoreResult<()> {
343 let timeout = self.config.write_timeout;
344 self.retry(
345 || async { self.timeout(|| self.inner.copy(from, to), timeout).await },
346 "copy",
347 )
348 .await
349 }
350
351 async fn rename(&self, from: &Path, to: &Path) -> StoreResult<()> {
352 let timeout = self.config.write_timeout;
353 self.retry(
354 || async { self.timeout(|| self.inner.rename(from, to), timeout).await },
355 "rename",
356 )
357 .await
358 }
359
360 async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> StoreResult<()> {
361 let timeout = self.config.write_timeout;
362 self.retry(
363 || async {
364 self.timeout(|| self.inner.copy_if_not_exists(from, to), timeout)
365 .await
366 },
367 "copy_if_not_exists",
368 )
369 .await
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use object_store::memory::InMemory;
377
378 #[test]
381 fn test_circuit_breaker_starts_closed() {
382 let cb = CircuitBreaker::new(5, Duration::from_secs(30));
383 assert!(cb.allow_request());
384 assert_eq!(cb.failures.load(Ordering::Relaxed), 0);
385 }
386
387 #[test]
388 fn test_cb_closed_to_open_after_threshold() {
389 let cb = CircuitBreaker::new(5, Duration::from_secs(30));
390 for _ in 0..5 {
391 cb.report_failure();
392 }
393 assert!(!cb.allow_request(), "CB should be open after 5 failures");
394 }
395
396 #[test]
397 fn test_cb_stays_closed_below_threshold() {
398 let cb = CircuitBreaker::new(5, Duration::from_secs(30));
399 for _ in 0..4 {
400 cb.report_failure();
401 }
402 assert!(cb.allow_request(), "CB should stay closed with 4 failures");
403 }
404
405 #[test]
406 fn test_cb_open_to_half_open_after_timeout() {
407 let cb = CircuitBreaker::new(2, Duration::from_millis(1));
408 cb.report_failure();
409 cb.report_failure();
410 assert!(!cb.allow_request(), "CB should be open");
411
412 std::thread::sleep(Duration::from_millis(5));
413 assert!(
414 cb.allow_request(),
415 "CB should allow probe after reset timeout"
416 );
417 }
418
419 #[test]
420 fn test_cb_half_open_to_closed_on_success() {
421 let cb = CircuitBreaker::new(2, Duration::from_millis(1));
422 cb.report_failure();
423 cb.report_failure();
424 std::thread::sleep(Duration::from_millis(5));
425
426 cb.report_success();
428 assert_eq!(cb.failures.load(Ordering::Relaxed), 0);
429 assert!(cb.allow_request());
430 }
431
432 #[test]
433 fn test_cb_half_open_to_open_on_failure() {
434 let cb = CircuitBreaker::new(2, Duration::from_millis(1));
435 cb.report_failure();
436 cb.report_failure();
437 std::thread::sleep(Duration::from_millis(5));
438
439 cb.report_failure();
441 assert!(
442 !cb.allow_request(),
443 "CB should be open again after failure in half-open"
444 );
445 }
446
447 #[test]
448 fn test_cb_success_resets_failures() {
449 let cb = CircuitBreaker::new(5, Duration::from_secs(30));
450 cb.report_failure();
451 cb.report_failure();
452 cb.report_failure();
453 assert_eq!(cb.failures.load(Ordering::Relaxed), 3);
454
455 cb.report_success();
456 assert_eq!(cb.failures.load(Ordering::Relaxed), 0);
457 }
458
459 #[tokio::test]
462 async fn test_resilient_store_retry_succeeds() {
463 let inner = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
465 let config = ObjectStoreConfig {
466 connect_timeout: Duration::from_secs(5),
467 max_retries: 3,
468 retry_backoff_base: Duration::from_millis(1),
469 retry_backoff_max: Duration::from_millis(10),
470 read_timeout: Duration::from_secs(5),
471 write_timeout: Duration::from_secs(5),
472 };
473 let store = ResilientObjectStore::new(inner, config);
474
475 let path = Path::from("test/key");
477 store
478 .put(&path, PutPayload::from_static(b"hello"))
479 .await
480 .unwrap();
481 let result = store.get(&path).await.unwrap();
482 let bytes = result.bytes().await.unwrap();
483 assert_eq!(bytes.as_ref(), b"hello");
484
485 assert_eq!(store.cb.failures.load(Ordering::Relaxed), 0);
487 }
488
489 #[tokio::test]
490 async fn test_resilient_store_circuit_open_rejects() {
491 let inner = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
492 let config = ObjectStoreConfig {
493 connect_timeout: Duration::from_secs(5),
494 max_retries: 1,
495 retry_backoff_base: Duration::from_millis(1),
496 retry_backoff_max: Duration::from_millis(10),
497 read_timeout: Duration::from_secs(5),
498 write_timeout: Duration::from_secs(5),
499 };
500 let store = ResilientObjectStore::new(inner, config);
501
502 for _ in 0..5 {
504 store.cb.report_failure();
505 }
506
507 let err = store.get(&Path::from("any")).await.unwrap_err();
509 assert!(
510 err.to_string().contains("Circuit breaker open"),
511 "Expected circuit breaker error, got: {}",
512 err
513 );
514 }
515
516 #[tokio::test]
517 async fn test_resilient_store_not_found_skips_retry() {
518 let inner = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
519 let config = ObjectStoreConfig {
520 connect_timeout: Duration::from_secs(5),
521 max_retries: 3,
522 retry_backoff_base: Duration::from_millis(1),
523 retry_backoff_max: Duration::from_millis(10),
524 read_timeout: Duration::from_secs(5),
525 write_timeout: Duration::from_secs(5),
526 };
527 let store = ResilientObjectStore::new(inner, config);
528
529 let err = store.get(&Path::from("nonexistent")).await.unwrap_err();
531 assert!(err.to_string().to_lowercase().contains("not found"));
532 assert_eq!(
533 store.cb.failures.load(Ordering::Relaxed),
534 0,
535 "Not-found errors should not count as CB failures"
536 );
537 }
538}