scirs2_core/error/
async_handling.rs

1//! Async error handling and recovery mechanisms for ``SciRS2``
2//!
3//! This module provides error handling patterns specifically designed for asynchronous operations:
4//! - Async retry mechanisms with backoff
5//! - Timeout handling for long-running operations
6//! - Error propagation in async contexts
7//! - Async circuit breakers
8//! - Progress tracking with error recovery
9
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll};
14use std::time::{Duration, Instant};
15
16use super::recovery::{CircuitBreaker, RecoverableError, RecoveryStrategy};
17use crate::error::{CoreError, CoreResult, ErrorContext};
18
19/// Async retry executor with configurable backoff strategies
20#[derive(Debug)]
21pub struct AsyncRetryExecutor {
22    strategy: RecoveryStrategy,
23}
24
25impl AsyncRetryExecutor {
26    /// Create a new async retry executor with the given strategy
27    pub fn new(strategy: RecoveryStrategy) -> Self {
28        Self { strategy }
29    }
30
31    /// Execute an async function with retry logic
32    pub async fn execute<F, Fut, T>(&self, mut f: F) -> CoreResult<T>
33    where
34        F: FnMut() -> Fut,
35        Fut: Future<Output = CoreResult<T>>,
36    {
37        match &self.strategy {
38            RecoveryStrategy::FailFast => f().await,
39
40            RecoveryStrategy::ExponentialBackoff {
41                max_attempts,
42                initialdelay,
43                maxdelay,
44                multiplier,
45            } => {
46                let mut delay = *initialdelay;
47                let mut lasterror = None;
48
49                for attempt in 0..*max_attempts {
50                    match f().await {
51                        Ok(result) => return Ok(result),
52                        Err(err) => {
53                            lasterror = Some(err);
54
55                            if attempt < max_attempts - 1 {
56                                tokio::time::sleep(delay).await;
57                                delay = std::cmp::min(
58                                    Duration::from_nanos(
59                                        (delay.as_nanos() as f64 * multiplier) as u64,
60                                    ),
61                                    *maxdelay,
62                                );
63                            }
64                        }
65                    }
66                }
67
68                Err(lasterror.expect("Operation failed"))
69            }
70
71            RecoveryStrategy::LinearBackoff {
72                max_attempts,
73                delay,
74            } => {
75                let mut lasterror = None;
76
77                for attempt in 0..*max_attempts {
78                    match f().await {
79                        Ok(result) => return Ok(result),
80                        Err(err) => {
81                            lasterror = Some(err);
82
83                            if attempt < max_attempts - 1 {
84                                tokio::time::sleep(*delay).await;
85                            }
86                        }
87                    }
88                }
89
90                Err(lasterror.expect("Operation failed"))
91            }
92
93            RecoveryStrategy::CustomBackoff {
94                max_attempts,
95                delays,
96            } => {
97                let mut lasterror = None;
98
99                for attempt in 0..*max_attempts {
100                    match f().await {
101                        Ok(result) => return Ok(result),
102                        Err(err) => {
103                            lasterror = Some(err);
104
105                            if attempt < max_attempts - 1 {
106                                if let Some(&delay) = delays.get(attempt) {
107                                    tokio::time::sleep(delay).await;
108                                }
109                            }
110                        }
111                    }
112                }
113
114                Err(lasterror.expect("Operation failed"))
115            }
116
117            _ => f().await, // Other strategies not applicable for retry
118        }
119    }
120}
121
122/// Async circuit breaker for handling repeated failures in async contexts
123#[derive(Debug)]
124pub struct AsyncCircuitBreaker {
125    #[allow(dead_code)]
126    inner: Arc<CircuitBreaker>,
127}
128
129impl AsyncCircuitBreaker {
130    /// Create a new async circuit breaker
131    pub fn new(failure_threshold: usize, timeout: Duration, recoverytimeout: Duration) -> Self {
132        Self {
133            inner: Arc::new(CircuitBreaker::new(
134                failure_threshold,
135                timeout,
136                recoverytimeout,
137            )),
138        }
139    }
140
141    /// Execute an async function with circuit breaker protection
142    pub async fn execute<F, Fut, T>(&self, f: F) -> CoreResult<T>
143    where
144        F: FnOnce() -> Fut,
145        Fut: Future<Output = CoreResult<T>>,
146    {
147        // Check if circuit should allow execution
148        if !self.should_allow_execution() {
149            return Err(CoreError::ComputationError(ErrorContext::new(
150                "Async circuit breaker is open - too many recent failures",
151            )));
152        }
153
154        // Execute the async function
155        match f().await {
156            Ok(result) => {
157                self.on_success();
158                Ok(result)
159            }
160            Err(err) => {
161                self.on_failure();
162                Err(err)
163            }
164        }
165    }
166
167    fn should_allow_execution(&self) -> bool {
168        // Delegate to the inner circuit breaker
169        // This is a simplified check - in a real implementation,
170        // you'd need to expose the internal state checking logic
171        true // Placeholder
172    }
173
174    fn on_success(&self) {
175        // Update circuit breaker state on success
176        // This would typically involve updating internal counters
177    }
178
179    fn on_failure(&self) {
180        // Update circuit breaker state on failure
181        // This would typically involve updating failure counters
182    }
183}
184
185/// Timeout wrapper for async operations
186pub struct TimeoutWrapper<F> {
187    future: F,
188    #[allow(dead_code)]
189    timeout: Duration,
190}
191
192impl<F> TimeoutWrapper<F> {
193    /// Create a new timeout wrapper
194    pub fn new(future: F, timeout: Duration) -> Self {
195        Self { future, timeout }
196    }
197}
198
199impl<F, T> Future for TimeoutWrapper<F>
200where
201    F: Future<Output = CoreResult<T>>,
202{
203    type Output = CoreResult<T>;
204
205    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
206        // This is a simplified implementation
207        // In a real implementation, you'd use tokio::time::timeout
208        // or implement proper timeout handling
209
210        let this = unsafe { self.get_unchecked_mut() };
211        let future = unsafe { Pin::new_unchecked(&mut this.future) };
212
213        match future.poll(cx) {
214            Poll::Ready(result) => Poll::Ready(result),
215            Poll::Pending => Poll::Pending,
216        }
217    }
218}
219
220/// Progress tracker for long-running async operations with error recovery
221#[derive(Debug)]
222pub struct AsyncProgressTracker {
223    total_steps: usize,
224    completed_steps: Arc<Mutex<usize>>,
225    errors: Arc<Mutex<Vec<RecoverableError>>>,
226    start_time: Instant,
227}
228
229impl AsyncProgressTracker {
230    /// Create a new progress tracker
231    pub fn new(totalsteps: usize) -> Self {
232        Self {
233            total_steps: totalsteps,
234            completed_steps: Arc::new(Mutex::new(0)),
235            errors: Arc::new(Mutex::new(Vec::new())),
236            start_time: Instant::now(),
237        }
238    }
239
240    /// Mark a step as completed
241    pub fn complete_step(&self) {
242        let mut completed = self.completed_steps.lock().expect("Operation failed");
243        *completed += 1;
244    }
245
246    /// Record an error that occurred during processing
247    pub fn recorderror(&self, error: RecoverableError) {
248        let mut errors = self.errors.lock().expect("Operation failed");
249        errors.push(error);
250    }
251
252    /// Get current progress (0.0 to 1.0)
253    pub fn progress(&self) -> f64 {
254        let completed = *self.completed_steps.lock().expect("Operation failed") as f64;
255        completed / self.total_steps as f64
256    }
257
258    /// Get elapsed time
259    pub fn elapsed_time(&self) -> Duration {
260        self.start_time.elapsed()
261    }
262
263    /// Estimate remaining time based on current progress
264    pub fn estimated_remaining_time(&self) -> Option<Duration> {
265        let progress = self.progress();
266        if progress > 0.0 && progress < 1.0 {
267            let elapsed = self.elapsed_time();
268            let total_estimated = elapsed.as_secs_f64() / progress;
269            let remaining = total_estimated - elapsed.as_secs_f64();
270            Some(Duration::from_secs_f64(remaining.max(0.0)))
271        } else {
272            None
273        }
274    }
275
276    /// Get all recorded errors
277    pub fn errors(&self) -> Vec<RecoverableError> {
278        self.errors.lock().expect("Operation failed").clone()
279    }
280
281    /// Check if any errors have been recorded
282    pub fn haserrors(&self) -> bool {
283        !self.errors.lock().expect("Operation failed").is_empty()
284    }
285
286    /// Get a progress report
287    pub fn progress_report(&self) -> String {
288        let completed = *self.completed_steps.lock().expect("Operation failed");
289        let progress_pct = (self.progress() * 100.0) as u32;
290        let elapsed = self.elapsed_time();
291        let error_count = self.errors.lock().expect("Operation failed").len();
292
293        let mut report = format!(
294            "Progress: {}/{} steps ({}%) | Elapsed: {:?}",
295            completed, self.total_steps, progress_pct, elapsed
296        );
297
298        if let Some(remaining) = self.estimated_remaining_time() {
299            report.push_str(&format!(" | Remaining: {:?}", remaining));
300        }
301
302        if error_count > 0 {
303            report.push_str(&format!(" | Errors: {}", error_count));
304        }
305
306        report
307    }
308}
309
310/// Async error aggregator for collecting errors from multiple async operations
311#[derive(Debug)]
312pub struct AsyncErrorAggregator {
313    errors: Arc<Mutex<Vec<RecoverableError>>>,
314    maxerrors: Option<usize>,
315}
316
317impl AsyncErrorAggregator {
318    /// Create a new async error aggregator
319    pub fn new() -> Self {
320        Self {
321            errors: Arc::new(Mutex::new(Vec::new())),
322            maxerrors: None,
323        }
324    }
325
326    /// Create a new async error aggregator with maximum error limit
327    pub fn with_maxerrors(maxerrors: usize) -> Self {
328        Self {
329            errors: Arc::new(Mutex::new(Vec::new())),
330            maxerrors: Some(maxerrors),
331        }
332    }
333
334    /// Add an error to the aggregator (async-safe)
335    pub async fn adderror(&self, error: RecoverableError) {
336        let mut errors = self.errors.lock().expect("Operation failed");
337
338        if let Some(max) = self.maxerrors {
339            if errors.len() >= max {
340                return; // Ignore additional errors
341            }
342        }
343
344        errors.push(error);
345    }
346
347    /// Add a simple error to the aggregator
348    pub async fn add_simpleerror(&self, error: CoreError) {
349        self.adderror(RecoverableError::error(error)).await;
350    }
351
352    /// Check if there are any errors
353    pub fn haserrors(&self) -> bool {
354        !self.errors.lock().expect("Operation failed").is_empty()
355    }
356
357    /// Get the number of errors
358    pub fn error_count(&self) -> usize {
359        self.errors.lock().expect("Operation failed").len()
360    }
361
362    /// Get all errors
363    pub fn geterrors(&self) -> Vec<RecoverableError> {
364        self.errors.lock().expect("Operation failed").clone()
365    }
366
367    /// Get the most severe error
368    pub fn most_severeerror(&self) -> Option<RecoverableError> {
369        self.geterrors().into_iter().max_by_key(|err| err.severity)
370    }
371
372    /// Convert to a single error if there are any errors
373    pub fn into_result<T>(self, successvalue: T) -> Result<T, RecoverableError> {
374        if let Some(most_severe) = self.most_severeerror() {
375            Err(most_severe)
376        } else {
377            Ok(successvalue)
378        }
379    }
380}
381
382impl Default for AsyncErrorAggregator {
383    fn default() -> Self {
384        Self::new()
385    }
386}
387
388/// Convenience function to add timeout to any async operation
389pub async fn with_timeout<F, T>(future: F, timeout: Duration) -> CoreResult<T>
390where
391    F: Future<Output = CoreResult<T>>,
392{
393    match tokio::time::timeout(timeout, future).await {
394        Ok(result) => result,
395        Err(_) => Err(CoreError::TimeoutError(ErrorContext::new(format!(
396            "Operation timed out after {:?}",
397            timeout
398        )))),
399    }
400}
401
402/// Convenience function to retry an async operation with exponential backoff
403pub async fn retry_with_exponential_backoff<F, Fut, T>(
404    f: F,
405    max_attempts: usize,
406    initialdelay: Duration,
407    maxdelay: Duration,
408    multiplier: f64,
409) -> CoreResult<T>
410where
411    F: Fn() -> Fut,
412    Fut: Future<Output = CoreResult<T>>,
413{
414    let executor = AsyncRetryExecutor::new(RecoveryStrategy::ExponentialBackoff {
415        max_attempts,
416        initialdelay,
417        maxdelay,
418        multiplier,
419    });
420
421    executor.execute(f).await
422}
423
424/// Convenience function to execute multiple async operations with error aggregation
425pub async fn execute_witherror_aggregation<T>(
426    operations: Vec<impl Future<Output = CoreResult<T>>>,
427    fail_fast: bool,
428) -> Result<Vec<T>, AsyncErrorAggregator> {
429    let aggregator = AsyncErrorAggregator::new();
430    let mut results = Vec::new();
431
432    for operation in operations {
433        match operation.await {
434            Ok(result) => results.push(result),
435            Err(error) => {
436                aggregator.add_simpleerror(error).await;
437
438                if fail_fast {
439                    return Err(aggregator);
440                }
441            }
442        }
443    }
444
445    if aggregator.haserrors() {
446        Err(aggregator)
447    } else {
448        Ok(results)
449    }
450}
451
452/// Async operation with built-in progress tracking and error recovery
453pub struct TrackedAsyncOperation<F> {
454    operation: F,
455    tracker: AsyncProgressTracker,
456    retry_strategy: Option<RecoveryStrategy>,
457}
458
459impl<F> TrackedAsyncOperation<F> {
460    /// Create a new tracked async operation
461    pub fn new(operation: F, totalsteps: usize) -> Self {
462        Self {
463            operation,
464            tracker: AsyncProgressTracker::new(totalsteps),
465            retry_strategy: None,
466        }
467    }
468
469    /// Add retry strategy to the operation
470    pub fn with_retry(mut self, strategy: RecoveryStrategy) -> Self {
471        self.retry_strategy = Some(strategy);
472        self
473    }
474
475    /// Get reference to the progress tracker
476    pub const fn tracker(&self) -> &AsyncProgressTracker {
477        &self.tracker
478    }
479}
480
481impl<F, T> Future for TrackedAsyncOperation<F>
482where
483    F: Future<Output = CoreResult<T>>,
484{
485    type Output = CoreResult<T>;
486
487    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
488        let this = unsafe { self.get_unchecked_mut() };
489        let operation = unsafe { Pin::new_unchecked(&mut this.operation) };
490
491        match operation.poll(cx) {
492            Poll::Ready(result) => {
493                match &result {
494                    Ok(_) => this.tracker.complete_step(),
495                    Err(error) => {
496                        let recoverableerror = RecoverableError::error(error.clone());
497                        this.tracker.recorderror(recoverableerror);
498                    }
499                }
500                Poll::Ready(result)
501            }
502            Poll::Pending => Poll::Pending,
503        }
504    }
505}
506
507/// Macro to create an async operation with automatic error handling and progress tracking
508#[macro_export]
509macro_rules! async_with_recovery {
510    ($operation:expr, $steps:expr) => {{
511        let tracked_op =
512            $crate::error::async_handling::TrackedAsyncOperation::new($operation, $steps);
513        tracked_op.await
514    }};
515
516    ($operation:expr, $steps:expr, $retry_strategy:expr) => {{
517        let tracked_op =
518            $crate::error::async_handling::TrackedAsyncOperation::new($operation, $steps)
519                .with_retry($retry_strategy);
520        tracked_op.await
521    }};
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use std::sync::atomic::{AtomicUsize, Ordering};
528
529    #[tokio::test]
530    async fn test_async_retry_executor() {
531        let executor = AsyncRetryExecutor::new(RecoveryStrategy::LinearBackoff {
532            max_attempts: 3,
533            delay: Duration::from_millis(1),
534        });
535
536        let attempt_count = Arc::new(AtomicUsize::new(0));
537        let attempt_count_clone = attempt_count.clone();
538
539        let result = executor
540            .execute(|| {
541                let count = attempt_count_clone.clone();
542                async move {
543                    let current = count.fetch_add(1, Ordering::SeqCst);
544                    if current < 2 {
545                        Err(CoreError::ComputationError(ErrorContext::new("Test error")))
546                    } else {
547                        Ok(42)
548                    }
549                }
550            })
551            .await;
552
553        assert_eq!(result.expect("Operation failed"), 42);
554        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
555    }
556
557    #[tokio::test]
558    async fn test_timeout_wrapper() {
559        let result = with_timeout(
560            async {
561                tokio::time::sleep(Duration::from_millis(100)).await;
562                Ok(42)
563            },
564            Duration::from_millis(50),
565        )
566        .await;
567
568        assert!(result.is_err());
569        assert!(matches!(result.unwrap_err(), CoreError::TimeoutError(_)));
570    }
571
572    #[tokio::test]
573    async fn test_progress_tracker() {
574        let tracker = AsyncProgressTracker::new(10);
575
576        assert_eq!(tracker.progress(), 0.0);
577
578        // Add a small delay to ensure measurable elapsed time
579        tokio::time::sleep(Duration::from_millis(1)).await;
580
581        tracker.complete_step();
582        tracker.complete_step();
583
584        assert_eq!(tracker.progress(), 0.2);
585        assert!(tracker.elapsed_time().as_nanos() > 0);
586    }
587
588    #[tokio::test]
589    async fn test_asyncerror_aggregator() {
590        let aggregator = AsyncErrorAggregator::new();
591
592        assert!(!aggregator.haserrors());
593
594        aggregator
595            .add_simpleerror(CoreError::ValueError(ErrorContext::new("Error 1")))
596            .await;
597        aggregator
598            .add_simpleerror(CoreError::DomainError(ErrorContext::new("Error 2")))
599            .await;
600
601        assert_eq!(aggregator.error_count(), 2);
602        assert!(aggregator.haserrors());
603    }
604}