Skip to main content

tensorlogic_infer/
recovery.rs

1//! Error recovery and fault tolerance for execution.
2//!
3//! This module provides mechanisms for handling failures gracefully:
4//! - Partial results on execution failure
5//! - Checkpoint/restart capabilities
6//! - Graceful degradation strategies
7
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10
11use tensorlogic_ir::EinsumGraph;
12
13/// Recovery strategy for handling execution failures
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum RecoveryStrategy {
16    /// Fail immediately on any error
17    FailFast,
18    /// Continue execution with partial results
19    ContinuePartial,
20    /// Retry failed operations with exponential backoff
21    RetryWithBackoff { max_retries: usize },
22    /// Degrade gracefully by skipping non-critical operations
23    GracefulDegradation,
24}
25
26/// Configuration for error recovery
27#[derive(Debug, Clone)]
28pub struct RecoveryConfig {
29    pub strategy: RecoveryStrategy,
30    pub checkpoint_interval: Option<usize>,
31    pub max_failures: Option<usize>,
32    pub timeout: Option<Duration>,
33}
34
35impl RecoveryConfig {
36    pub fn fail_fast() -> Self {
37        RecoveryConfig {
38            strategy: RecoveryStrategy::FailFast,
39            checkpoint_interval: None,
40            max_failures: None,
41            timeout: None,
42        }
43    }
44
45    pub fn partial_results() -> Self {
46        RecoveryConfig {
47            strategy: RecoveryStrategy::ContinuePartial,
48            checkpoint_interval: Some(10),
49            max_failures: Some(5),
50            timeout: None,
51        }
52    }
53
54    pub fn retry(max_retries: usize) -> Self {
55        RecoveryConfig {
56            strategy: RecoveryStrategy::RetryWithBackoff { max_retries },
57            checkpoint_interval: Some(5),
58            max_failures: None,
59            timeout: Some(Duration::from_secs(300)), // 5 minutes
60        }
61    }
62
63    pub fn graceful() -> Self {
64        RecoveryConfig {
65            strategy: RecoveryStrategy::GracefulDegradation,
66            checkpoint_interval: Some(10),
67            max_failures: Some(10),
68            timeout: Some(Duration::from_secs(600)), // 10 minutes
69        }
70    }
71
72    pub fn with_checkpointing(mut self, interval: usize) -> Self {
73        self.checkpoint_interval = Some(interval);
74        self
75    }
76
77    pub fn with_timeout(mut self, timeout: Duration) -> Self {
78        self.timeout = Some(timeout);
79        self
80    }
81
82    pub fn with_max_failures(mut self, max: usize) -> Self {
83        self.max_failures = Some(max);
84        self
85    }
86}
87
88impl Default for RecoveryConfig {
89    fn default() -> Self {
90        Self::partial_results()
91    }
92}
93
94/// Result of execution with recovery information
95#[derive(Debug, Clone)]
96pub struct RecoveryResult<T> {
97    /// Successfully computed outputs
98    pub outputs: Vec<T>,
99    /// Indices of failed operations
100    pub failures: Vec<FailureInfo>,
101    /// Total number of operations attempted
102    pub total_operations: usize,
103    /// Whether execution completed successfully
104    pub success: bool,
105    /// Recovery metadata
106    pub metadata: RecoveryMetadata,
107}
108
109impl<T> RecoveryResult<T> {
110    pub fn success(outputs: Vec<T>) -> Self {
111        let total = outputs.len();
112        RecoveryResult {
113            outputs,
114            failures: Vec::new(),
115            total_operations: total,
116            success: true,
117            metadata: RecoveryMetadata::default(),
118        }
119    }
120
121    pub fn partial(
122        outputs: Vec<T>,
123        failures: Vec<FailureInfo>,
124        total_operations: usize,
125        metadata: RecoveryMetadata,
126    ) -> Self {
127        RecoveryResult {
128            outputs,
129            failures,
130            total_operations,
131            success: false,
132            metadata,
133        }
134    }
135
136    pub fn success_rate(&self) -> f64 {
137        if self.total_operations == 0 {
138            return 0.0;
139        }
140        (self.outputs.len() as f64) / (self.total_operations as f64)
141    }
142
143    pub fn failure_rate(&self) -> f64 {
144        1.0 - self.success_rate()
145    }
146
147    pub fn has_failures(&self) -> bool {
148        !self.failures.is_empty()
149    }
150}
151
152/// Information about a failed operation
153#[derive(Debug, Clone)]
154pub struct FailureInfo {
155    pub operation_id: usize,
156    pub error: String,
157    pub retry_count: usize,
158    pub timestamp: Instant,
159}
160
161impl FailureInfo {
162    pub fn new(operation_id: usize, error: String) -> Self {
163        FailureInfo {
164            operation_id,
165            error,
166            retry_count: 0,
167            timestamp: Instant::now(),
168        }
169    }
170
171    pub fn with_retries(mut self, count: usize) -> Self {
172        self.retry_count = count;
173        self
174    }
175}
176
177/// Metadata about recovery process
178#[derive(Debug, Clone)]
179pub struct RecoveryMetadata {
180    pub total_retries: usize,
181    pub checkpoints_created: usize,
182    pub execution_time: Duration,
183    pub recovery_strategy_used: RecoveryStrategy,
184}
185
186impl RecoveryMetadata {
187    pub fn new(strategy: RecoveryStrategy) -> Self {
188        RecoveryMetadata {
189            total_retries: 0,
190            checkpoints_created: 0,
191            execution_time: Duration::default(),
192            recovery_strategy_used: strategy,
193        }
194    }
195}
196
197impl Default for RecoveryMetadata {
198    fn default() -> Self {
199        Self::new(RecoveryStrategy::FailFast)
200    }
201}
202
203/// Checkpoint for saving execution state
204#[derive(Debug, Clone)]
205pub struct Checkpoint<T> {
206    pub checkpoint_id: usize,
207    pub operation_index: usize,
208    pub partial_results: Vec<T>,
209    pub timestamp: Instant,
210}
211
212impl<T: Clone> Checkpoint<T> {
213    pub fn new(checkpoint_id: usize, operation_index: usize, partial_results: Vec<T>) -> Self {
214        Checkpoint {
215            checkpoint_id,
216            operation_index,
217            partial_results,
218            timestamp: Instant::now(),
219        }
220    }
221
222    pub fn age(&self) -> Duration {
223        self.timestamp.elapsed()
224    }
225}
226
227/// Manager for checkpoints during execution
228pub struct CheckpointManager<T> {
229    checkpoints: Vec<Checkpoint<T>>,
230    max_checkpoints: usize,
231}
232
233impl<T: Clone> CheckpointManager<T> {
234    pub fn new(max_checkpoints: usize) -> Self {
235        CheckpointManager {
236            checkpoints: Vec::new(),
237            max_checkpoints,
238        }
239    }
240
241    pub fn create_checkpoint(&mut self, operation_index: usize, partial_results: Vec<T>) -> usize {
242        let checkpoint_id = self.checkpoints.len();
243        let checkpoint = Checkpoint::new(checkpoint_id, operation_index, partial_results);
244
245        self.checkpoints.push(checkpoint);
246
247        // Evict oldest checkpoint if we exceed max
248        if self.checkpoints.len() > self.max_checkpoints {
249            self.checkpoints.remove(0);
250        }
251
252        checkpoint_id
253    }
254
255    pub fn restore_checkpoint(&self, checkpoint_id: usize) -> Option<&Checkpoint<T>> {
256        self.checkpoints.get(checkpoint_id)
257    }
258
259    pub fn latest_checkpoint(&self) -> Option<&Checkpoint<T>> {
260        self.checkpoints.last()
261    }
262
263    pub fn num_checkpoints(&self) -> usize {
264        self.checkpoints.len()
265    }
266
267    pub fn clear(&mut self) {
268        self.checkpoints.clear();
269    }
270}
271
272// Note: No Default impl because it requires T: Clone for new()
273
274/// Trait for executors with recovery capabilities
275pub trait TlRecoverableExecutor {
276    type Tensor;
277    type Error;
278
279    /// Execute graph with recovery configuration
280    fn execute_with_recovery(
281        &mut self,
282        graph: &EinsumGraph,
283        inputs: Vec<Self::Tensor>,
284        config: &RecoveryConfig,
285    ) -> Result<RecoveryResult<Self::Tensor>, Self::Error>;
286
287    /// Create a checkpoint of current execution state
288    fn create_checkpoint(&mut self, operation_index: usize) -> Result<usize, Self::Error>;
289
290    /// Restore from a checkpoint
291    fn restore_checkpoint(&mut self, checkpoint_id: usize) -> Result<(), Self::Error>;
292
293    /// Get recovery statistics
294    fn recovery_stats(&self) -> RecoveryStats;
295}
296
297/// Statistics about recovery operations
298#[derive(Debug, Clone, Default)]
299pub struct RecoveryStats {
300    pub total_recoveries: usize,
301    pub successful_recoveries: usize,
302    pub failed_recoveries: usize,
303    pub total_retries: usize,
304    pub total_checkpoints: usize,
305}
306
307impl RecoveryStats {
308    pub fn new() -> Self {
309        Self::default()
310    }
311
312    pub fn record_recovery(&mut self, success: bool) {
313        self.total_recoveries += 1;
314        if success {
315            self.successful_recoveries += 1;
316        } else {
317            self.failed_recoveries += 1;
318        }
319    }
320
321    pub fn record_retry(&mut self) {
322        self.total_retries += 1;
323    }
324
325    pub fn record_checkpoint(&mut self) {
326        self.total_checkpoints += 1;
327    }
328
329    pub fn recovery_rate(&self) -> f64 {
330        if self.total_recoveries == 0 {
331            return 0.0;
332        }
333        (self.successful_recoveries as f64) / (self.total_recoveries as f64)
334    }
335}
336
337/// Retry policy with exponential backoff
338pub struct RetryPolicy {
339    max_retries: usize,
340    base_delay_ms: u64,
341    max_delay_ms: u64,
342    backoff_multiplier: f64,
343}
344
345impl RetryPolicy {
346    pub fn new(max_retries: usize) -> Self {
347        RetryPolicy {
348            max_retries,
349            base_delay_ms: 100,
350            max_delay_ms: 10_000,
351            backoff_multiplier: 2.0,
352        }
353    }
354
355    pub fn exponential(max_retries: usize, base_delay_ms: u64) -> Self {
356        RetryPolicy {
357            max_retries,
358            base_delay_ms,
359            max_delay_ms: 60_000, // 1 minute max
360            backoff_multiplier: 2.0,
361        }
362    }
363
364    pub fn calculate_delay(&self, retry_count: usize) -> Duration {
365        if retry_count >= self.max_retries {
366            return Duration::from_millis(self.max_delay_ms);
367        }
368
369        let delay_ms =
370            (self.base_delay_ms as f64) * self.backoff_multiplier.powi(retry_count as i32);
371        let delay_ms = delay_ms.min(self.max_delay_ms as f64) as u64;
372
373        Duration::from_millis(delay_ms)
374    }
375
376    pub fn should_retry(&self, retry_count: usize) -> bool {
377        retry_count < self.max_retries
378    }
379
380    pub fn max_retries(&self) -> usize {
381        self.max_retries
382    }
383}
384
385impl Default for RetryPolicy {
386    fn default() -> Self {
387        Self::new(3)
388    }
389}
390
391/// Degradation policy for graceful degradation
392#[derive(Debug, Clone)]
393pub struct DegradationPolicy {
394    /// Operations that can be skipped without critical failure
395    pub skippable_operations: Vec<usize>,
396    /// Fallback strategies for failed operations
397    pub fallback_strategies: HashMap<usize, FallbackStrategy>,
398}
399
400impl DegradationPolicy {
401    pub fn new() -> Self {
402        DegradationPolicy {
403            skippable_operations: Vec::new(),
404            fallback_strategies: HashMap::new(),
405        }
406    }
407
408    pub fn mark_skippable(mut self, operation_id: usize) -> Self {
409        self.skippable_operations.push(operation_id);
410        self
411    }
412
413    pub fn with_fallback(mut self, operation_id: usize, strategy: FallbackStrategy) -> Self {
414        self.fallback_strategies.insert(operation_id, strategy);
415        self
416    }
417
418    pub fn can_skip(&self, operation_id: usize) -> bool {
419        self.skippable_operations.contains(&operation_id)
420    }
421
422    pub fn get_fallback(&self, operation_id: usize) -> Option<&FallbackStrategy> {
423        self.fallback_strategies.get(&operation_id)
424    }
425}
426
427impl Default for DegradationPolicy {
428    fn default() -> Self {
429        Self::new()
430    }
431}
432
433/// Fallback strategy for failed operations
434#[derive(Debug, Clone, PartialEq, Eq)]
435pub enum FallbackStrategy {
436    /// Skip the operation entirely
437    Skip,
438    /// Use a default/zero value
439    UseDefault,
440    /// Use result from a previous successful execution
441    UseCached,
442    /// Use a simpler approximation
443    UseApproximation,
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn test_recovery_config() {
452        let config = RecoveryConfig::partial_results()
453            .with_checkpointing(20)
454            .with_max_failures(3);
455
456        assert_eq!(config.strategy, RecoveryStrategy::ContinuePartial);
457        assert_eq!(config.checkpoint_interval, Some(20));
458        assert_eq!(config.max_failures, Some(3));
459    }
460
461    #[test]
462    fn test_recovery_config_retry() {
463        let config = RecoveryConfig::retry(5);
464        assert_eq!(
465            config.strategy,
466            RecoveryStrategy::RetryWithBackoff { max_retries: 5 }
467        );
468        assert!(config.timeout.is_some());
469    }
470
471    #[test]
472    fn test_recovery_result_success() {
473        let result: RecoveryResult<i32> = RecoveryResult::success(vec![1, 2, 3]);
474        assert!(result.success);
475        assert_eq!(result.success_rate(), 1.0);
476        assert_eq!(result.failure_rate(), 0.0);
477        assert!(!result.has_failures());
478    }
479
480    #[test]
481    fn test_recovery_result_partial() {
482        let failures = vec![FailureInfo::new(2, "Error".to_string())];
483        let metadata = RecoveryMetadata::new(RecoveryStrategy::ContinuePartial);
484        let result: RecoveryResult<i32> =
485            RecoveryResult::partial(vec![1, 2], failures, 3, metadata);
486
487        assert!(!result.success);
488        assert_eq!(result.success_rate(), 2.0 / 3.0);
489        assert!(result.has_failures());
490        assert_eq!(result.failures.len(), 1);
491    }
492
493    #[test]
494    fn test_checkpoint_manager() {
495        let mut manager: CheckpointManager<i32> = CheckpointManager::new(3);
496
497        let id1 = manager.create_checkpoint(0, vec![1, 2, 3]);
498        let _id2 = manager.create_checkpoint(1, vec![4, 5, 6]);
499        let _id3 = manager.create_checkpoint(2, vec![7, 8, 9]);
500
501        assert_eq!(manager.num_checkpoints(), 3);
502
503        let checkpoint = manager.restore_checkpoint(id1).unwrap();
504        assert_eq!(checkpoint.checkpoint_id, 0);
505        assert_eq!(checkpoint.partial_results, vec![1, 2, 3]);
506
507        // Add one more, should evict the oldest
508        manager.create_checkpoint(3, vec![10, 11, 12]);
509        assert_eq!(manager.num_checkpoints(), 3);
510    }
511
512    #[test]
513    fn test_checkpoint_manager_latest() {
514        let mut manager: CheckpointManager<i32> = CheckpointManager::new(5);
515
516        manager.create_checkpoint(0, vec![1]);
517        manager.create_checkpoint(1, vec![2]);
518        manager.create_checkpoint(2, vec![3]);
519
520        let latest = manager.latest_checkpoint().unwrap();
521        assert_eq!(latest.checkpoint_id, 2);
522        assert_eq!(latest.partial_results, vec![3]);
523    }
524
525    #[test]
526    fn test_recovery_stats() {
527        let mut stats = RecoveryStats::new();
528
529        stats.record_recovery(true);
530        stats.record_recovery(true);
531        stats.record_recovery(false);
532        stats.record_retry();
533        stats.record_retry();
534        stats.record_checkpoint();
535
536        assert_eq!(stats.total_recoveries, 3);
537        assert_eq!(stats.successful_recoveries, 2);
538        assert_eq!(stats.failed_recoveries, 1);
539        assert_eq!(stats.total_retries, 2);
540        assert_eq!(stats.total_checkpoints, 1);
541        assert!((stats.recovery_rate() - 2.0 / 3.0).abs() < 1e-6);
542    }
543
544    #[test]
545    fn test_retry_policy() {
546        let policy = RetryPolicy::new(3);
547
548        assert!(policy.should_retry(0));
549        assert!(policy.should_retry(2));
550        assert!(!policy.should_retry(3));
551        assert!(!policy.should_retry(4));
552
553        let delay1 = policy.calculate_delay(0);
554        let delay2 = policy.calculate_delay(1);
555        let delay3 = policy.calculate_delay(2);
556
557        // Exponential backoff
558        assert!(delay2 > delay1);
559        assert!(delay3 > delay2);
560    }
561
562    #[test]
563    fn test_retry_policy_exponential() {
564        let policy = RetryPolicy::exponential(5, 50);
565
566        let delay0 = policy.calculate_delay(0);
567        let delay1 = policy.calculate_delay(1);
568        let delay2 = policy.calculate_delay(2);
569
570        assert_eq!(delay0.as_millis(), 50);
571        assert_eq!(delay1.as_millis(), 100);
572        assert_eq!(delay2.as_millis(), 200);
573    }
574
575    #[test]
576    fn test_degradation_policy() {
577        let policy = DegradationPolicy::new()
578            .mark_skippable(1)
579            .mark_skippable(3)
580            .with_fallback(2, FallbackStrategy::UseDefault);
581
582        assert!(policy.can_skip(1));
583        assert!(!policy.can_skip(2));
584        assert!(policy.can_skip(3));
585
586        let fallback = policy.get_fallback(2);
587        assert_eq!(fallback, Some(&FallbackStrategy::UseDefault));
588        assert!(policy.get_fallback(1).is_none());
589    }
590
591    #[test]
592    fn test_failure_info() {
593        let info = FailureInfo::new(5, "Test error".to_string()).with_retries(3);
594
595        assert_eq!(info.operation_id, 5);
596        assert_eq!(info.error, "Test error");
597        assert_eq!(info.retry_count, 3);
598    }
599
600    #[test]
601    fn test_checkpoint_age() {
602        let checkpoint: Checkpoint<i32> = Checkpoint::new(0, 0, vec![1, 2, 3]);
603        std::thread::sleep(Duration::from_millis(10));
604        let age = checkpoint.age();
605        assert!(age >= Duration::from_millis(10));
606    }
607}