1use std::future::Future;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll};
14use std::time::{Duration, Instant};
15
16use super::recovery::{CircuitBreaker, RecoverableError, RecoveryStrategy};
17use crate::error::{CoreError, CoreResult, ErrorContext};
18
19#[derive(Debug)]
21pub struct AsyncRetryExecutor {
22 strategy: RecoveryStrategy,
23}
24
25impl AsyncRetryExecutor {
26 pub fn new(strategy: RecoveryStrategy) -> Self {
28 Self { strategy }
29 }
30
31 pub async fn execute<F, Fut, T>(&self, mut f: F) -> CoreResult<T>
33 where
34 F: FnMut() -> Fut,
35 Fut: Future<Output = CoreResult<T>>,
36 {
37 match &self.strategy {
38 RecoveryStrategy::FailFast => f().await,
39
40 RecoveryStrategy::ExponentialBackoff {
41 max_attempts,
42 initialdelay,
43 maxdelay,
44 multiplier,
45 } => {
46 let mut delay = *initialdelay;
47 let mut lasterror = None;
48
49 for attempt in 0..*max_attempts {
50 match f().await {
51 Ok(result) => return Ok(result),
52 Err(err) => {
53 lasterror = Some(err);
54
55 if attempt < max_attempts - 1 {
56 tokio::time::sleep(delay).await;
57 delay = std::cmp::min(
58 Duration::from_nanos(
59 (delay.as_nanos() as f64 * multiplier) as u64,
60 ),
61 *maxdelay,
62 );
63 }
64 }
65 }
66 }
67
68 Err(lasterror.expect("Operation failed"))
69 }
70
71 RecoveryStrategy::LinearBackoff {
72 max_attempts,
73 delay,
74 } => {
75 let mut lasterror = None;
76
77 for attempt in 0..*max_attempts {
78 match f().await {
79 Ok(result) => return Ok(result),
80 Err(err) => {
81 lasterror = Some(err);
82
83 if attempt < max_attempts - 1 {
84 tokio::time::sleep(*delay).await;
85 }
86 }
87 }
88 }
89
90 Err(lasterror.expect("Operation failed"))
91 }
92
93 RecoveryStrategy::CustomBackoff {
94 max_attempts,
95 delays,
96 } => {
97 let mut lasterror = None;
98
99 for attempt in 0..*max_attempts {
100 match f().await {
101 Ok(result) => return Ok(result),
102 Err(err) => {
103 lasterror = Some(err);
104
105 if attempt < max_attempts - 1 {
106 if let Some(&delay) = delays.get(attempt) {
107 tokio::time::sleep(delay).await;
108 }
109 }
110 }
111 }
112 }
113
114 Err(lasterror.expect("Operation failed"))
115 }
116
117 _ => f().await, }
119 }
120}
121
122#[derive(Debug)]
124pub struct AsyncCircuitBreaker {
125 #[allow(dead_code)]
126 inner: Arc<CircuitBreaker>,
127}
128
129impl AsyncCircuitBreaker {
130 pub fn new(failure_threshold: usize, timeout: Duration, recoverytimeout: Duration) -> Self {
132 Self {
133 inner: Arc::new(CircuitBreaker::new(
134 failure_threshold,
135 timeout,
136 recoverytimeout,
137 )),
138 }
139 }
140
141 pub async fn execute<F, Fut, T>(&self, f: F) -> CoreResult<T>
143 where
144 F: FnOnce() -> Fut,
145 Fut: Future<Output = CoreResult<T>>,
146 {
147 if !self.should_allow_execution() {
149 return Err(CoreError::ComputationError(ErrorContext::new(
150 "Async circuit breaker is open - too many recent failures",
151 )));
152 }
153
154 match f().await {
156 Ok(result) => {
157 self.on_success();
158 Ok(result)
159 }
160 Err(err) => {
161 self.on_failure();
162 Err(err)
163 }
164 }
165 }
166
167 fn should_allow_execution(&self) -> bool {
168 true }
173
174 fn on_success(&self) {
175 }
178
179 fn on_failure(&self) {
180 }
183}
184
185pub struct TimeoutWrapper<F> {
187 future: F,
188 #[allow(dead_code)]
189 timeout: Duration,
190}
191
192impl<F> TimeoutWrapper<F> {
193 pub fn new(future: F, timeout: Duration) -> Self {
195 Self { future, timeout }
196 }
197}
198
199impl<F, T> Future for TimeoutWrapper<F>
200where
201 F: Future<Output = CoreResult<T>>,
202{
203 type Output = CoreResult<T>;
204
205 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
206 let this = unsafe { self.get_unchecked_mut() };
211 let future = unsafe { Pin::new_unchecked(&mut this.future) };
212
213 match future.poll(cx) {
214 Poll::Ready(result) => Poll::Ready(result),
215 Poll::Pending => Poll::Pending,
216 }
217 }
218}
219
220#[derive(Debug)]
222pub struct AsyncProgressTracker {
223 total_steps: usize,
224 completed_steps: Arc<Mutex<usize>>,
225 errors: Arc<Mutex<Vec<RecoverableError>>>,
226 start_time: Instant,
227}
228
229impl AsyncProgressTracker {
230 pub fn new(totalsteps: usize) -> Self {
232 Self {
233 total_steps: totalsteps,
234 completed_steps: Arc::new(Mutex::new(0)),
235 errors: Arc::new(Mutex::new(Vec::new())),
236 start_time: Instant::now(),
237 }
238 }
239
240 pub fn complete_step(&self) {
242 let mut completed = self.completed_steps.lock().expect("Operation failed");
243 *completed += 1;
244 }
245
246 pub fn recorderror(&self, error: RecoverableError) {
248 let mut errors = self.errors.lock().expect("Operation failed");
249 errors.push(error);
250 }
251
252 pub fn progress(&self) -> f64 {
254 let completed = *self.completed_steps.lock().expect("Operation failed") as f64;
255 completed / self.total_steps as f64
256 }
257
258 pub fn elapsed_time(&self) -> Duration {
260 self.start_time.elapsed()
261 }
262
263 pub fn estimated_remaining_time(&self) -> Option<Duration> {
265 let progress = self.progress();
266 if progress > 0.0 && progress < 1.0 {
267 let elapsed = self.elapsed_time();
268 let total_estimated = elapsed.as_secs_f64() / progress;
269 let remaining = total_estimated - elapsed.as_secs_f64();
270 Some(Duration::from_secs_f64(remaining.max(0.0)))
271 } else {
272 None
273 }
274 }
275
276 pub fn errors(&self) -> Vec<RecoverableError> {
278 self.errors.lock().expect("Operation failed").clone()
279 }
280
281 pub fn haserrors(&self) -> bool {
283 !self.errors.lock().expect("Operation failed").is_empty()
284 }
285
286 pub fn progress_report(&self) -> String {
288 let completed = *self.completed_steps.lock().expect("Operation failed");
289 let progress_pct = (self.progress() * 100.0) as u32;
290 let elapsed = self.elapsed_time();
291 let error_count = self.errors.lock().expect("Operation failed").len();
292
293 let mut report = format!(
294 "Progress: {}/{} steps ({}%) | Elapsed: {:?}",
295 completed, self.total_steps, progress_pct, elapsed
296 );
297
298 if let Some(remaining) = self.estimated_remaining_time() {
299 report.push_str(&format!(" | Remaining: {:?}", remaining));
300 }
301
302 if error_count > 0 {
303 report.push_str(&format!(" | Errors: {}", error_count));
304 }
305
306 report
307 }
308}
309
310#[derive(Debug)]
312pub struct AsyncErrorAggregator {
313 errors: Arc<Mutex<Vec<RecoverableError>>>,
314 maxerrors: Option<usize>,
315}
316
317impl AsyncErrorAggregator {
318 pub fn new() -> Self {
320 Self {
321 errors: Arc::new(Mutex::new(Vec::new())),
322 maxerrors: None,
323 }
324 }
325
326 pub fn with_maxerrors(maxerrors: usize) -> Self {
328 Self {
329 errors: Arc::new(Mutex::new(Vec::new())),
330 maxerrors: Some(maxerrors),
331 }
332 }
333
334 pub async fn adderror(&self, error: RecoverableError) {
336 let mut errors = self.errors.lock().expect("Operation failed");
337
338 if let Some(max) = self.maxerrors {
339 if errors.len() >= max {
340 return; }
342 }
343
344 errors.push(error);
345 }
346
347 pub async fn add_simpleerror(&self, error: CoreError) {
349 self.adderror(RecoverableError::error(error)).await;
350 }
351
352 pub fn haserrors(&self) -> bool {
354 !self.errors.lock().expect("Operation failed").is_empty()
355 }
356
357 pub fn error_count(&self) -> usize {
359 self.errors.lock().expect("Operation failed").len()
360 }
361
362 pub fn geterrors(&self) -> Vec<RecoverableError> {
364 self.errors.lock().expect("Operation failed").clone()
365 }
366
367 pub fn most_severeerror(&self) -> Option<RecoverableError> {
369 self.geterrors().into_iter().max_by_key(|err| err.severity)
370 }
371
372 pub fn into_result<T>(self, successvalue: T) -> Result<T, RecoverableError> {
374 if let Some(most_severe) = self.most_severeerror() {
375 Err(most_severe)
376 } else {
377 Ok(successvalue)
378 }
379 }
380}
381
382impl Default for AsyncErrorAggregator {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388pub async fn with_timeout<F, T>(future: F, timeout: Duration) -> CoreResult<T>
390where
391 F: Future<Output = CoreResult<T>>,
392{
393 match tokio::time::timeout(timeout, future).await {
394 Ok(result) => result,
395 Err(_) => Err(CoreError::TimeoutError(ErrorContext::new(format!(
396 "Operation timed out after {:?}",
397 timeout
398 )))),
399 }
400}
401
402pub async fn retry_with_exponential_backoff<F, Fut, T>(
404 f: F,
405 max_attempts: usize,
406 initialdelay: Duration,
407 maxdelay: Duration,
408 multiplier: f64,
409) -> CoreResult<T>
410where
411 F: Fn() -> Fut,
412 Fut: Future<Output = CoreResult<T>>,
413{
414 let executor = AsyncRetryExecutor::new(RecoveryStrategy::ExponentialBackoff {
415 max_attempts,
416 initialdelay,
417 maxdelay,
418 multiplier,
419 });
420
421 executor.execute(f).await
422}
423
424pub async fn execute_witherror_aggregation<T>(
426 operations: Vec<impl Future<Output = CoreResult<T>>>,
427 fail_fast: bool,
428) -> Result<Vec<T>, AsyncErrorAggregator> {
429 let aggregator = AsyncErrorAggregator::new();
430 let mut results = Vec::new();
431
432 for operation in operations {
433 match operation.await {
434 Ok(result) => results.push(result),
435 Err(error) => {
436 aggregator.add_simpleerror(error).await;
437
438 if fail_fast {
439 return Err(aggregator);
440 }
441 }
442 }
443 }
444
445 if aggregator.haserrors() {
446 Err(aggregator)
447 } else {
448 Ok(results)
449 }
450}
451
452pub struct TrackedAsyncOperation<F> {
454 operation: F,
455 tracker: AsyncProgressTracker,
456 retry_strategy: Option<RecoveryStrategy>,
457}
458
459impl<F> TrackedAsyncOperation<F> {
460 pub fn new(operation: F, totalsteps: usize) -> Self {
462 Self {
463 operation,
464 tracker: AsyncProgressTracker::new(totalsteps),
465 retry_strategy: None,
466 }
467 }
468
469 pub fn with_retry(mut self, strategy: RecoveryStrategy) -> Self {
471 self.retry_strategy = Some(strategy);
472 self
473 }
474
475 pub const fn tracker(&self) -> &AsyncProgressTracker {
477 &self.tracker
478 }
479}
480
481impl<F, T> Future for TrackedAsyncOperation<F>
482where
483 F: Future<Output = CoreResult<T>>,
484{
485 type Output = CoreResult<T>;
486
487 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
488 let this = unsafe { self.get_unchecked_mut() };
489 let operation = unsafe { Pin::new_unchecked(&mut this.operation) };
490
491 match operation.poll(cx) {
492 Poll::Ready(result) => {
493 match &result {
494 Ok(_) => this.tracker.complete_step(),
495 Err(error) => {
496 let recoverableerror = RecoverableError::error(error.clone());
497 this.tracker.recorderror(recoverableerror);
498 }
499 }
500 Poll::Ready(result)
501 }
502 Poll::Pending => Poll::Pending,
503 }
504 }
505}
506
507#[macro_export]
509macro_rules! async_with_recovery {
510 ($operation:expr, $steps:expr) => {{
511 let tracked_op =
512 $crate::error::async_handling::TrackedAsyncOperation::new($operation, $steps);
513 tracked_op.await
514 }};
515
516 ($operation:expr, $steps:expr, $retry_strategy:expr) => {{
517 let tracked_op =
518 $crate::error::async_handling::TrackedAsyncOperation::new($operation, $steps)
519 .with_retry($retry_strategy);
520 tracked_op.await
521 }};
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use std::sync::atomic::{AtomicUsize, Ordering};
528
529 #[tokio::test]
530 async fn test_async_retry_executor() {
531 let executor = AsyncRetryExecutor::new(RecoveryStrategy::LinearBackoff {
532 max_attempts: 3,
533 delay: Duration::from_millis(1),
534 });
535
536 let attempt_count = Arc::new(AtomicUsize::new(0));
537 let attempt_count_clone = attempt_count.clone();
538
539 let result = executor
540 .execute(|| {
541 let count = attempt_count_clone.clone();
542 async move {
543 let current = count.fetch_add(1, Ordering::SeqCst);
544 if current < 2 {
545 Err(CoreError::ComputationError(ErrorContext::new("Test error")))
546 } else {
547 Ok(42)
548 }
549 }
550 })
551 .await;
552
553 assert_eq!(result.expect("Operation failed"), 42);
554 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
555 }
556
557 #[tokio::test]
558 async fn test_timeout_wrapper() {
559 let result = with_timeout(
560 async {
561 tokio::time::sleep(Duration::from_millis(100)).await;
562 Ok(42)
563 },
564 Duration::from_millis(50),
565 )
566 .await;
567
568 assert!(result.is_err());
569 assert!(matches!(result.unwrap_err(), CoreError::TimeoutError(_)));
570 }
571
572 #[tokio::test]
573 async fn test_progress_tracker() {
574 let tracker = AsyncProgressTracker::new(10);
575
576 assert_eq!(tracker.progress(), 0.0);
577
578 tokio::time::sleep(Duration::from_millis(1)).await;
580
581 tracker.complete_step();
582 tracker.complete_step();
583
584 assert_eq!(tracker.progress(), 0.2);
585 assert!(tracker.elapsed_time().as_nanos() > 0);
586 }
587
588 #[tokio::test]
589 async fn test_asyncerror_aggregator() {
590 let aggregator = AsyncErrorAggregator::new();
591
592 assert!(!aggregator.haserrors());
593
594 aggregator
595 .add_simpleerror(CoreError::ValueError(ErrorContext::new("Error 1")))
596 .await;
597 aggregator
598 .add_simpleerror(CoreError::DomainError(ErrorContext::new("Error 2")))
599 .await;
600
601 assert_eq!(aggregator.error_count(), 2);
602 assert!(aggregator.haserrors());
603 }
604}