Skip to main content

synth_ai_core/orchestration/
progress.rs

1//! Progress tracking for optimization jobs.
2//!
3//! This module provides progress tracking and aggregation for GEPA/MIPRO
4//! optimization jobs.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::Instant;
9
10use super::events::{EventCategory, EventParser, ParsedEvent};
11
12/// Token usage statistics.
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct TokenUsage {
15    /// Input/prompt tokens
16    pub prompt_tokens: i64,
17    /// Output/completion tokens
18    pub completion_tokens: i64,
19    /// Total tokens
20    pub total_tokens: i64,
21    /// Reasoning tokens (for o1-style models)
22    #[serde(default)]
23    pub reasoning_tokens: i64,
24    /// Cached tokens
25    #[serde(default)]
26    pub cached_tokens: i64,
27}
28
29impl TokenUsage {
30    /// Create from prompt and completion counts.
31    pub fn new(prompt: i64, completion: i64) -> Self {
32        Self {
33            prompt_tokens: prompt,
34            completion_tokens: completion,
35            total_tokens: prompt + completion,
36            reasoning_tokens: 0,
37            cached_tokens: 0,
38        }
39    }
40}
41
42/// Information about a single candidate.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CandidateInfo {
45    /// Unique candidate ID
46    pub candidate_id: String,
47    /// Accuracy/score on training set
48    #[serde(default)]
49    pub accuracy: Option<f64>,
50    /// Multi-objective scores
51    #[serde(default)]
52    pub objectives: Option<HashMap<String, f64>>,
53    /// Validation accuracy (if validation phase completed)
54    #[serde(default)]
55    pub val_accuracy: Option<f64>,
56    /// Training accuracy (alias for accuracy)
57    #[serde(default)]
58    pub train_accuracy: Option<f64>,
59    /// Generation number
60    #[serde(default)]
61    pub generation: Option<i32>,
62    /// Parent candidate ID (for mutations)
63    #[serde(default)]
64    pub parent_id: Option<String>,
65    /// Whether on Pareto frontier
66    #[serde(default)]
67    pub is_pareto: bool,
68    /// Whether accepted into population
69    #[serde(default)]
70    pub accepted: bool,
71    /// Type of mutation used
72    #[serde(default)]
73    pub mutation_type: Option<String>,
74    /// Token usage for this candidate
75    #[serde(default)]
76    pub token_usage: Option<TokenUsage>,
77    /// Cost in USD
78    #[serde(default)]
79    pub cost_usd: Option<f64>,
80    /// Unix timestamp when evaluated
81    #[serde(default)]
82    pub timestamp: f64,
83    /// Timestamp in milliseconds
84    #[serde(default)]
85    pub timestamp_ms: Option<i64>,
86}
87
88impl Default for CandidateInfo {
89    fn default() -> Self {
90        Self {
91            candidate_id: String::new(),
92            accuracy: None,
93            objectives: None,
94            val_accuracy: None,
95            train_accuracy: None,
96            generation: None,
97            parent_id: None,
98            is_pareto: false,
99            accepted: false,
100            mutation_type: None,
101            token_usage: None,
102            cost_usd: None,
103            timestamp: 0.0,
104            timestamp_ms: None,
105        }
106    }
107}
108
109/// Baseline information.
110#[derive(Debug, Clone, Default, Serialize, Deserialize)]
111pub struct BaselineInfo {
112    /// Baseline accuracy/score
113    pub accuracy: Option<f64>,
114    /// Multi-objective scores
115    #[serde(default)]
116    pub objectives: Option<HashMap<String, f64>>,
117    /// Validation accuracy
118    #[serde(default)]
119    pub val_accuracy: Option<f64>,
120    /// Per-instance scores
121    #[serde(default)]
122    pub instance_scores: Vec<f64>,
123}
124
125/// Frontier update record.
126#[derive(Debug, Clone, Default, Serialize, Deserialize)]
127pub struct FrontierUpdate {
128    /// Update timestamp
129    pub timestamp: f64,
130    /// Candidates added
131    #[serde(default)]
132    pub added: Vec<String>,
133    /// Candidates removed
134    #[serde(default)]
135    pub removed: Vec<String>,
136    /// Current frontier
137    #[serde(default)]
138    pub frontier: Vec<String>,
139    /// Scores by candidate
140    #[serde(default)]
141    pub frontier_scores: HashMap<String, f64>,
142    /// Frontier size
143    #[serde(default)]
144    pub frontier_size: i32,
145    /// Best optimistic score
146    #[serde(default)]
147    pub optimistic_score: Option<f64>,
148    /// Generation number
149    #[serde(default)]
150    pub generation: Option<i32>,
151}
152
153/// Overall GEPA progress state.
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct GEPAProgress {
156    /// Current phase: "init", "optimization", "validation", "complete", "failed"
157    pub phase: String,
158    /// Rollouts completed
159    pub rollouts_completed: i32,
160    /// Total rollouts planned
161    pub rollouts_total: i32,
162    /// Generations completed
163    pub generations_completed: i32,
164    /// Candidates evaluated
165    pub candidates_evaluated: i32,
166    /// Current best score
167    pub best_score: f64,
168    /// Baseline score for lift calculation
169    pub baseline_score: Option<f64>,
170    /// Elapsed time in seconds
171    pub elapsed_seconds: f64,
172    /// Estimated time remaining
173    pub eta_seconds: Option<f64>,
174    /// Finish reason if complete
175    pub finish_reason: Option<String>,
176}
177
178impl Default for GEPAProgress {
179    fn default() -> Self {
180        Self {
181            phase: "init".to_string(),
182            rollouts_completed: 0,
183            rollouts_total: 0,
184            generations_completed: 0,
185            candidates_evaluated: 0,
186            best_score: 0.0,
187            baseline_score: None,
188            elapsed_seconds: 0.0,
189            eta_seconds: None,
190            finish_reason: None,
191        }
192    }
193}
194
195impl GEPAProgress {
196    /// Calculate progress percentage.
197    pub fn progress_pct(&self) -> f64 {
198        if self.rollouts_total > 0 {
199            (self.rollouts_completed as f64 / self.rollouts_total as f64) * 100.0
200        } else {
201            0.0
202        }
203    }
204
205    /// Calculate lift over baseline.
206    pub fn lift(&self) -> Option<f64> {
207        self.baseline_score.map(|b| {
208            if b > 0.0 {
209                (self.best_score - b) / b
210            } else {
211                0.0
212            }
213        })
214    }
215}
216
217/// Progress tracker that aggregates events into state.
218pub struct ProgressTracker {
219    /// Overall progress
220    pub progress: GEPAProgress,
221    /// All evaluated candidates
222    pub candidates: Vec<CandidateInfo>,
223    /// Candidates indexed by ID
224    candidates_by_id: HashMap<String, usize>,
225    /// Baseline information
226    pub baseline: Option<BaselineInfo>,
227    /// Current Pareto frontier
228    pub frontier: Vec<String>,
229    /// Frontier update history
230    pub frontier_history: Vec<FrontierUpdate>,
231    /// Generation history
232    pub generation_history: Vec<GenerationInfo>,
233    /// Start time for elapsed calculation
234    start_time: Option<Instant>,
235    /// Last event sequence number
236    pub last_seq: i64,
237}
238
239/// Generation summary info.
240#[derive(Debug, Clone, Default, Serialize, Deserialize)]
241pub struct GenerationInfo {
242    /// Generation number
243    pub generation: i32,
244    /// Best accuracy in generation
245    pub best_accuracy: f64,
246    /// Candidates proposed
247    pub candidates_proposed: i32,
248    /// Candidates accepted
249    pub candidates_accepted: i32,
250}
251
252impl Default for ProgressTracker {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258impl ProgressTracker {
259    /// Create a new progress tracker.
260    pub fn new() -> Self {
261        Self {
262            progress: GEPAProgress::default(),
263            candidates: Vec::new(),
264            candidates_by_id: HashMap::new(),
265            baseline: None,
266            frontier: Vec::new(),
267            frontier_history: Vec::new(),
268            generation_history: Vec::new(),
269            start_time: None,
270            last_seq: -1,
271        }
272    }
273
274    /// Get the current best score.
275    pub fn best_score(&self) -> f64 {
276        self.progress.best_score
277    }
278
279    /// Get baseline score.
280    pub fn baseline_score(&self) -> Option<f64> {
281        self.progress.baseline_score
282    }
283
284    /// Get lift over baseline.
285    pub fn lift(&self) -> Option<f64> {
286        self.progress.lift()
287    }
288
289    /// Get current frontier candidates.
290    pub fn current_frontier(&self) -> &[String] {
291        &self.frontier
292    }
293
294    /// Update tracker with an event.
295    pub fn update(&mut self, event: &ParsedEvent) {
296        // Start timer on first event
297        if self.start_time.is_none() {
298            self.start_time = Some(Instant::now());
299        }
300
301        // Update elapsed time
302        if let Some(start) = self.start_time {
303            self.progress.elapsed_seconds = start.elapsed().as_secs_f64();
304        }
305
306        // Track sequence
307        if let Some(seq) = event.seq {
308            if seq > self.last_seq {
309                self.last_seq = seq;
310            }
311        }
312
313        // Handle by category
314        match event.category {
315            EventCategory::Baseline => self.handle_baseline(event),
316            EventCategory::Candidate => self.handle_candidate(event),
317            EventCategory::Frontier => self.handle_frontier(event),
318            EventCategory::Progress => self.handle_progress(event),
319            EventCategory::Generation => self.handle_generation(event),
320            EventCategory::Complete => self.handle_complete(event),
321            EventCategory::Termination => self.handle_termination(event),
322            EventCategory::Validation => self.handle_validation(event),
323            _ => {}
324        }
325    }
326
327    fn handle_baseline(&mut self, event: &ParsedEvent) {
328        let data = EventParser::parse_baseline(event);
329
330        self.baseline = Some(BaselineInfo {
331            accuracy: data.accuracy,
332            objectives: data.objectives,
333            val_accuracy: None,
334            instance_scores: data.instance_scores.unwrap_or_default(),
335        });
336
337        if let Some(acc) = data.accuracy {
338            self.progress.baseline_score = Some(acc);
339            // Initialize best score to baseline
340            if self.progress.best_score == 0.0 {
341                self.progress.best_score = acc;
342            }
343        }
344
345        self.progress.phase = "optimization".to_string();
346    }
347
348    fn handle_candidate(&mut self, event: &ParsedEvent) {
349        let data = EventParser::parse_candidate(event);
350
351        let candidate = CandidateInfo {
352            candidate_id: data.candidate_id.clone(),
353            accuracy: data.accuracy,
354            objectives: data.objectives,
355            val_accuracy: None,
356            train_accuracy: data.accuracy,
357            generation: data.generation,
358            parent_id: data.parent_id,
359            is_pareto: data.is_pareto,
360            accepted: data.accepted,
361            mutation_type: data.mutation_type,
362            token_usage: None,
363            cost_usd: None,
364            timestamp: self.progress.elapsed_seconds,
365            timestamp_ms: event.timestamp_ms,
366        };
367
368        // Update best score
369        if let Some(acc) = data.accuracy {
370            if acc > self.progress.best_score {
371                self.progress.best_score = acc;
372            }
373        }
374
375        // Store candidate
376        let idx = self.candidates.len();
377        self.candidates.push(candidate);
378        self.candidates_by_id.insert(data.candidate_id, idx);
379        self.progress.candidates_evaluated += 1;
380    }
381
382    fn handle_frontier(&mut self, event: &ParsedEvent) {
383        let data = EventParser::parse_frontier(event);
384
385        self.frontier = data.frontier.clone();
386
387        if let Some(best) = data.best_score {
388            if best > self.progress.best_score {
389                self.progress.best_score = best;
390            }
391        }
392
393        let update = FrontierUpdate {
394            timestamp: self.progress.elapsed_seconds,
395            added: data.added,
396            removed: data.removed,
397            frontier: data.frontier,
398            frontier_scores: data.frontier_scores.unwrap_or_default(),
399            frontier_size: data.frontier_size,
400            optimistic_score: data.best_score,
401            generation: None,
402        };
403        self.frontier_history.push(update);
404    }
405
406    fn handle_progress(&mut self, event: &ParsedEvent) {
407        let data = EventParser::parse_progress(event);
408
409        self.progress.rollouts_completed = data.rollouts_completed;
410        if let Some(total) = data.rollouts_total {
411            self.progress.rollouts_total = total;
412        }
413
414        if let Some(best) = data.best_score {
415            if best > self.progress.best_score {
416                self.progress.best_score = best;
417            }
418        }
419
420        if let Some(baseline) = data.baseline_score {
421            if self.progress.baseline_score.is_none() {
422                self.progress.baseline_score = Some(baseline);
423            }
424        }
425
426        // Estimate ETA
427        if self.progress.rollouts_total > 0 && self.progress.rollouts_completed > 0 {
428            let remaining = self.progress.rollouts_total - self.progress.rollouts_completed;
429            let rate = self.progress.elapsed_seconds / self.progress.rollouts_completed as f64;
430            self.progress.eta_seconds = Some(remaining as f64 * rate);
431        }
432    }
433
434    fn handle_generation(&mut self, event: &ParsedEvent) {
435        let data = EventParser::parse_generation(event);
436
437        self.progress.generations_completed = data.generation;
438
439        let info = GenerationInfo {
440            generation: data.generation,
441            best_accuracy: data.best_accuracy,
442            candidates_proposed: data.candidates_proposed,
443            candidates_accepted: data.candidates_accepted,
444        };
445        self.generation_history.push(info);
446    }
447
448    fn handle_complete(&mut self, event: &ParsedEvent) {
449        let data = EventParser::parse_complete(event);
450
451        self.progress.phase = "complete".to_string();
452        self.progress.finish_reason = data.finish_reason;
453
454        if let Some(best) = data.best_score {
455            self.progress.best_score = best;
456        }
457
458        if let Some(baseline) = data.baseline_score {
459            self.progress.baseline_score = Some(baseline);
460        }
461    }
462
463    fn handle_termination(&mut self, event: &ParsedEvent) {
464        let data = EventParser::parse_termination(event);
465
466        self.progress.phase = "complete".to_string();
467        self.progress.finish_reason = Some(data.reason);
468    }
469
470    fn handle_validation(&mut self, event: &ParsedEvent) {
471        self.progress.phase = "validation".to_string();
472
473        // Update candidate validation scores if provided
474        if let Some(candidate_id) = event.data.get("candidate_id").and_then(|v| v.as_str()) {
475            if let Some(val_score) = event.data.get("val_accuracy").and_then(|v| v.as_f64()) {
476                if let Some(&idx) = self.candidates_by_id.get(candidate_id) {
477                    self.candidates[idx].val_accuracy = Some(val_score);
478                }
479            }
480        }
481    }
482
483    /// Get a summary dict for serialization.
484    pub fn to_summary(&self) -> serde_json::Value {
485        serde_json::json!({
486            "phase": self.progress.phase,
487            "rollouts_completed": self.progress.rollouts_completed,
488            "rollouts_total": self.progress.rollouts_total,
489            "candidates_evaluated": self.progress.candidates_evaluated,
490            "generations_completed": self.progress.generations_completed,
491            "best_score": self.progress.best_score,
492            "baseline_score": self.progress.baseline_score,
493            "lift": self.lift(),
494            "elapsed_seconds": self.progress.elapsed_seconds,
495            "frontier_size": self.frontier.len(),
496        })
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use serde_json::json;
504
505    #[test]
506    fn test_progress_default() {
507        let progress = GEPAProgress::default();
508        assert_eq!(progress.phase, "init");
509        assert_eq!(progress.progress_pct(), 0.0);
510        assert!(progress.lift().is_none());
511    }
512
513    #[test]
514    fn test_progress_lift() {
515        let mut progress = GEPAProgress::default();
516        progress.baseline_score = Some(0.5);
517        progress.best_score = 0.75;
518
519        let lift = progress.lift().unwrap();
520        assert!((lift - 0.5).abs() < 0.001); // 50% lift
521    }
522
523    #[test]
524    fn test_tracker_baseline() {
525        let mut tracker = ProgressTracker::new();
526
527        let event = EventParser::parse(&json!({
528            "type": "learning.policy.gepa.baseline",
529            "seq": 1,
530            "data": { "accuracy": 0.72 }
531        }));
532
533        tracker.update(&event);
534
535        assert!(tracker.baseline.is_some());
536        assert_eq!(tracker.baseline_score(), Some(0.72));
537        assert_eq!(tracker.progress.phase, "optimization");
538    }
539
540    #[test]
541    fn test_tracker_candidate() {
542        let mut tracker = ProgressTracker::new();
543
544        // First baseline
545        tracker.update(&EventParser::parse(&json!({
546            "type": "learning.policy.gepa.baseline",
547            "data": { "accuracy": 0.72 }
548        })));
549
550        // Then candidate
551        tracker.update(&EventParser::parse(&json!({
552            "type": "learning.policy.gepa.candidate.evaluated",
553            "seq": 2,
554            "data": {
555                "candidate_id": "cand_1",
556                "accuracy": 0.85,
557                "accepted": true,
558                "generation": 1
559            }
560        })));
561
562        assert_eq!(tracker.candidates.len(), 1);
563        assert_eq!(tracker.best_score(), 0.85);
564        assert_eq!(tracker.progress.candidates_evaluated, 1);
565    }
566
567    #[test]
568    fn test_tracker_frontier() {
569        let mut tracker = ProgressTracker::new();
570
571        tracker.update(&EventParser::parse(&json!({
572            "type": "learning.policy.gepa.frontier_updated",
573            "data": {
574                "frontier": ["cand_1", "cand_2"],
575                "best_score": 0.88
576            }
577        })));
578
579        assert_eq!(tracker.frontier.len(), 2);
580        assert_eq!(tracker.frontier_history.len(), 1);
581        assert_eq!(tracker.best_score(), 0.88);
582    }
583
584    #[test]
585    fn test_tracker_complete() {
586        let mut tracker = ProgressTracker::new();
587
588        tracker.update(&EventParser::parse(&json!({
589            "type": "learning.policy.gepa.job.completed",
590            "data": {
591                "best_score": 0.92,
592                "baseline_score": 0.72,
593                "finish_reason": "budget_exhausted"
594            }
595        })));
596
597        assert_eq!(tracker.progress.phase, "complete");
598        assert_eq!(tracker.progress.finish_reason, Some("budget_exhausted".to_string()));
599        assert_eq!(tracker.best_score(), 0.92);
600    }
601}