1use std::collections::HashMap;
9use std::time::{Duration, Instant};
10
11use tensorlogic_ir::EinsumGraph;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum RecoveryStrategy {
16 FailFast,
18 ContinuePartial,
20 RetryWithBackoff { max_retries: usize },
22 GracefulDegradation,
24}
25
26#[derive(Debug, Clone)]
28pub struct RecoveryConfig {
29 pub strategy: RecoveryStrategy,
30 pub checkpoint_interval: Option<usize>,
31 pub max_failures: Option<usize>,
32 pub timeout: Option<Duration>,
33}
34
35impl RecoveryConfig {
36 pub fn fail_fast() -> Self {
37 RecoveryConfig {
38 strategy: RecoveryStrategy::FailFast,
39 checkpoint_interval: None,
40 max_failures: None,
41 timeout: None,
42 }
43 }
44
45 pub fn partial_results() -> Self {
46 RecoveryConfig {
47 strategy: RecoveryStrategy::ContinuePartial,
48 checkpoint_interval: Some(10),
49 max_failures: Some(5),
50 timeout: None,
51 }
52 }
53
54 pub fn retry(max_retries: usize) -> Self {
55 RecoveryConfig {
56 strategy: RecoveryStrategy::RetryWithBackoff { max_retries },
57 checkpoint_interval: Some(5),
58 max_failures: None,
59 timeout: Some(Duration::from_secs(300)), }
61 }
62
63 pub fn graceful() -> Self {
64 RecoveryConfig {
65 strategy: RecoveryStrategy::GracefulDegradation,
66 checkpoint_interval: Some(10),
67 max_failures: Some(10),
68 timeout: Some(Duration::from_secs(600)), }
70 }
71
72 pub fn with_checkpointing(mut self, interval: usize) -> Self {
73 self.checkpoint_interval = Some(interval);
74 self
75 }
76
77 pub fn with_timeout(mut self, timeout: Duration) -> Self {
78 self.timeout = Some(timeout);
79 self
80 }
81
82 pub fn with_max_failures(mut self, max: usize) -> Self {
83 self.max_failures = Some(max);
84 self
85 }
86}
87
88impl Default for RecoveryConfig {
89 fn default() -> Self {
90 Self::partial_results()
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct RecoveryResult<T> {
97 pub outputs: Vec<T>,
99 pub failures: Vec<FailureInfo>,
101 pub total_operations: usize,
103 pub success: bool,
105 pub metadata: RecoveryMetadata,
107}
108
109impl<T> RecoveryResult<T> {
110 pub fn success(outputs: Vec<T>) -> Self {
111 let total = outputs.len();
112 RecoveryResult {
113 outputs,
114 failures: Vec::new(),
115 total_operations: total,
116 success: true,
117 metadata: RecoveryMetadata::default(),
118 }
119 }
120
121 pub fn partial(
122 outputs: Vec<T>,
123 failures: Vec<FailureInfo>,
124 total_operations: usize,
125 metadata: RecoveryMetadata,
126 ) -> Self {
127 RecoveryResult {
128 outputs,
129 failures,
130 total_operations,
131 success: false,
132 metadata,
133 }
134 }
135
136 pub fn success_rate(&self) -> f64 {
137 if self.total_operations == 0 {
138 return 0.0;
139 }
140 (self.outputs.len() as f64) / (self.total_operations as f64)
141 }
142
143 pub fn failure_rate(&self) -> f64 {
144 1.0 - self.success_rate()
145 }
146
147 pub fn has_failures(&self) -> bool {
148 !self.failures.is_empty()
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct FailureInfo {
155 pub operation_id: usize,
156 pub error: String,
157 pub retry_count: usize,
158 pub timestamp: Instant,
159}
160
161impl FailureInfo {
162 pub fn new(operation_id: usize, error: String) -> Self {
163 FailureInfo {
164 operation_id,
165 error,
166 retry_count: 0,
167 timestamp: Instant::now(),
168 }
169 }
170
171 pub fn with_retries(mut self, count: usize) -> Self {
172 self.retry_count = count;
173 self
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct RecoveryMetadata {
180 pub total_retries: usize,
181 pub checkpoints_created: usize,
182 pub execution_time: Duration,
183 pub recovery_strategy_used: RecoveryStrategy,
184}
185
186impl RecoveryMetadata {
187 pub fn new(strategy: RecoveryStrategy) -> Self {
188 RecoveryMetadata {
189 total_retries: 0,
190 checkpoints_created: 0,
191 execution_time: Duration::default(),
192 recovery_strategy_used: strategy,
193 }
194 }
195}
196
197impl Default for RecoveryMetadata {
198 fn default() -> Self {
199 Self::new(RecoveryStrategy::FailFast)
200 }
201}
202
203#[derive(Debug, Clone)]
205pub struct Checkpoint<T> {
206 pub checkpoint_id: usize,
207 pub operation_index: usize,
208 pub partial_results: Vec<T>,
209 pub timestamp: Instant,
210}
211
212impl<T: Clone> Checkpoint<T> {
213 pub fn new(checkpoint_id: usize, operation_index: usize, partial_results: Vec<T>) -> Self {
214 Checkpoint {
215 checkpoint_id,
216 operation_index,
217 partial_results,
218 timestamp: Instant::now(),
219 }
220 }
221
222 pub fn age(&self) -> Duration {
223 self.timestamp.elapsed()
224 }
225}
226
227pub struct CheckpointManager<T> {
229 checkpoints: Vec<Checkpoint<T>>,
230 max_checkpoints: usize,
231}
232
233impl<T: Clone> CheckpointManager<T> {
234 pub fn new(max_checkpoints: usize) -> Self {
235 CheckpointManager {
236 checkpoints: Vec::new(),
237 max_checkpoints,
238 }
239 }
240
241 pub fn create_checkpoint(&mut self, operation_index: usize, partial_results: Vec<T>) -> usize {
242 let checkpoint_id = self.checkpoints.len();
243 let checkpoint = Checkpoint::new(checkpoint_id, operation_index, partial_results);
244
245 self.checkpoints.push(checkpoint);
246
247 if self.checkpoints.len() > self.max_checkpoints {
249 self.checkpoints.remove(0);
250 }
251
252 checkpoint_id
253 }
254
255 pub fn restore_checkpoint(&self, checkpoint_id: usize) -> Option<&Checkpoint<T>> {
256 self.checkpoints.get(checkpoint_id)
257 }
258
259 pub fn latest_checkpoint(&self) -> Option<&Checkpoint<T>> {
260 self.checkpoints.last()
261 }
262
263 pub fn num_checkpoints(&self) -> usize {
264 self.checkpoints.len()
265 }
266
267 pub fn clear(&mut self) {
268 self.checkpoints.clear();
269 }
270}
271
272pub trait TlRecoverableExecutor {
276 type Tensor;
277 type Error;
278
279 fn execute_with_recovery(
281 &mut self,
282 graph: &EinsumGraph,
283 inputs: Vec<Self::Tensor>,
284 config: &RecoveryConfig,
285 ) -> Result<RecoveryResult<Self::Tensor>, Self::Error>;
286
287 fn create_checkpoint(&mut self, operation_index: usize) -> Result<usize, Self::Error>;
289
290 fn restore_checkpoint(&mut self, checkpoint_id: usize) -> Result<(), Self::Error>;
292
293 fn recovery_stats(&self) -> RecoveryStats;
295}
296
297#[derive(Debug, Clone, Default)]
299pub struct RecoveryStats {
300 pub total_recoveries: usize,
301 pub successful_recoveries: usize,
302 pub failed_recoveries: usize,
303 pub total_retries: usize,
304 pub total_checkpoints: usize,
305}
306
307impl RecoveryStats {
308 pub fn new() -> Self {
309 Self::default()
310 }
311
312 pub fn record_recovery(&mut self, success: bool) {
313 self.total_recoveries += 1;
314 if success {
315 self.successful_recoveries += 1;
316 } else {
317 self.failed_recoveries += 1;
318 }
319 }
320
321 pub fn record_retry(&mut self) {
322 self.total_retries += 1;
323 }
324
325 pub fn record_checkpoint(&mut self) {
326 self.total_checkpoints += 1;
327 }
328
329 pub fn recovery_rate(&self) -> f64 {
330 if self.total_recoveries == 0 {
331 return 0.0;
332 }
333 (self.successful_recoveries as f64) / (self.total_recoveries as f64)
334 }
335}
336
337pub struct RetryPolicy {
339 max_retries: usize,
340 base_delay_ms: u64,
341 max_delay_ms: u64,
342 backoff_multiplier: f64,
343}
344
345impl RetryPolicy {
346 pub fn new(max_retries: usize) -> Self {
347 RetryPolicy {
348 max_retries,
349 base_delay_ms: 100,
350 max_delay_ms: 10_000,
351 backoff_multiplier: 2.0,
352 }
353 }
354
355 pub fn exponential(max_retries: usize, base_delay_ms: u64) -> Self {
356 RetryPolicy {
357 max_retries,
358 base_delay_ms,
359 max_delay_ms: 60_000, backoff_multiplier: 2.0,
361 }
362 }
363
364 pub fn calculate_delay(&self, retry_count: usize) -> Duration {
365 if retry_count >= self.max_retries {
366 return Duration::from_millis(self.max_delay_ms);
367 }
368
369 let delay_ms =
370 (self.base_delay_ms as f64) * self.backoff_multiplier.powi(retry_count as i32);
371 let delay_ms = delay_ms.min(self.max_delay_ms as f64) as u64;
372
373 Duration::from_millis(delay_ms)
374 }
375
376 pub fn should_retry(&self, retry_count: usize) -> bool {
377 retry_count < self.max_retries
378 }
379
380 pub fn max_retries(&self) -> usize {
381 self.max_retries
382 }
383}
384
385impl Default for RetryPolicy {
386 fn default() -> Self {
387 Self::new(3)
388 }
389}
390
391#[derive(Debug, Clone)]
393pub struct DegradationPolicy {
394 pub skippable_operations: Vec<usize>,
396 pub fallback_strategies: HashMap<usize, FallbackStrategy>,
398}
399
400impl DegradationPolicy {
401 pub fn new() -> Self {
402 DegradationPolicy {
403 skippable_operations: Vec::new(),
404 fallback_strategies: HashMap::new(),
405 }
406 }
407
408 pub fn mark_skippable(mut self, operation_id: usize) -> Self {
409 self.skippable_operations.push(operation_id);
410 self
411 }
412
413 pub fn with_fallback(mut self, operation_id: usize, strategy: FallbackStrategy) -> Self {
414 self.fallback_strategies.insert(operation_id, strategy);
415 self
416 }
417
418 pub fn can_skip(&self, operation_id: usize) -> bool {
419 self.skippable_operations.contains(&operation_id)
420 }
421
422 pub fn get_fallback(&self, operation_id: usize) -> Option<&FallbackStrategy> {
423 self.fallback_strategies.get(&operation_id)
424 }
425}
426
427impl Default for DegradationPolicy {
428 fn default() -> Self {
429 Self::new()
430 }
431}
432
433#[derive(Debug, Clone, PartialEq, Eq)]
435pub enum FallbackStrategy {
436 Skip,
438 UseDefault,
440 UseCached,
442 UseApproximation,
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_recovery_config() {
452 let config = RecoveryConfig::partial_results()
453 .with_checkpointing(20)
454 .with_max_failures(3);
455
456 assert_eq!(config.strategy, RecoveryStrategy::ContinuePartial);
457 assert_eq!(config.checkpoint_interval, Some(20));
458 assert_eq!(config.max_failures, Some(3));
459 }
460
461 #[test]
462 fn test_recovery_config_retry() {
463 let config = RecoveryConfig::retry(5);
464 assert_eq!(
465 config.strategy,
466 RecoveryStrategy::RetryWithBackoff { max_retries: 5 }
467 );
468 assert!(config.timeout.is_some());
469 }
470
471 #[test]
472 fn test_recovery_result_success() {
473 let result: RecoveryResult<i32> = RecoveryResult::success(vec![1, 2, 3]);
474 assert!(result.success);
475 assert_eq!(result.success_rate(), 1.0);
476 assert_eq!(result.failure_rate(), 0.0);
477 assert!(!result.has_failures());
478 }
479
480 #[test]
481 fn test_recovery_result_partial() {
482 let failures = vec![FailureInfo::new(2, "Error".to_string())];
483 let metadata = RecoveryMetadata::new(RecoveryStrategy::ContinuePartial);
484 let result: RecoveryResult<i32> =
485 RecoveryResult::partial(vec![1, 2], failures, 3, metadata);
486
487 assert!(!result.success);
488 assert_eq!(result.success_rate(), 2.0 / 3.0);
489 assert!(result.has_failures());
490 assert_eq!(result.failures.len(), 1);
491 }
492
493 #[test]
494 fn test_checkpoint_manager() {
495 let mut manager: CheckpointManager<i32> = CheckpointManager::new(3);
496
497 let id1 = manager.create_checkpoint(0, vec![1, 2, 3]);
498 let _id2 = manager.create_checkpoint(1, vec![4, 5, 6]);
499 let _id3 = manager.create_checkpoint(2, vec![7, 8, 9]);
500
501 assert_eq!(manager.num_checkpoints(), 3);
502
503 let checkpoint = manager.restore_checkpoint(id1).unwrap();
504 assert_eq!(checkpoint.checkpoint_id, 0);
505 assert_eq!(checkpoint.partial_results, vec![1, 2, 3]);
506
507 manager.create_checkpoint(3, vec![10, 11, 12]);
509 assert_eq!(manager.num_checkpoints(), 3);
510 }
511
512 #[test]
513 fn test_checkpoint_manager_latest() {
514 let mut manager: CheckpointManager<i32> = CheckpointManager::new(5);
515
516 manager.create_checkpoint(0, vec![1]);
517 manager.create_checkpoint(1, vec![2]);
518 manager.create_checkpoint(2, vec![3]);
519
520 let latest = manager.latest_checkpoint().unwrap();
521 assert_eq!(latest.checkpoint_id, 2);
522 assert_eq!(latest.partial_results, vec![3]);
523 }
524
525 #[test]
526 fn test_recovery_stats() {
527 let mut stats = RecoveryStats::new();
528
529 stats.record_recovery(true);
530 stats.record_recovery(true);
531 stats.record_recovery(false);
532 stats.record_retry();
533 stats.record_retry();
534 stats.record_checkpoint();
535
536 assert_eq!(stats.total_recoveries, 3);
537 assert_eq!(stats.successful_recoveries, 2);
538 assert_eq!(stats.failed_recoveries, 1);
539 assert_eq!(stats.total_retries, 2);
540 assert_eq!(stats.total_checkpoints, 1);
541 assert!((stats.recovery_rate() - 2.0 / 3.0).abs() < 1e-6);
542 }
543
544 #[test]
545 fn test_retry_policy() {
546 let policy = RetryPolicy::new(3);
547
548 assert!(policy.should_retry(0));
549 assert!(policy.should_retry(2));
550 assert!(!policy.should_retry(3));
551 assert!(!policy.should_retry(4));
552
553 let delay1 = policy.calculate_delay(0);
554 let delay2 = policy.calculate_delay(1);
555 let delay3 = policy.calculate_delay(2);
556
557 assert!(delay2 > delay1);
559 assert!(delay3 > delay2);
560 }
561
562 #[test]
563 fn test_retry_policy_exponential() {
564 let policy = RetryPolicy::exponential(5, 50);
565
566 let delay0 = policy.calculate_delay(0);
567 let delay1 = policy.calculate_delay(1);
568 let delay2 = policy.calculate_delay(2);
569
570 assert_eq!(delay0.as_millis(), 50);
571 assert_eq!(delay1.as_millis(), 100);
572 assert_eq!(delay2.as_millis(), 200);
573 }
574
575 #[test]
576 fn test_degradation_policy() {
577 let policy = DegradationPolicy::new()
578 .mark_skippable(1)
579 .mark_skippable(3)
580 .with_fallback(2, FallbackStrategy::UseDefault);
581
582 assert!(policy.can_skip(1));
583 assert!(!policy.can_skip(2));
584 assert!(policy.can_skip(3));
585
586 let fallback = policy.get_fallback(2);
587 assert_eq!(fallback, Some(&FallbackStrategy::UseDefault));
588 assert!(policy.get_fallback(1).is_none());
589 }
590
591 #[test]
592 fn test_failure_info() {
593 let info = FailureInfo::new(5, "Test error".to_string()).with_retries(3);
594
595 assert_eq!(info.operation_id, 5);
596 assert_eq!(info.error, "Test error");
597 assert_eq!(info.retry_count, 3);
598 }
599
600 #[test]
601 fn test_checkpoint_age() {
602 let checkpoint: Checkpoint<i32> = Checkpoint::new(0, 0, vec![1, 2, 3]);
603 std::thread::sleep(Duration::from_millis(10));
604 let age = checkpoint.age();
605 assert!(age >= Duration::from_millis(10));
606 }
607}