1use crate::checkpoint_system::{ErrorCategory, ErrorRecovery, RecoveryLayer};
8use crate::planner::{ExecutionPlan, RiskLevel};
9use crate::types::Layer2Result;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum ExecutionStatus {
19 Pending,
21 Running,
23 Paused,
25 StepCompleted,
27 StepFailed,
29 Completed,
31 Failed,
33 AwaitingUserInput,
35 Cancelled,
37}
38
39impl Default for ExecutionStatus {
40 fn default() -> Self {
41 Self::Pending
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct StepResult {
48 pub subtask_id: String,
50 pub status: ExecutionStatus,
52 pub output: Option<String>,
54 pub error: Option<String>,
56 pub duration: Duration,
58 pub retry_count: u32,
60 pub recovery_layer: Option<RecoveryLayer>,
62}
63
64pub struct ExecutionMonitor {
66 plan: Arc<RwLock<ExecutionPlan>>,
68 status: Arc<RwLock<ExecutionStatus>>,
70 step_results: Arc<RwLock<HashMap<String, StepResult>>>,
72 error_recovery: Arc<ErrorRecovery>,
74 start_time: Arc<RwLock<Option<Instant>>>,
76 progress_callback: Arc<RwLock<Option<Box<dyn Fn(&str, ExecutionStatus) + Send + Sync>>>>,
78 correction_history: Arc<RwLock<Vec<CorrectionRecord>>>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct CorrectionRecord {
85 pub id: String,
87 pub failed_subtask: String,
89 pub error_category: String,
91 pub original_error: String,
93 pub strategy: CorrectionStrategy,
95 pub success: bool,
97 pub timestamp: chrono::DateTime<chrono::Utc>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum CorrectionStrategy {
104 Retry { max_attempts: u32 },
106 Skip,
108 Alternative { replacement_subtask: String },
110 Decompose { new_subtasks: Vec<String> },
112 UserIntervention { action: String },
114 AdjustParameters { new_params: serde_json::Value },
116}
117
118impl CorrectionStrategy {
119 pub fn debug_name(&self) -> &'static str {
121 match self {
122 CorrectionStrategy::Retry { .. } => "Retry",
123 CorrectionStrategy::Skip => "Skip",
124 CorrectionStrategy::Alternative { .. } => "Alternative",
125 CorrectionStrategy::Decompose { .. } => "Decompose",
126 CorrectionStrategy::UserIntervention { .. } => "UserIntervention",
127 CorrectionStrategy::AdjustParameters { .. } => "AdjustParameters",
128 }
129 }
130}
131
132impl ExecutionMonitor {
133 pub fn new(plan: ExecutionPlan) -> Self {
135 Self {
136 plan: Arc::new(RwLock::new(plan)),
137 status: Arc::new(RwLock::new(ExecutionStatus::Pending)),
138 step_results: Arc::new(RwLock::new(HashMap::new())),
139 error_recovery: Arc::new(ErrorRecovery::new()),
140 start_time: Arc::new(RwLock::new(None)),
141 progress_callback: Arc::new(RwLock::new(None)),
142 correction_history: Arc::new(RwLock::new(Vec::new())),
143 }
144 }
145
146 pub async fn set_progress_callback<F>(&self, callback: F)
148 where
149 F: Fn(&str, ExecutionStatus) + Send + Sync + 'static,
150 {
151 *self.progress_callback.write().await = Some(Box::new(callback));
152 }
153
154 pub async fn get_status(&self) -> ExecutionStatus {
156 *self.status.read().await
157 }
158
159 pub async fn get_progress(&self) -> u32 {
161 let plan = self.plan.read().await;
162 let results = self.step_results.read().await;
163
164 if plan.subtasks.is_empty() {
165 return 0;
166 }
167
168 let completed = results
169 .values()
170 .filter(|r| matches!(r.status, ExecutionStatus::StepCompleted))
171 .count();
172
173 (completed as u32 * 100) / plan.subtasks.len() as u32
174 }
175
176 pub async fn start(&self) -> Layer2Result<()> {
178 let mut status = self.status.write().await;
179 *status = ExecutionStatus::Running;
180 drop(status);
181
182 *self.start_time.write().await = Some(Instant::now());
183
184 self.notify_progress("execution_started", ExecutionStatus::Running)
186 .await;
187
188 Ok(())
189 }
190
191 pub async fn report_step_completed(
193 &self,
194 subtask_id: &str,
195 output: String,
196 ) -> Layer2Result<()> {
197 let result = StepResult {
198 subtask_id: subtask_id.to_string(),
199 status: ExecutionStatus::StepCompleted,
200 output: Some(output),
201 error: None,
202 duration: Duration::from_secs(0),
203 retry_count: 0,
204 recovery_layer: None,
205 };
206
207 self.step_results
208 .write()
209 .await
210 .insert(subtask_id.to_string(), result);
211 self.notify_progress(subtask_id, ExecutionStatus::StepCompleted)
212 .await;
213
214 Ok(())
215 }
216
217 pub async fn report_step_failed(
219 &self,
220 subtask_id: &str,
221 error: String,
222 ) -> Layer2Result<CorrectionDecision> {
223 let category = ErrorCategory::from_error_message(&error);
225
226 let result = StepResult {
228 subtask_id: subtask_id.to_string(),
229 status: ExecutionStatus::StepFailed,
230 output: None,
231 error: Some(error.clone()),
232 duration: Duration::from_secs(0),
233 retry_count: 0,
234 recovery_layer: None,
235 };
236
237 self.step_results
238 .write()
239 .await
240 .insert(subtask_id.to_string(), result);
241
242 let decision = self.decide_correction(subtask_id, &category, &error).await;
244
245 self.record_correction(subtask_id, &category, &error, &decision)
247 .await;
248
249 self.notify_progress(subtask_id, ExecutionStatus::StepFailed)
250 .await;
251
252 Ok(decision)
253 }
254
255 async fn decide_correction(
257 &self,
258 subtask_id: &str,
259 category: &ErrorCategory,
260 error: &str,
261 ) -> CorrectionDecision {
262 let plan = self.plan.read().await;
263
264 let subtask = plan.subtasks.iter().find(|s| s.id == subtask_id);
266
267 match category {
269 ErrorCategory::Transient => {
270 CorrectionDecision {
272 strategy: CorrectionStrategy::Retry { max_attempts: 3 },
273 should_continue: true,
274 user_message: Some("Temporary error, will retry automatically".to_string()),
275 }
276 }
277 ErrorCategory::Resource => {
278 CorrectionDecision {
280 strategy: CorrectionStrategy::Retry { max_attempts: 2 },
281 should_continue: true,
282 user_message: Some("Resource issue, waiting before retry".to_string()),
283 }
284 }
285 ErrorCategory::Logic => {
286 if let Some(subtask) = subtask {
288 if let Some(fallback) = &subtask.fallback {
289 CorrectionDecision {
290 strategy: CorrectionStrategy::Alternative {
291 replacement_subtask: fallback.name.clone(),
292 },
293 should_continue: true,
294 user_message: Some("Using fallback strategy".to_string()),
295 }
296 } else {
297 CorrectionDecision {
299 strategy: CorrectionStrategy::Decompose {
300 new_subtasks: vec!["simplified_step".to_string()],
301 },
302 should_continue: true,
303 user_message: Some("Breaking down the task".to_string()),
304 }
305 }
306 } else {
307 CorrectionDecision {
308 strategy: CorrectionStrategy::Skip,
309 should_continue: true,
310 user_message: Some("Skipping failed step".to_string()),
311 }
312 }
313 }
314 ErrorCategory::Configuration => {
315 CorrectionDecision {
317 strategy: CorrectionStrategy::UserIntervention {
318 action: "Please check your configuration".to_string(),
319 },
320 should_continue: false,
321 user_message: Some(format!("Configuration error: {}", error)),
322 }
323 }
324 ErrorCategory::UserInterrupt => {
325 CorrectionDecision {
327 strategy: CorrectionStrategy::Skip,
328 should_continue: false,
329 user_message: Some("Execution cancelled by user".to_string()),
330 }
331 }
332 ErrorCategory::System => {
333 if plan.risk_level == RiskLevel::Critical {
335 CorrectionDecision {
336 strategy: CorrectionStrategy::UserIntervention {
337 action: "Critical error requires manual intervention".to_string(),
338 },
339 should_continue: false,
340 user_message: Some(format!("Critical system error: {}", error)),
341 }
342 } else {
343 CorrectionDecision {
344 strategy: CorrectionStrategy::Retry { max_attempts: 1 },
345 should_continue: true,
346 user_message: Some("System error, attempting recovery".to_string()),
347 }
348 }
349 }
350 }
351 }
352
353 async fn record_correction(
355 &self,
356 subtask_id: &str,
357 category: &ErrorCategory,
358 error: &str,
359 decision: &CorrectionDecision,
360 ) {
361 let category_str = match category {
362 ErrorCategory::Transient => "Transient",
363 ErrorCategory::Resource => "Resource",
364 ErrorCategory::Configuration => "Configuration",
365 ErrorCategory::Logic => "Logic",
366 ErrorCategory::System => "System",
367 ErrorCategory::UserInterrupt => "UserInterrupt",
368 };
369 let record = CorrectionRecord {
370 id: format!("correction_{}", chrono::Utc::now().timestamp()),
371 failed_subtask: subtask_id.to_string(),
372 error_category: category_str.to_string(),
373 original_error: error.to_string(),
374 strategy: decision.strategy.clone(),
375 success: false, timestamp: chrono::Utc::now(),
377 };
378
379 self.correction_history.write().await.push(record);
380 }
381
382 pub async fn apply_correction(
384 &self,
385 subtask_id: &str,
386 decision: &CorrectionDecision,
387 ) -> Layer2Result<bool> {
388 match &decision.strategy {
389 CorrectionStrategy::Retry { max_attempts } => {
390 Ok(true)
393 }
394 CorrectionStrategy::Skip => {
395 self.report_step_completed(
397 subtask_id,
398 "[SKIPPED] Step skipped due to unrecoverable error".to_string(),
399 )
400 .await?;
401 Ok(true)
402 }
403 CorrectionStrategy::Alternative {
404 replacement_subtask,
405 } => {
406 self.report_step_completed(
408 subtask_id,
409 format!("[ALTERNATIVE] Used: {}", replacement_subtask),
410 )
411 .await?;
412 Ok(true)
413 }
414 CorrectionStrategy::UserIntervention { action } => {
415 let mut status = self.status.write().await;
417 *status = ExecutionStatus::AwaitingUserInput;
418 Ok(false)
419 }
420 CorrectionStrategy::Decompose { new_subtasks } => {
421 self.report_step_completed(
423 subtask_id,
424 format!("[DECOMPOSED] Into: {}", new_subtasks.join(", ")),
425 )
426 .await?;
427 Ok(true)
428 }
429 CorrectionStrategy::AdjustParameters { new_params } => {
430 self.report_step_completed(
432 subtask_id,
433 "[ADJUSTED] Parameters modified".to_string(),
434 )
435 .await?;
436 Ok(true)
437 }
438 }
439 }
440
441 pub async fn complete(&self) -> Layer2Result<ExecutionSummary> {
443 let mut status = self.status.write().await;
444 *status = ExecutionStatus::Completed;
445 drop(status);
446
447 self.notify_progress("execution_completed", ExecutionStatus::Completed)
448 .await;
449
450 let plan = self.plan.read().await;
451 let results = self.step_results.read().await;
452 let corrections = self.correction_history.read().await;
453 let start_time = self.start_time.read().await;
454
455 let completed = results
456 .values()
457 .filter(|r| matches!(r.status, ExecutionStatus::StepCompleted))
458 .count();
459 let failed = results
460 .values()
461 .filter(|r| matches!(r.status, ExecutionStatus::StepFailed))
462 .count();
463 let skipped = results
464 .values()
465 .filter(|r| {
466 r.output
467 .as_ref()
468 .map(|o| o.starts_with("[SKIPPED]"))
469 .unwrap_or(false)
470 })
471 .count();
472
473 Ok(ExecutionSummary {
474 plan_id: plan.id.clone(),
475 total_steps: plan.subtasks.len(),
476 completed_steps: completed,
477 failed_steps: failed,
478 skipped_steps: skipped,
479 correction_count: corrections.len(),
480 duration: start_time.map(|t| t.elapsed()).unwrap_or(Duration::ZERO),
481 status: ExecutionStatus::Completed,
482 })
483 }
484
485 pub async fn get_correction_history(&self) -> Vec<CorrectionRecord> {
487 self.correction_history.read().await.clone()
488 }
489
490 async fn notify_progress(&self, subtask_id: &str, status: ExecutionStatus) {
492 if let Some(callback) = self.progress_callback.read().await.as_ref() {
493 callback(subtask_id, status);
494 }
495 }
496}
497
498#[derive(Debug, Clone)]
500pub struct CorrectionDecision {
501 pub strategy: CorrectionStrategy,
503 pub should_continue: bool,
505 pub user_message: Option<String>,
507}
508
509#[derive(Debug, Clone, Serialize, Deserialize)]
511pub struct ExecutionSummary {
512 pub plan_id: String,
514 pub total_steps: usize,
516 pub completed_steps: usize,
518 pub failed_steps: usize,
520 pub skipped_steps: usize,
522 pub correction_count: usize,
524 pub duration: Duration,
526 pub status: ExecutionStatus,
528}
529
530pub struct SelfCorrector {
532 history: RwLock<Vec<CorrectionRecord>>,
534 patterns: RwLock<HashMap<String, CorrectionStrategy>>,
536}
537
538impl Default for SelfCorrector {
539 fn default() -> Self {
540 Self::new()
541 }
542}
543
544impl SelfCorrector {
545 pub fn new() -> Self {
547 Self {
548 history: RwLock::new(Vec::new()),
549 patterns: RwLock::new(HashMap::new()),
550 }
551 }
552
553 pub async fn learn_pattern(&self, error_signature: &str, strategy: CorrectionStrategy) {
555 self.patterns
556 .write()
557 .await
558 .insert(error_signature.to_string(), strategy);
559 }
560
561 pub async fn get_recommended_strategy(&self, error: &str) -> Option<CorrectionStrategy> {
563 let patterns = self.patterns.read().await;
564
565 for (signature, strategy) in patterns.iter() {
567 if error.contains(signature) {
568 return Some(strategy.clone());
569 }
570 }
571
572 None
573 }
574
575 pub async fn record_result(&self, record: CorrectionRecord) {
577 if record.success {
579 let signature = Self::extract_signature(&record.original_error);
580 self.learn_pattern(&signature, record.strategy.clone())
581 .await;
582 }
583
584 self.history.write().await.push(record);
585 }
586
587 fn extract_signature(error: &str) -> String {
589 let error_lower = error.to_lowercase();
591 if error_lower.len() > 50 {
592 error_lower[..50].to_string()
593 } else {
594 error_lower
595 }
596 }
597
598 pub async fn get_success_rate(&self) -> f32 {
600 let history = self.history.read().await;
601 if history.is_empty() {
602 return 0.0;
603 }
604
605 let success_count = history.iter().filter(|r| r.success).count();
606 success_count as f32 / history.len() as f32
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613 use crate::planner::{ExecutionPlan, SubTask};
614
615 #[tokio::test]
616 async fn test_execution_monitor_creation() {
617 let plan = ExecutionPlan::new("Test task");
618 let monitor = ExecutionMonitor::new(plan);
619
620 let status = monitor.get_status().await;
621 assert_eq!(status, ExecutionStatus::Pending);
622 }
623
624 #[tokio::test]
625 async fn test_progress_calculation() {
626 let mut plan = ExecutionPlan::new("Test task");
627 plan.add_subtask(SubTask::new("s1", "Step 1", "First"));
628 plan.add_subtask(SubTask::new("s2", "Step 2", "Second"));
629 plan.compute_execution_order().unwrap();
630
631 let monitor = ExecutionMonitor::new(plan);
632 monitor.start().await.unwrap();
633
634 assert_eq!(monitor.get_progress().await, 0);
635
636 monitor
637 .report_step_completed("s1", "Done".to_string())
638 .await
639 .unwrap();
640 assert_eq!(monitor.get_progress().await, 50);
641
642 monitor
643 .report_step_completed("s2", "Done".to_string())
644 .await
645 .unwrap();
646 assert_eq!(monitor.get_progress().await, 100);
647 }
648
649 #[tokio::test]
650 async fn test_correction_decision() {
651 let plan = ExecutionPlan::new("Test task");
652 let monitor = ExecutionMonitor::new(plan);
653
654 let decision = monitor
655 .decide_correction("test_subtask", &ErrorCategory::Transient, "Network timeout")
656 .await;
657
658 assert!(decision.should_continue);
659 matches!(decision.strategy, CorrectionStrategy::Retry { .. });
660 }
661
662 #[tokio::test]
663 async fn test_self_corrector_learning() {
664 let corrector = SelfCorrector::new();
665
666 corrector
667 .learn_pattern("timeout", CorrectionStrategy::Retry { max_attempts: 3 })
668 .await;
669
670 let strategy = corrector
671 .get_recommended_strategy("Connection timeout occurred")
672 .await;
673 assert!(strategy.is_some());
674 }
675
676 #[tokio::test]
677 async fn test_correction_history() {
678 let plan = ExecutionPlan::new("Test task");
679 let monitor = ExecutionMonitor::new(plan);
680
681 monitor
682 .report_step_failed("s1", "Error occurred".to_string())
683 .await
684 .unwrap();
685
686 let history = monitor.get_correction_history().await;
687 assert!(!history.is_empty());
688 }
689}