1use crate::checkpoint_system::{ErrorCategory, ErrorRecovery, RecoveryLayer};
8use crate::planner::{ExecutionPlan, RiskLevel};
9use crate::types::Layer2Result;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
18pub enum ExecutionStatus {
19 #[default]
21 Pending,
22 Running,
24 Paused,
26 StepCompleted,
28 StepFailed,
30 Completed,
32 Failed,
34 AwaitingUserInput,
36 Cancelled,
38}
39
40#[derive(Debug, Clone)]
42pub struct StepResult {
43 pub subtask_id: String,
45 pub status: ExecutionStatus,
47 pub output: Option<String>,
49 pub error: Option<String>,
51 pub duration: Duration,
53 pub retry_count: u32,
55 pub recovery_layer: Option<RecoveryLayer>,
57}
58
59#[allow(clippy::type_complexity)]
61pub struct ExecutionMonitor {
62 plan: Arc<RwLock<ExecutionPlan>>,
64 status: Arc<RwLock<ExecutionStatus>>,
66 step_results: Arc<RwLock<HashMap<String, StepResult>>>,
68 #[allow(dead_code)]
70 error_recovery: Arc<ErrorRecovery>,
71 start_time: Arc<RwLock<Option<Instant>>>,
73 progress_callback: Arc<RwLock<Option<Box<dyn Fn(&str, ExecutionStatus) + Send + Sync>>>>,
75 correction_history: Arc<RwLock<Vec<CorrectionRecord>>>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct CorrectionRecord {
82 pub id: String,
84 pub failed_subtask: String,
86 pub error_category: String,
88 pub original_error: String,
90 pub strategy: CorrectionStrategy,
92 pub success: bool,
94 pub timestamp: chrono::DateTime<chrono::Utc>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub enum CorrectionStrategy {
101 Retry { max_attempts: u32 },
103 Skip,
105 Alternative { replacement_subtask: String },
107 Decompose { new_subtasks: Vec<String> },
109 UserIntervention { action: String },
111 AdjustParameters { new_params: serde_json::Value },
113}
114
115impl CorrectionStrategy {
116 pub fn debug_name(&self) -> &'static str {
118 match self {
119 CorrectionStrategy::Retry { .. } => "Retry",
120 CorrectionStrategy::Skip => "Skip",
121 CorrectionStrategy::Alternative { .. } => "Alternative",
122 CorrectionStrategy::Decompose { .. } => "Decompose",
123 CorrectionStrategy::UserIntervention { .. } => "UserIntervention",
124 CorrectionStrategy::AdjustParameters { .. } => "AdjustParameters",
125 }
126 }
127}
128
129impl ExecutionMonitor {
130 pub fn new(plan: ExecutionPlan) -> Self {
132 Self {
133 plan: Arc::new(RwLock::new(plan)),
134 status: Arc::new(RwLock::new(ExecutionStatus::Pending)),
135 step_results: Arc::new(RwLock::new(HashMap::new())),
136 error_recovery: Arc::new(ErrorRecovery::new()),
137 start_time: Arc::new(RwLock::new(None)),
138 progress_callback: Arc::new(RwLock::new(None)),
139 correction_history: Arc::new(RwLock::new(Vec::new())),
140 }
141 }
142
143 pub async fn set_progress_callback<F>(&self, callback: F)
145 where
146 F: Fn(&str, ExecutionStatus) + Send + Sync + 'static,
147 {
148 *self.progress_callback.write().await = Some(Box::new(callback));
149 }
150
151 pub async fn get_status(&self) -> ExecutionStatus {
153 *self.status.read().await
154 }
155
156 pub async fn get_progress(&self) -> u32 {
158 let plan = self.plan.read().await;
159 let results = self.step_results.read().await;
160
161 if plan.subtasks.is_empty() {
162 return 0;
163 }
164
165 let completed = results
166 .values()
167 .filter(|r| matches!(r.status, ExecutionStatus::StepCompleted))
168 .count();
169
170 (completed as u32 * 100) / plan.subtasks.len() as u32
171 }
172
173 pub async fn start(&self) -> Layer2Result<()> {
175 let mut status = self.status.write().await;
176 *status = ExecutionStatus::Running;
177 drop(status);
178
179 *self.start_time.write().await = Some(Instant::now());
180
181 self.notify_progress("execution_started", ExecutionStatus::Running)
183 .await;
184
185 Ok(())
186 }
187
188 pub async fn report_step_completed(
190 &self,
191 subtask_id: &str,
192 output: String,
193 ) -> Layer2Result<()> {
194 let result = StepResult {
195 subtask_id: subtask_id.to_string(),
196 status: ExecutionStatus::StepCompleted,
197 output: Some(output),
198 error: None,
199 duration: Duration::from_secs(0),
200 retry_count: 0,
201 recovery_layer: None,
202 };
203
204 self.step_results
205 .write()
206 .await
207 .insert(subtask_id.to_string(), result);
208 self.notify_progress(subtask_id, ExecutionStatus::StepCompleted)
209 .await;
210
211 Ok(())
212 }
213
214 pub async fn report_step_failed(
216 &self,
217 subtask_id: &str,
218 error: String,
219 ) -> Layer2Result<CorrectionDecision> {
220 let category = ErrorCategory::from_error_message(&error);
222
223 let result = StepResult {
225 subtask_id: subtask_id.to_string(),
226 status: ExecutionStatus::StepFailed,
227 output: None,
228 error: Some(error.clone()),
229 duration: Duration::from_secs(0),
230 retry_count: 0,
231 recovery_layer: None,
232 };
233
234 self.step_results
235 .write()
236 .await
237 .insert(subtask_id.to_string(), result);
238
239 let decision = self.decide_correction(subtask_id, &category, &error).await;
241
242 self.record_correction(subtask_id, &category, &error, &decision)
244 .await;
245
246 self.notify_progress(subtask_id, ExecutionStatus::StepFailed)
247 .await;
248
249 Ok(decision)
250 }
251
252 async fn decide_correction(
254 &self,
255 subtask_id: &str,
256 category: &ErrorCategory,
257 error: &str,
258 ) -> CorrectionDecision {
259 let plan = self.plan.read().await;
260
261 let subtask = plan.subtasks.iter().find(|s| s.id == subtask_id);
263
264 match category {
266 ErrorCategory::Transient => {
267 CorrectionDecision {
269 strategy: CorrectionStrategy::Retry { max_attempts: 3 },
270 should_continue: true,
271 user_message: Some("Temporary error, will retry automatically".to_string()),
272 }
273 }
274 ErrorCategory::Resource => {
275 CorrectionDecision {
277 strategy: CorrectionStrategy::Retry { max_attempts: 2 },
278 should_continue: true,
279 user_message: Some("Resource issue, waiting before retry".to_string()),
280 }
281 }
282 ErrorCategory::Logic => {
283 if let Some(subtask) = subtask {
285 if let Some(fallback) = &subtask.fallback {
286 CorrectionDecision {
287 strategy: CorrectionStrategy::Alternative {
288 replacement_subtask: fallback.name.clone(),
289 },
290 should_continue: true,
291 user_message: Some("Using fallback strategy".to_string()),
292 }
293 } else {
294 CorrectionDecision {
296 strategy: CorrectionStrategy::Decompose {
297 new_subtasks: vec!["simplified_step".to_string()],
298 },
299 should_continue: true,
300 user_message: Some("Breaking down the task".to_string()),
301 }
302 }
303 } else {
304 CorrectionDecision {
305 strategy: CorrectionStrategy::Skip,
306 should_continue: true,
307 user_message: Some("Skipping failed step".to_string()),
308 }
309 }
310 }
311 ErrorCategory::Configuration => {
312 CorrectionDecision {
314 strategy: CorrectionStrategy::UserIntervention {
315 action: "Please check your configuration".to_string(),
316 },
317 should_continue: false,
318 user_message: Some(format!("Configuration error: {}", error)),
319 }
320 }
321 ErrorCategory::UserInterrupt => {
322 CorrectionDecision {
324 strategy: CorrectionStrategy::Skip,
325 should_continue: false,
326 user_message: Some("Execution cancelled by user".to_string()),
327 }
328 }
329 ErrorCategory::System => {
330 if plan.risk_level == RiskLevel::Critical {
332 CorrectionDecision {
333 strategy: CorrectionStrategy::UserIntervention {
334 action: "Critical error requires manual intervention".to_string(),
335 },
336 should_continue: false,
337 user_message: Some(format!("Critical system error: {}", error)),
338 }
339 } else {
340 CorrectionDecision {
341 strategy: CorrectionStrategy::Retry { max_attempts: 1 },
342 should_continue: true,
343 user_message: Some("System error, attempting recovery".to_string()),
344 }
345 }
346 }
347 }
348 }
349
350 async fn record_correction(
352 &self,
353 subtask_id: &str,
354 category: &ErrorCategory,
355 error: &str,
356 decision: &CorrectionDecision,
357 ) {
358 let category_str = match category {
359 ErrorCategory::Transient => "Transient",
360 ErrorCategory::Resource => "Resource",
361 ErrorCategory::Configuration => "Configuration",
362 ErrorCategory::Logic => "Logic",
363 ErrorCategory::System => "System",
364 ErrorCategory::UserInterrupt => "UserInterrupt",
365 };
366 let record = CorrectionRecord {
367 id: format!("correction_{}", chrono::Utc::now().timestamp()),
368 failed_subtask: subtask_id.to_string(),
369 error_category: category_str.to_string(),
370 original_error: error.to_string(),
371 strategy: decision.strategy.clone(),
372 success: false, timestamp: chrono::Utc::now(),
374 };
375
376 self.correction_history.write().await.push(record);
377 }
378
379 pub async fn apply_correction(
381 &self,
382 subtask_id: &str,
383 decision: &CorrectionDecision,
384 ) -> Layer2Result<bool> {
385 match &decision.strategy {
386 CorrectionStrategy::Retry { max_attempts: _ } => {
387 Ok(true)
390 }
391 CorrectionStrategy::Skip => {
392 self.report_step_completed(
394 subtask_id,
395 "[SKIPPED] Step skipped due to unrecoverable error".to_string(),
396 )
397 .await?;
398 Ok(true)
399 }
400 CorrectionStrategy::Alternative {
401 replacement_subtask,
402 } => {
403 self.report_step_completed(
405 subtask_id,
406 format!("[ALTERNATIVE] Used: {}", replacement_subtask),
407 )
408 .await?;
409 Ok(true)
410 }
411 CorrectionStrategy::UserIntervention { action: _ } => {
412 let mut status = self.status.write().await;
414 *status = ExecutionStatus::AwaitingUserInput;
415 Ok(false)
416 }
417 CorrectionStrategy::Decompose { new_subtasks } => {
418 self.report_step_completed(
420 subtask_id,
421 format!("[DECOMPOSED] Into: {}", new_subtasks.join(", ")),
422 )
423 .await?;
424 Ok(true)
425 }
426 CorrectionStrategy::AdjustParameters { new_params: _ } => {
427 self.report_step_completed(
429 subtask_id,
430 "[ADJUSTED] Parameters modified".to_string(),
431 )
432 .await?;
433 Ok(true)
434 }
435 }
436 }
437
438 pub async fn complete(&self) -> Layer2Result<ExecutionSummary> {
440 let mut status = self.status.write().await;
441 *status = ExecutionStatus::Completed;
442 drop(status);
443
444 self.notify_progress("execution_completed", ExecutionStatus::Completed)
445 .await;
446
447 let plan = self.plan.read().await;
448 let results = self.step_results.read().await;
449 let corrections = self.correction_history.read().await;
450 let start_time = self.start_time.read().await;
451
452 let completed = results
453 .values()
454 .filter(|r| matches!(r.status, ExecutionStatus::StepCompleted))
455 .count();
456 let failed = results
457 .values()
458 .filter(|r| matches!(r.status, ExecutionStatus::StepFailed))
459 .count();
460 let skipped = results
461 .values()
462 .filter(|r| {
463 r.output
464 .as_ref()
465 .map(|o| o.starts_with("[SKIPPED]"))
466 .unwrap_or(false)
467 })
468 .count();
469
470 Ok(ExecutionSummary {
471 plan_id: plan.id.clone(),
472 total_steps: plan.subtasks.len(),
473 completed_steps: completed,
474 failed_steps: failed,
475 skipped_steps: skipped,
476 correction_count: corrections.len(),
477 duration: start_time.map(|t| t.elapsed()).unwrap_or(Duration::ZERO),
478 status: ExecutionStatus::Completed,
479 })
480 }
481
482 pub async fn get_correction_history(&self) -> Vec<CorrectionRecord> {
484 self.correction_history.read().await.clone()
485 }
486
487 async fn notify_progress(&self, subtask_id: &str, status: ExecutionStatus) {
489 if let Some(callback) = self.progress_callback.read().await.as_ref() {
490 callback(subtask_id, status);
491 }
492 }
493}
494
495#[derive(Debug, Clone)]
497pub struct CorrectionDecision {
498 pub strategy: CorrectionStrategy,
500 pub should_continue: bool,
502 pub user_message: Option<String>,
504}
505
506#[derive(Debug, Clone, Serialize, Deserialize)]
508pub struct ExecutionSummary {
509 pub plan_id: String,
511 pub total_steps: usize,
513 pub completed_steps: usize,
515 pub failed_steps: usize,
517 pub skipped_steps: usize,
519 pub correction_count: usize,
521 pub duration: Duration,
523 pub status: ExecutionStatus,
525}
526
527pub struct SelfCorrector {
529 history: RwLock<Vec<CorrectionRecord>>,
531 patterns: RwLock<HashMap<String, CorrectionStrategy>>,
533}
534
535impl Default for SelfCorrector {
536 fn default() -> Self {
537 Self::new()
538 }
539}
540
541impl SelfCorrector {
542 pub fn new() -> Self {
544 Self {
545 history: RwLock::new(Vec::new()),
546 patterns: RwLock::new(HashMap::new()),
547 }
548 }
549
550 pub async fn learn_pattern(&self, error_signature: &str, strategy: CorrectionStrategy) {
552 self.patterns
553 .write()
554 .await
555 .insert(error_signature.to_string(), strategy);
556 }
557
558 pub async fn get_recommended_strategy(&self, error: &str) -> Option<CorrectionStrategy> {
560 let patterns = self.patterns.read().await;
561
562 for (signature, strategy) in patterns.iter() {
564 if error.contains(signature) {
565 return Some(strategy.clone());
566 }
567 }
568
569 None
570 }
571
572 pub async fn record_result(&self, record: CorrectionRecord) {
574 if record.success {
576 let signature = Self::extract_signature(&record.original_error);
577 self.learn_pattern(&signature, record.strategy.clone())
578 .await;
579 }
580
581 self.history.write().await.push(record);
582 }
583
584 fn extract_signature(error: &str) -> String {
586 let error_lower = error.to_lowercase();
588 if error_lower.len() > 50 {
589 error_lower[..50].to_string()
590 } else {
591 error_lower
592 }
593 }
594
595 pub async fn get_success_rate(&self) -> f32 {
597 let history = self.history.read().await;
598 if history.is_empty() {
599 return 0.0;
600 }
601
602 let success_count = history.iter().filter(|r| r.success).count();
603 success_count as f32 / history.len() as f32
604 }
605}
606
607#[cfg(test)]
608mod tests {
609 use super::*;
610 use crate::planner::{ExecutionPlan, SubTask};
611
612 #[tokio::test]
613 async fn test_execution_monitor_creation() {
614 let plan = ExecutionPlan::new("Test task");
615 let monitor = ExecutionMonitor::new(plan);
616
617 let status = monitor.get_status().await;
618 assert_eq!(status, ExecutionStatus::Pending);
619 }
620
621 #[tokio::test]
622 async fn test_progress_calculation() {
623 let mut plan = ExecutionPlan::new("Test task");
624 plan.add_subtask(SubTask::new("s1", "Step 1", "First"));
625 plan.add_subtask(SubTask::new("s2", "Step 2", "Second"));
626 plan.compute_execution_order().unwrap();
627
628 let monitor = ExecutionMonitor::new(plan);
629 monitor.start().await.unwrap();
630
631 assert_eq!(monitor.get_progress().await, 0);
632
633 monitor
634 .report_step_completed("s1", "Done".to_string())
635 .await
636 .unwrap();
637 assert_eq!(monitor.get_progress().await, 50);
638
639 monitor
640 .report_step_completed("s2", "Done".to_string())
641 .await
642 .unwrap();
643 assert_eq!(monitor.get_progress().await, 100);
644 }
645
646 #[tokio::test]
647 async fn test_correction_decision() {
648 let plan = ExecutionPlan::new("Test task");
649 let monitor = ExecutionMonitor::new(plan);
650
651 let decision = monitor
652 .decide_correction("test_subtask", &ErrorCategory::Transient, "Network timeout")
653 .await;
654
655 assert!(decision.should_continue);
656 matches!(decision.strategy, CorrectionStrategy::Retry { .. });
657 }
658
659 #[tokio::test]
660 async fn test_self_corrector_learning() {
661 let corrector = SelfCorrector::new();
662
663 corrector
664 .learn_pattern("timeout", CorrectionStrategy::Retry { max_attempts: 3 })
665 .await;
666
667 let strategy = corrector
668 .get_recommended_strategy("Connection timeout occurred")
669 .await;
670 assert!(strategy.is_some());
671 }
672
673 #[tokio::test]
674 async fn test_correction_history() {
675 let plan = ExecutionPlan::new("Test task");
676 let monitor = ExecutionMonitor::new(plan);
677
678 monitor
679 .report_step_failed("s1", "Error occurred".to_string())
680 .await
681 .unwrap();
682
683 let history = monitor.get_correction_history().await;
684 assert!(!history.is_empty());
685 }
686}