1use anyhow::{Error, Result};
38use serde::{Deserialize, Serialize};
39use std::collections::{HashMap, VecDeque};
40use std::sync::{Arc, Mutex};
41use std::time::{Duration, Instant, SystemTime};
42use uuid::Uuid;
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct RecoveryConfig {
47 pub max_retries: usize,
49 pub base_delay_ms: u64,
51 pub max_delay_ms: u64,
53 pub backoff_multiplier: f64,
55 pub enable_fallback: bool,
57 pub enable_checkpointing: bool,
59 pub memory_pressure_threshold_mb: f64,
61 pub circuit_breaker_threshold: usize,
63 pub circuit_breaker_timeout_s: u64,
65 pub enable_monitoring: bool,
67 pub max_error_history: usize,
69}
70
71impl Default for RecoveryConfig {
72 fn default() -> Self {
73 Self {
74 max_retries: 3,
75 base_delay_ms: 100,
76 max_delay_ms: 30000,
77 backoff_multiplier: 2.0,
78 enable_fallback: true,
79 enable_checkpointing: true,
80 memory_pressure_threshold_mb: 1024.0,
81 circuit_breaker_threshold: 5,
82 circuit_breaker_timeout_s: 60,
83 enable_monitoring: true,
84 max_error_history: 1000,
85 }
86 }
87}
88
89impl RecoveryConfig {
90 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
92 self.max_retries = max_retries;
93 self
94 }
95
96 pub fn with_fallback_enabled(mut self, enabled: bool) -> Self {
98 self.enable_fallback = enabled;
99 self
100 }
101
102 pub fn with_memory_threshold(mut self, threshold_mb: f64) -> Self {
104 self.memory_pressure_threshold_mb = threshold_mb;
105 self
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
111pub enum ErrorCategory {
112 Memory,
114 Compute,
116 Network,
118 Model,
120 Data,
122 Resource,
124 Configuration,
126 Unknown,
128}
129
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
132pub enum RecoveryStrategy {
133 Retry {
135 max_attempts: usize,
136 base_delay_ms: u64,
137 },
138 Fallback { fallback_implementation: String },
140 ResourceReduction { reduction_factor: f64 },
142 Restart { component: String },
144 MemoryCleanup,
146 CheckpointRestore { checkpoint_id: String },
148 Degrade { degraded_mode: String },
150 NoRecovery,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct RecoveryAttempt {
157 pub attempt_id: Uuid,
158 pub timestamp: SystemTime,
159 pub error_category: ErrorCategory,
160 pub strategy: RecoveryStrategy,
161 pub success: bool,
162 pub duration_ms: u64,
163 pub error_message: String,
164 pub context: HashMap<String, String>,
165}
166
167#[derive(Debug, Clone, PartialEq)]
169enum CircuitBreakerState {
170 Closed,
171 Open,
172 HalfOpen,
173}
174
175#[derive(Debug)]
177struct CircuitBreaker {
178 state: CircuitBreakerState,
179 failure_count: usize,
180 last_failure_time: Option<Instant>,
181 failure_threshold: usize,
182 timeout: Duration,
183}
184
185impl CircuitBreaker {
186 fn new(failure_threshold: usize, timeout: Duration) -> Self {
187 Self {
188 state: CircuitBreakerState::Closed,
189 failure_count: 0,
190 last_failure_time: None,
191 failure_threshold,
192 timeout,
193 }
194 }
195
196 fn can_execute(&mut self) -> bool {
197 match self.state {
198 CircuitBreakerState::Closed => true,
199 CircuitBreakerState::Open => {
200 if let Some(last_failure) = self.last_failure_time {
201 if last_failure.elapsed() >= self.timeout {
202 self.state = CircuitBreakerState::HalfOpen;
203 true
204 } else {
205 false
206 }
207 } else {
208 true
209 }
210 },
211 CircuitBreakerState::HalfOpen => true,
212 }
213 }
214
215 fn on_success(&mut self) {
216 self.failure_count = 0;
217 self.state = CircuitBreakerState::Closed;
218 }
219
220 fn on_failure(&mut self) {
221 self.failure_count += 1;
222 self.last_failure_time = Some(Instant::now());
223
224 if self.failure_count >= self.failure_threshold {
225 self.state = CircuitBreakerState::Open;
226 }
227 }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct ModelCheckpoint {
233 pub checkpoint_id: String,
234 pub timestamp: SystemTime,
235 pub model_state: HashMap<String, Vec<u8>>, pub metadata: HashMap<String, String>,
237 pub size_bytes: usize,
238}
239
240impl ModelCheckpoint {
241 pub fn new(model_state: HashMap<String, Vec<u8>>, metadata: HashMap<String, String>) -> Self {
243 let size_bytes = model_state.values().map(|v| v.len()).sum();
244
245 Self {
246 checkpoint_id: Uuid::new_v4().to_string(),
247 timestamp: SystemTime::now(),
248 model_state,
249 metadata,
250 size_bytes,
251 }
252 }
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct RecoveryMetrics {
258 pub total_errors: usize,
259 pub successful_recoveries: usize,
260 pub failed_recoveries: usize,
261 pub average_recovery_time_ms: f64,
262 pub recovery_rate: f64,
263 pub error_frequency: f64,
264 pub most_common_errors: HashMap<ErrorCategory, usize>,
265 pub most_effective_strategies: HashMap<String, f64>,
266}
267
268pub struct ErrorRecoveryManager {
270 config: RecoveryConfig,
271 error_history: VecDeque<RecoveryAttempt>,
272 circuit_breakers: HashMap<String, CircuitBreaker>,
273 checkpoints: HashMap<String, ModelCheckpoint>,
274 recovery_strategies: HashMap<ErrorCategory, Vec<RecoveryStrategy>>,
275 metrics: Arc<Mutex<RecoveryMetrics>>,
276 start_time: Instant,
277}
278
279impl ErrorRecoveryManager {
280 pub fn new(config: RecoveryConfig) -> Self {
282 let mut recovery_strategies = HashMap::new();
283
284 recovery_strategies.insert(
286 ErrorCategory::Memory,
287 vec![
288 RecoveryStrategy::MemoryCleanup,
289 RecoveryStrategy::ResourceReduction {
290 reduction_factor: 0.5,
291 },
292 RecoveryStrategy::CheckpointRestore {
293 checkpoint_id: "latest".to_string(),
294 },
295 ],
296 );
297
298 recovery_strategies.insert(
299 ErrorCategory::Compute,
300 vec![
301 RecoveryStrategy::Retry {
302 max_attempts: 3,
303 base_delay_ms: 1000,
304 },
305 RecoveryStrategy::Fallback {
306 fallback_implementation: "cpu".to_string(),
307 },
308 RecoveryStrategy::Restart {
309 component: "compute_engine".to_string(),
310 },
311 ],
312 );
313
314 recovery_strategies.insert(
315 ErrorCategory::Network,
316 vec![
317 RecoveryStrategy::Retry {
318 max_attempts: 5,
319 base_delay_ms: 2000,
320 },
321 RecoveryStrategy::Fallback {
322 fallback_implementation: "local".to_string(),
323 },
324 ],
325 );
326
327 recovery_strategies.insert(
328 ErrorCategory::Model,
329 vec![
330 RecoveryStrategy::CheckpointRestore {
331 checkpoint_id: "latest".to_string(),
332 },
333 RecoveryStrategy::Degrade {
334 degraded_mode: "simple".to_string(),
335 },
336 RecoveryStrategy::Restart {
337 component: "model".to_string(),
338 },
339 ],
340 );
341
342 recovery_strategies.insert(
343 ErrorCategory::Data,
344 vec![
345 RecoveryStrategy::Retry {
346 max_attempts: 2,
347 base_delay_ms: 100,
348 },
349 RecoveryStrategy::Fallback {
350 fallback_implementation: "default_data".to_string(),
351 },
352 ],
353 );
354
355 recovery_strategies.insert(
356 ErrorCategory::Resource,
357 vec![
358 RecoveryStrategy::Retry {
359 max_attempts: 3,
360 base_delay_ms: 5000,
361 },
362 RecoveryStrategy::ResourceReduction {
363 reduction_factor: 0.7,
364 },
365 ],
366 );
367
368 Self {
369 config,
370 error_history: VecDeque::new(),
371 circuit_breakers: HashMap::new(),
372 checkpoints: HashMap::new(),
373 recovery_strategies,
374 metrics: Arc::new(Mutex::new(RecoveryMetrics {
375 total_errors: 0,
376 successful_recoveries: 0,
377 failed_recoveries: 0,
378 average_recovery_time_ms: 0.0,
379 recovery_rate: 0.0,
380 error_frequency: 0.0,
381 most_common_errors: HashMap::new(),
382 most_effective_strategies: HashMap::new(),
383 })),
384 start_time: Instant::now(),
385 }
386 }
387
388 pub fn execute_with_recovery<T, F>(&mut self, operation: F) -> Result<T>
390 where
391 F: Fn() -> Result<T>,
392 {
393 let operation_name = "default_operation";
394
395 if !self.get_or_create_circuit_breaker(operation_name).can_execute() {
397 return Err(anyhow::anyhow!(
398 "Circuit breaker is open for operation: {}",
399 operation_name
400 ));
401 }
402
403 let mut last_error = None;
404
405 for attempt in 0..=self.config.max_retries {
406 let start_time = Instant::now();
407
408 match operation() {
409 Ok(result) => {
410 self.get_or_create_circuit_breaker(operation_name).on_success();
412
413 if attempt > 0 {
414 self.record_successful_recovery(attempt, start_time);
416 }
417
418 return Ok(result);
419 },
420 Err(error) => {
421 last_error = Some(anyhow::anyhow!(error.to_string()));
422
423 let error_category = self.classify_error(&error);
425 let recovery_success = self
426 .attempt_recovery(&error, error_category.clone(), attempt)
427 .unwrap_or(false);
428
429 if !recovery_success && attempt == self.config.max_retries {
430 self.get_or_create_circuit_breaker(operation_name).on_failure();
432 self.record_failed_recovery(error_category, start_time, &error);
433 break;
434 }
435
436 if attempt < self.config.max_retries {
438 let delay = self.calculate_backoff_delay(attempt);
439 std::thread::sleep(delay);
440 }
441 },
442 }
443 }
444
445 Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Unknown error occurred")))
447 }
448
449 fn classify_error(&self, error: &Error) -> ErrorCategory {
451 let error_string = error.to_string().to_lowercase();
452
453 if error_string.contains("memory")
454 || error_string.contains("oom")
455 || error_string.contains("allocation")
456 {
457 ErrorCategory::Memory
458 } else if error_string.contains("cuda")
459 || error_string.contains("gpu")
460 || error_string.contains("device")
461 {
462 ErrorCategory::Compute
463 } else if error_string.contains("network")
464 || error_string.contains("connection")
465 || error_string.contains("timeout")
466 {
467 ErrorCategory::Network
468 } else if error_string.contains("dimension")
469 || error_string.contains("shape")
470 || error_string.contains("tensor")
471 {
472 ErrorCategory::Model
473 } else if error_string.contains("data")
474 || error_string.contains("input")
475 || error_string.contains("corrupted")
476 {
477 ErrorCategory::Data
478 } else if error_string.contains("resource")
479 || error_string.contains("unavailable")
480 || error_string.contains("busy")
481 {
482 ErrorCategory::Resource
483 } else if error_string.contains("config")
484 || error_string.contains("setup")
485 || error_string.contains("initialization")
486 {
487 ErrorCategory::Configuration
488 } else {
489 ErrorCategory::Unknown
490 }
491 }
492
493 fn attempt_recovery(
495 &mut self,
496 error: &Error,
497 category: ErrorCategory,
498 _attempt: usize,
499 ) -> Result<bool> {
500 let strategies = self.recovery_strategies.get(&category).cloned().unwrap_or_else(|| {
501 vec![RecoveryStrategy::Retry {
502 max_attempts: 1,
503 base_delay_ms: 1000,
504 }]
505 });
506
507 for strategy in strategies {
508 if self.execute_recovery_strategy(&strategy, error, &category)? {
509 self.record_recovery_attempt(category.clone(), strategy, true, error);
510 return Ok(true);
511 }
512 }
513
514 self.record_recovery_attempt(category, RecoveryStrategy::NoRecovery, false, error);
515 Ok(false)
516 }
517
518 fn execute_recovery_strategy(
520 &mut self,
521 strategy: &RecoveryStrategy,
522 _error: &Error,
523 _category: &ErrorCategory,
524 ) -> Result<bool> {
525 match strategy {
526 RecoveryStrategy::Retry {
527 max_attempts: _,
528 base_delay_ms,
529 } => {
530 std::thread::sleep(Duration::from_millis(*base_delay_ms));
532 Ok(true)
533 },
534
535 RecoveryStrategy::MemoryCleanup => {
536 self.perform_memory_cleanup()?;
537 Ok(true)
538 },
539
540 RecoveryStrategy::ResourceReduction { reduction_factor } => {
541 self.reduce_resource_usage(*reduction_factor)?;
542 Ok(true)
543 },
544
545 RecoveryStrategy::CheckpointRestore { checkpoint_id } => {
546 self.restore_from_checkpoint(checkpoint_id)
547 },
548
549 RecoveryStrategy::Fallback {
550 fallback_implementation,
551 } => {
552 self.switch_to_fallback(fallback_implementation)?;
553 Ok(true)
554 },
555
556 RecoveryStrategy::Restart { component } => {
557 self.restart_component(component)?;
558 Ok(true)
559 },
560
561 RecoveryStrategy::Degrade { degraded_mode } => {
562 self.enable_degraded_mode(degraded_mode)?;
563 Ok(true)
564 },
565
566 RecoveryStrategy::NoRecovery => Ok(false),
567 }
568 }
569
570 fn perform_memory_cleanup(&self) -> Result<()> {
572 println!("[INFO] Performing memory cleanup");
576
577 Ok(())
584 }
585
586 fn reduce_resource_usage(&self, reduction_factor: f64) -> Result<()> {
588 println!(
589 "[INFO] Reducing resource usage by factor: {}",
590 reduction_factor
591 );
592
593 Ok(())
600 }
601
602 fn switch_to_fallback(&self, fallback: &str) -> Result<()> {
604 println!("[INFO] Switching to fallback implementation: {}", fallback);
605
606 Ok(())
612 }
613
614 fn restart_component(&self, component: &str) -> Result<()> {
616 println!("[INFO] Restarting component: {}", component);
617
618 Ok(())
624 }
625
626 fn enable_degraded_mode(&self, mode: &str) -> Result<()> {
628 println!("[INFO] Enabling degraded mode: {}", mode);
629
630 Ok(())
636 }
637
638 fn restore_from_checkpoint(&self, checkpoint_id: &str) -> Result<bool> {
640 if let Some(_checkpoint) = self.checkpoints.get(checkpoint_id) {
641 println!("[INFO] Restoring from checkpoint: {}", checkpoint_id);
642
643 Ok(true)
649 } else {
650 println!("[WARN] Checkpoint not found: {}", checkpoint_id);
651 Ok(false)
652 }
653 }
654
655 pub fn create_checkpoint(
657 &mut self,
658 model_state: HashMap<String, Vec<u8>>,
659 metadata: HashMap<String, String>,
660 ) -> String {
661 let checkpoint = ModelCheckpoint::new(model_state, metadata);
662 let checkpoint_id = checkpoint.checkpoint_id.clone();
663
664 self.checkpoints.insert(checkpoint_id.clone(), checkpoint);
665 self.checkpoints.insert(
666 "latest".to_string(),
667 self.checkpoints[&checkpoint_id].clone(),
668 );
669
670 if self.checkpoints.len() > 10 {
672 let keys_to_remove: Vec<String> = self.checkpoints.keys()
674 .filter(|k| *k != "latest")
675 .skip(9) .cloned()
677 .collect();
678
679 for key in keys_to_remove {
680 self.checkpoints.remove(&key);
681 }
682 }
683
684 println!("[INFO] Created checkpoint: {}", checkpoint_id);
685 checkpoint_id
686 }
687
688 fn calculate_backoff_delay(&self, attempt: usize) -> Duration {
690 let delay_ms =
691 self.config.base_delay_ms as f64 * self.config.backoff_multiplier.powi(attempt as i32);
692 let delay_ms = delay_ms.min(self.config.max_delay_ms as f64) as u64;
693 Duration::from_millis(delay_ms)
694 }
695
696 fn get_or_create_circuit_breaker(&mut self, operation: &str) -> &mut CircuitBreaker {
698 self.circuit_breakers.entry(operation.to_string()).or_insert_with(|| {
699 CircuitBreaker::new(
700 self.config.circuit_breaker_threshold,
701 Duration::from_secs(self.config.circuit_breaker_timeout_s),
702 )
703 })
704 }
705
706 fn record_recovery_attempt(
708 &mut self,
709 category: ErrorCategory,
710 strategy: RecoveryStrategy,
711 success: bool,
712 error: &Error,
713 ) {
714 let attempt = RecoveryAttempt {
715 attempt_id: Uuid::new_v4(),
716 timestamp: SystemTime::now(),
717 error_category: category.clone(),
718 strategy: strategy.clone(),
719 success,
720 duration_ms: 0, error_message: error.to_string(),
722 context: HashMap::new(),
723 };
724
725 self.error_history.push_back(attempt);
726
727 while self.error_history.len() > self.config.max_error_history {
729 self.error_history.pop_front();
730 }
731
732 if let Ok(mut metrics) = self.metrics.lock() {
734 metrics.total_errors += 1;
735 if success {
736 metrics.successful_recoveries += 1;
737 } else {
738 metrics.failed_recoveries += 1;
739 }
740
741 metrics.recovery_rate =
742 metrics.successful_recoveries as f64 / metrics.total_errors as f64;
743
744 let count = metrics.most_common_errors.entry(category).or_insert(0);
745 *count += 1;
746 }
747 }
748
749 fn record_successful_recovery(&mut self, _attempts: usize, start_time: Instant) {
751 if let Ok(mut metrics) = self.metrics.lock() {
752 let duration = start_time.elapsed().as_millis() as f64;
753 let total_recoveries = metrics.successful_recoveries + metrics.failed_recoveries;
754
755 if total_recoveries > 0 {
756 metrics.average_recovery_time_ms =
757 (metrics.average_recovery_time_ms * total_recoveries as f64 + duration)
758 / (total_recoveries + 1) as f64;
759 } else {
760 metrics.average_recovery_time_ms = duration;
761 }
762 }
763 }
764
765 fn record_failed_recovery(
767 &mut self,
768 category: ErrorCategory,
769 _start_time: Instant,
770 error: &Error,
771 ) {
772 self.record_recovery_attempt(category, RecoveryStrategy::NoRecovery, false, error);
773 }
774
775 pub fn get_metrics(&self) -> RecoveryMetrics {
777 self.metrics.lock().expect("operation failed").clone()
778 }
779
780 pub fn generate_recovery_report(&self) -> RecoveryReport {
782 let metrics = self.get_metrics();
783 let uptime = self.start_time.elapsed();
784
785 let recent_errors: Vec<_> = self.error_history.iter().rev().take(10).cloned().collect();
786
787 let error_trends = self.analyze_error_trends();
788 let recommendations = self.generate_recommendations(&metrics, &error_trends);
789
790 RecoveryReport {
791 timestamp: SystemTime::now(),
792 uptime,
793 metrics,
794 recent_errors,
795 error_trends,
796 recommendations,
797 circuit_breaker_states: self.get_circuit_breaker_states(),
798 checkpoint_count: self.checkpoints.len(),
799 }
800 }
801
802 fn analyze_error_trends(&self) -> ErrorTrends {
804 let now = SystemTime::now();
805 let one_hour_ago = now.checked_sub(Duration::from_secs(3600)).unwrap_or(now);
806
807 let recent_errors: Vec<_> = self
808 .error_history
809 .iter()
810 .filter(|attempt| attempt.timestamp >= one_hour_ago)
811 .collect();
812
813 let error_rate = recent_errors.len() as f64 / 3600.0; let recovery_success_rate = if !recent_errors.is_empty() {
815 recent_errors.iter().filter(|a| a.success).count() as f64 / recent_errors.len() as f64
816 } else {
817 1.0
818 };
819
820 let trending_up = recent_errors.len() > self.error_history.len() / 2;
821
822 ErrorTrends {
823 error_rate,
824 recovery_success_rate,
825 trending_up,
826 most_frequent_category: self.get_most_frequent_error_category(&recent_errors),
827 }
828 }
829
830 fn get_most_frequent_error_category(
832 &self,
833 errors: &[&RecoveryAttempt],
834 ) -> Option<ErrorCategory> {
835 let mut category_counts = HashMap::new();
836
837 for error in errors {
838 let count = category_counts.entry(error.error_category.clone()).or_insert(0);
839 *count += 1;
840 }
841
842 category_counts
843 .into_iter()
844 .max_by_key(|(_, count)| *count)
845 .map(|(category, _)| category)
846 }
847
848 fn generate_recommendations(
850 &self,
851 metrics: &RecoveryMetrics,
852 trends: &ErrorTrends,
853 ) -> Vec<String> {
854 let mut recommendations = Vec::new();
855
856 if metrics.recovery_rate < 0.8 {
857 recommendations
858 .push("Recovery rate is low. Consider reviewing recovery strategies.".to_string());
859 }
860
861 if trends.error_rate > 0.1 {
862 recommendations.push("High error rate detected. Investigate root causes.".to_string());
863 }
864
865 if trends.trending_up {
866 recommendations
867 .push("Error frequency is increasing. Monitor system closely.".to_string());
868 }
869
870 if metrics.average_recovery_time_ms > 5000.0 {
871 recommendations
872 .push("Recovery time is high. Optimize recovery strategies.".to_string());
873 }
874
875 if let Some(category) = &trends.most_frequent_category {
876 recommendations.push(format!(
877 "Most frequent error category: {:?}. Focus optimization efforts here.",
878 category
879 ));
880 }
881
882 if recommendations.is_empty() {
883 recommendations.push("Error recovery system is operating normally.".to_string());
884 }
885
886 recommendations
887 }
888
889 fn get_circuit_breaker_states(&self) -> HashMap<String, String> {
891 self.circuit_breakers
892 .iter()
893 .map(|(name, breaker)| {
894 let state = match breaker.state {
895 CircuitBreakerState::Closed => "CLOSED",
896 CircuitBreakerState::Open => "OPEN",
897 CircuitBreakerState::HalfOpen => "HALF_OPEN",
898 };
899 (name.clone(), state.to_string())
900 })
901 .collect()
902 }
903}
904
905#[derive(Debug, Clone, Serialize, Deserialize)]
907pub struct ErrorTrends {
908 pub error_rate: f64,
909 pub recovery_success_rate: f64,
910 pub trending_up: bool,
911 pub most_frequent_category: Option<ErrorCategory>,
912}
913
914#[derive(Debug, Serialize, Deserialize)]
916pub struct RecoveryReport {
917 pub timestamp: SystemTime,
918 pub uptime: Duration,
919 pub metrics: RecoveryMetrics,
920 pub recent_errors: Vec<RecoveryAttempt>,
921 pub error_trends: ErrorTrends,
922 pub recommendations: Vec<String>,
923 pub circuit_breaker_states: HashMap<String, String>,
924 pub checkpoint_count: usize,
925}
926
927pub trait RecoverableOperation<T> {
929 fn with_recovery(self, manager: &mut ErrorRecoveryManager) -> Result<T>;
930}
931
932impl<T, F> RecoverableOperation<T> for F
933where
934 F: Fn() -> Result<T>,
935{
936 fn with_recovery(self, manager: &mut ErrorRecoveryManager) -> Result<T> {
937 manager.execute_with_recovery(self)
938 }
939}
940
941#[cfg(test)]
942mod tests {
943 use super::*;
944
945 #[test]
946 fn test_error_classification() {
947 let manager = ErrorRecoveryManager::new(RecoveryConfig::default());
948
949 let memory_error = anyhow::anyhow!("Out of memory error occurred");
950 assert_eq!(manager.classify_error(&memory_error), ErrorCategory::Memory);
951
952 let cuda_error = anyhow::anyhow!("CUDA device error");
953 assert_eq!(manager.classify_error(&cuda_error), ErrorCategory::Compute);
954 }
955
956 #[test]
957 fn test_circuit_breaker() {
958 let mut breaker = CircuitBreaker::new(2, Duration::from_secs(1));
959
960 assert!(breaker.can_execute());
961
962 breaker.on_failure();
963 assert!(breaker.can_execute());
964
965 breaker.on_failure();
966 assert!(!breaker.can_execute()); breaker.on_success();
969 assert!(breaker.can_execute()); }
971
972 #[test]
973 fn test_backoff_calculation() {
974 let config = RecoveryConfig::default();
975 let manager = ErrorRecoveryManager::new(config);
976
977 let delay0 = manager.calculate_backoff_delay(0);
978 let delay1 = manager.calculate_backoff_delay(1);
979 let delay2 = manager.calculate_backoff_delay(2);
980
981 assert!(delay1 > delay0);
982 assert!(delay2 > delay1);
983 }
984
985 #[test]
986 fn test_recovery_config_builder() {
987 let config = RecoveryConfig::default()
988 .with_max_retries(5)
989 .with_fallback_enabled(false)
990 .with_memory_threshold(2048.0);
991
992 assert_eq!(config.max_retries, 5);
993 assert!(!config.enable_fallback);
994 assert_eq!(config.memory_pressure_threshold_mb, 2048.0);
995 }
996}