Skip to main content

synth_ai_core/orchestration/
progress.rs

1//! Progress tracking for optimization jobs.
2//!
3//! This module provides progress tracking and aggregation for GEPA/MIPRO
4//! optimization jobs.
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::time::Instant;
10
11use super::events::{EventCategory, EventParser, ParsedEvent};
12
13/// Token usage statistics.
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct TokenUsage {
16    /// Input/prompt tokens
17    pub prompt_tokens: i64,
18    /// Output/completion tokens
19    pub completion_tokens: i64,
20    /// Total tokens
21    pub total_tokens: i64,
22    /// Reasoning tokens (for o1-style models)
23    #[serde(default)]
24    pub reasoning_tokens: i64,
25    /// Cached tokens
26    #[serde(default)]
27    pub cached_tokens: i64,
28}
29
30impl TokenUsage {
31    /// Create from prompt and completion counts.
32    pub fn new(prompt: i64, completion: i64) -> Self {
33        Self {
34            prompt_tokens: prompt,
35            completion_tokens: completion,
36            total_tokens: prompt + completion,
37            reasoning_tokens: 0,
38            cached_tokens: 0,
39        }
40    }
41}
42
43/// Stage information for multi-stage prompts.
44#[derive(Debug, Clone, Default, Serialize, Deserialize)]
45pub struct StageInfo {
46    /// Instruction text
47    pub instruction: String,
48    /// Optional rules/constraints
49    #[serde(default)]
50    pub rules: HashMap<String, Value>,
51    /// Optional temperature override
52    #[serde(default)]
53    pub temperature: Option<f64>,
54    /// Optional prompt variants
55    #[serde(default)]
56    pub prompts: Option<Vec<String>>,
57}
58
59/// Seed metadata for evaluated seeds.
60#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct SeedInfo {
62    pub seed: i64,
63    #[serde(default)]
64    pub query: String,
65    #[serde(default)]
66    pub expected: String,
67    #[serde(default)]
68    pub predicted: Option<String>,
69    #[serde(default)]
70    pub correct: Option<bool>,
71    #[serde(default, alias = "score")]
72    pub reward: Option<f64>,
73}
74
75/// Rollout sample details.
76#[derive(Debug, Clone, Default, Serialize, Deserialize)]
77pub struct RolloutSample {
78    pub seed: i64,
79    #[serde(default)]
80    pub query: String,
81    #[serde(default)]
82    pub expected: String,
83    #[serde(default)]
84    pub predicted: String,
85    #[serde(default)]
86    pub correct: bool,
87}
88
89/// Information about a single candidate.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct CandidateInfo {
92    /// Unique candidate ID
93    pub candidate_id: String,
94    /// Reward on training set
95    #[serde(default, alias = "accuracy")]
96    pub reward: Option<f64>,
97    /// Multi-objective scores
98    #[serde(default)]
99    pub objectives: Option<HashMap<String, f64>>,
100    /// Validation reward (if validation phase completed)
101    #[serde(default, alias = "val_accuracy")]
102    pub val_reward: Option<f64>,
103    /// Training reward
104    #[serde(default, alias = "train_accuracy")]
105    pub train_reward: Option<f64>,
106    /// Generation number
107    #[serde(default)]
108    pub generation: Option<i32>,
109    /// Parent candidate ID (for mutations)
110    #[serde(default)]
111    pub parent_id: Option<String>,
112    /// Whether on Pareto frontier
113    #[serde(default)]
114    pub is_pareto: bool,
115    /// Whether accepted into population
116    #[serde(default)]
117    pub accepted: bool,
118    /// Type of mutation used
119    #[serde(default)]
120    pub mutation_type: Option<String>,
121    /// Token usage for this candidate
122    #[serde(default)]
123    pub token_usage: Option<TokenUsage>,
124    /// Cost in USD
125    #[serde(default)]
126    pub cost_usd: Option<f64>,
127    /// Unix timestamp when evaluated
128    #[serde(default)]
129    pub timestamp: f64,
130    /// Timestamp in milliseconds
131    #[serde(default)]
132    pub timestamp_ms: Option<i64>,
133    /// First-class program stages
134    #[serde(default)]
135    pub stages: HashMap<String, StageInfo>,
136    /// Prompt summary for compatibility
137    #[serde(default)]
138    pub prompt_summary: Option<String>,
139    /// Mutation params
140    #[serde(default)]
141    pub mutation_params: Option<HashMap<String, Value>>,
142    /// Transformation details
143    #[serde(default)]
144    pub transformation: Option<HashMap<String, Value>>,
145    /// Seed rewards
146    #[serde(default, alias = "seed_scores")]
147    pub seed_rewards: Vec<Value>,
148    /// Seeds evaluated
149    #[serde(default)]
150    pub seeds_evaluated: Vec<i64>,
151    /// Seed metadata
152    #[serde(default)]
153    pub seed_info: Vec<SeedInfo>,
154    /// Rollout samples
155    #[serde(default)]
156    pub rollout_sample: Vec<RolloutSample>,
157    /// Evaluation duration in ms
158    #[serde(default)]
159    pub evaluation_duration_ms: Option<i64>,
160    /// Minibatch rewards
161    #[serde(default, alias = "minibatch_scores")]
162    pub minibatch_rewards: Vec<f64>,
163    /// Skip reason
164    #[serde(default)]
165    pub skip_reason: Option<String>,
166    /// Raw event data for debugging
167    #[serde(default)]
168    pub raw_data: HashMap<String, Value>,
169}
170
171impl Default for CandidateInfo {
172    fn default() -> Self {
173        Self {
174            candidate_id: String::new(),
175            reward: None,
176            objectives: None,
177            val_reward: None,
178            train_reward: None,
179            generation: None,
180            parent_id: None,
181            is_pareto: false,
182            accepted: false,
183            mutation_type: None,
184            token_usage: None,
185            cost_usd: None,
186            timestamp: 0.0,
187            timestamp_ms: None,
188            stages: HashMap::new(),
189            prompt_summary: None,
190            mutation_params: None,
191            transformation: None,
192            seed_rewards: Vec::new(),
193            seeds_evaluated: Vec::new(),
194            seed_info: Vec::new(),
195            rollout_sample: Vec::new(),
196            evaluation_duration_ms: None,
197            minibatch_rewards: Vec::new(),
198            skip_reason: None,
199            raw_data: HashMap::new(),
200        }
201    }
202}
203
204/// Baseline information.
205#[derive(Debug, Clone, Default, Serialize, Deserialize)]
206pub struct BaselineInfo {
207    /// Baseline reward
208    #[serde(alias = "accuracy")]
209    pub reward: Option<f64>,
210    /// Multi-objective scores
211    #[serde(default)]
212    pub objectives: Option<HashMap<String, f64>>,
213    /// Validation reward
214    #[serde(default, alias = "val_accuracy")]
215    pub val_reward: Option<f64>,
216    /// Per-instance rewards
217    #[serde(default, alias = "instance_scores")]
218    pub instance_rewards: Vec<f64>,
219    /// Per-instance objectives
220    #[serde(default)]
221    pub instance_objectives: Option<Vec<HashMap<String, f64>>>,
222    /// Seeds evaluated
223    #[serde(default)]
224    pub seeds_evaluated: Vec<i64>,
225    /// Prompt configuration (if provided)
226    #[serde(default)]
227    pub prompt: Option<Value>,
228    /// Rollout samples (if provided)
229    #[serde(default)]
230    pub rollout_sample: Vec<RolloutSample>,
231}
232
233/// Frontier update record.
234#[derive(Debug, Clone, Default, Serialize, Deserialize)]
235pub struct FrontierUpdate {
236    /// Update timestamp
237    pub timestamp: f64,
238    /// Candidates added
239    #[serde(default)]
240    pub added: Vec<String>,
241    /// Candidates removed
242    #[serde(default)]
243    pub removed: Vec<String>,
244    /// Current frontier
245    #[serde(default)]
246    pub frontier: Vec<String>,
247    /// Rewards by candidate
248    #[serde(default, alias = "frontier_scores")]
249    pub frontier_rewards: HashMap<String, f64>,
250    /// Objective scores by candidate (if provided)
251    #[serde(default)]
252    pub frontier_objectives: Option<Vec<HashMap<String, f64>>>,
253    /// Frontier size
254    #[serde(default)]
255    pub frontier_size: i32,
256    /// Best optimistic reward
257    #[serde(default, alias = "optimistic_score")]
258    pub optimistic_reward: Option<f64>,
259    /// Generation number
260    #[serde(default)]
261    pub generation: Option<i32>,
262    /// Baseline reward (if provided)
263    #[serde(default, alias = "baseline_score")]
264    pub baseline_reward: Option<f64>,
265    /// Timestamp in milliseconds (if provided)
266    #[serde(default)]
267    pub timestamp_ms: Option<i64>,
268}
269
270/// Overall GEPA progress state.
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct GEPAProgress {
273    /// Current phase: "init", "optimization", "validation", "complete", "failed"
274    pub phase: String,
275    /// Rollouts completed
276    pub rollouts_completed: i32,
277    /// Total rollouts planned
278    pub rollouts_total: i32,
279    /// Generations completed
280    pub generations_completed: i32,
281    /// Candidates evaluated
282    pub candidates_evaluated: i32,
283    /// Current best reward
284    #[serde(alias = "best_score")]
285    pub best_reward: f64,
286    /// Baseline reward for lift calculation
287    #[serde(alias = "baseline_score")]
288    pub baseline_reward: Option<f64>,
289    /// Elapsed time in seconds
290    pub elapsed_seconds: f64,
291    /// Estimated time remaining
292    pub eta_seconds: Option<f64>,
293    /// Finish reason if complete
294    pub finish_reason: Option<String>,
295}
296
297impl Default for GEPAProgress {
298    fn default() -> Self {
299        Self {
300            phase: "init".to_string(),
301            rollouts_completed: 0,
302            rollouts_total: 0,
303            generations_completed: 0,
304            candidates_evaluated: 0,
305            best_reward: 0.0,
306            baseline_reward: None,
307            elapsed_seconds: 0.0,
308            eta_seconds: None,
309            finish_reason: None,
310        }
311    }
312}
313
314impl GEPAProgress {
315    /// Calculate progress percentage.
316    pub fn progress_pct(&self) -> f64 {
317        if self.rollouts_total > 0 {
318            (self.rollouts_completed as f64 / self.rollouts_total as f64) * 100.0
319        } else {
320            0.0
321        }
322    }
323
324    /// Calculate lift over baseline.
325    pub fn lift(&self) -> Option<f64> {
326        self.baseline_reward.map(|b| {
327            if b > 0.0 {
328                (self.best_reward - b) / b
329            } else {
330                0.0
331            }
332        })
333    }
334}
335
336/// Progress tracker that aggregates events into state.
337pub struct ProgressTracker {
338    /// Overall progress
339    pub progress: GEPAProgress,
340    /// All evaluated candidates
341    pub candidates: Vec<CandidateInfo>,
342    /// Candidates indexed by ID
343    candidates_by_id: HashMap<String, usize>,
344    /// Baseline information
345    pub baseline: Option<BaselineInfo>,
346    /// Current Pareto frontier
347    pub frontier: Vec<String>,
348    /// Frontier update history
349    pub frontier_history: Vec<FrontierUpdate>,
350    /// Generation history
351    pub generation_history: Vec<GenerationInfo>,
352    /// Start time for elapsed calculation
353    start_time: Option<Instant>,
354    /// Last event sequence number
355    pub last_seq: i64,
356}
357
358/// Generation summary info.
359#[derive(Debug, Clone, Default, Serialize, Deserialize)]
360pub struct GenerationInfo {
361    /// Generation number
362    pub generation: i32,
363    /// Best reward in generation
364    #[serde(alias = "best_accuracy")]
365    pub best_reward: f64,
366    /// Candidates proposed
367    pub candidates_proposed: i32,
368    /// Candidates accepted
369    pub candidates_accepted: i32,
370    /// Frontier size
371    #[serde(default)]
372    pub frontier_size: i32,
373    /// Child candidates
374    #[serde(default)]
375    pub children: Vec<Value>,
376    /// Generation duration ms
377    #[serde(default)]
378    pub duration_ms: Option<f64>,
379    /// Timestamp seconds
380    #[serde(default)]
381    pub timestamp: f64,
382}
383
384impl Default for ProgressTracker {
385    fn default() -> Self {
386        Self::new()
387    }
388}
389
390impl ProgressTracker {
391    fn extract_instance_rewards(value: &Value) -> Option<Vec<f64>> {
392        let instance_objectives = value.get("instance_objectives")?.as_array()?;
393        if instance_objectives.is_empty() {
394            return None;
395        }
396        let mut values = Vec::with_capacity(instance_objectives.len());
397        for item in instance_objectives {
398            let reward_val = if let Some(obj) = item.as_object() {
399                if let Some(objectives) = obj.get("objectives").and_then(|v| v.as_object()) {
400                    objectives.get("reward").and_then(|v| {
401                        v.as_f64()
402                            .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
403                    })
404                } else {
405                    obj.get("reward").and_then(|v| {
406                        v.as_f64()
407                            .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
408                    })
409                }
410            } else {
411                None
412            };
413            let reward_val = reward_val?;
414            values.push(reward_val);
415        }
416        if values.is_empty() {
417            None
418        } else {
419            Some(values)
420        }
421    }
422
423    /// Create a new progress tracker.
424    pub fn new() -> Self {
425        Self {
426            progress: GEPAProgress::default(),
427            candidates: Vec::new(),
428            candidates_by_id: HashMap::new(),
429            baseline: None,
430            frontier: Vec::new(),
431            frontier_history: Vec::new(),
432            generation_history: Vec::new(),
433            start_time: None,
434            last_seq: -1,
435        }
436    }
437
438    /// Get the current best reward.
439    pub fn best_reward(&self) -> f64 {
440        self.progress.best_reward
441    }
442
443    /// Get baseline reward.
444    pub fn baseline_reward(&self) -> Option<f64> {
445        self.progress.baseline_reward
446    }
447
448    /// Get lift over baseline.
449    pub fn lift(&self) -> Option<f64> {
450        self.progress.lift()
451    }
452
453    /// Get current frontier candidates.
454    pub fn current_frontier(&self) -> &[String] {
455        &self.frontier
456    }
457
458    /// Update tracker with an event.
459    pub fn update(&mut self, event: &ParsedEvent) {
460        // Start timer on first event
461        if self.start_time.is_none() {
462            self.start_time = Some(Instant::now());
463        }
464
465        // Update elapsed time
466        if let Some(start) = self.start_time {
467            self.progress.elapsed_seconds = start.elapsed().as_secs_f64();
468        }
469
470        // Track sequence
471        if let Some(seq) = event.seq {
472            if seq > self.last_seq {
473                self.last_seq = seq;
474            }
475        }
476
477        // Handle by category
478        match event.category {
479            EventCategory::Baseline => self.handle_baseline(event),
480            EventCategory::Candidate => self.handle_candidate(event),
481            EventCategory::Frontier => self.handle_frontier(event),
482            EventCategory::Progress => self.handle_progress(event),
483            EventCategory::Generation => self.handle_generation(event),
484            EventCategory::Complete => self.handle_complete(event),
485            EventCategory::Termination => self.handle_termination(event),
486            EventCategory::Validation => self.handle_validation(event),
487            _ => {}
488        }
489    }
490
491    fn handle_baseline(&mut self, event: &ParsedEvent) {
492        let data = EventParser::parse_baseline(event);
493
494        self.baseline = Some(BaselineInfo {
495            reward: data.reward,
496            objectives: data.objectives,
497            val_reward: None,
498            instance_rewards: data.instance_rewards.unwrap_or_default(),
499            instance_objectives: data.instance_objectives,
500            seeds_evaluated: event
501                .data
502                .get("seeds_evaluated")
503                .and_then(|v| v.as_array())
504                .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
505                .unwrap_or_default(),
506            prompt: event.data.get("prompt").cloned(),
507            rollout_sample: event
508                .data
509                .get("rollout_sample")
510                .and_then(|v| v.as_array())
511                .map(|arr| {
512                    arr.iter()
513                        .filter_map(|item| {
514                            serde_json::from_value::<RolloutSample>(item.clone()).ok()
515                        })
516                        .collect()
517                })
518                .unwrap_or_default(),
519        });
520
521        if let Some(acc) = data.reward {
522            self.progress.baseline_reward = Some(acc);
523            // Initialize best reward to baseline
524            if self.progress.best_reward == 0.0 {
525                self.progress.best_reward = acc;
526            }
527        }
528
529        self.progress.phase = "optimization".to_string();
530    }
531
532    fn handle_candidate(&mut self, event: &ParsedEvent) {
533        let data = EventParser::parse_candidate(event);
534
535        let is_baseline = event
536            .data
537            .get("is_baseline")
538            .and_then(|v| v.as_bool())
539            .unwrap_or(false)
540            || data.parent_id.is_none();
541
542        let mut merged_data = event.data.as_object().cloned().unwrap_or_default();
543        if let Some(program_candidate) = event
544            .data
545            .get("program_candidate")
546            .and_then(|v| v.as_object())
547        {
548            for (k, v) in program_candidate {
549                merged_data.insert(k.clone(), v.clone());
550            }
551        }
552        let merged_value = Value::Object(merged_data.clone());
553
554        if is_baseline && self.baseline.is_none() {
555            let candidate_view = merged_data.clone();
556            let candidate_value = Value::Object(candidate_view.clone());
557
558            let parse_f64_map = |val: Option<&Value>| -> Option<HashMap<String, f64>> {
559                let map = val?.as_object()?;
560                let mut out = HashMap::new();
561                for (k, v) in map {
562                    let val = v
563                        .as_f64()
564                        .or_else(|| v.as_str().and_then(|s| s.parse().ok()));
565                    let val = match val {
566                        Some(val) => val,
567                        None => return None,
568                    };
569                    out.insert(k.clone(), val);
570                }
571                Some(out)
572            };
573
574            // Try top-level objectives, then score.objectives
575            let objectives = parse_f64_map(candidate_view.get("objectives")).or_else(|| {
576                candidate_view
577                    .get("score")
578                    .and_then(|v| v.as_object())
579                    .and_then(|score| parse_f64_map(score.get("objectives")))
580            });
581
582            let accuracy = objectives
583                .as_ref()
584                .and_then(|m| m.get("reward").copied())
585                .or_else(|| candidate_view.get("reward").and_then(|v| v.as_f64()))
586                .or_else(|| candidate_view.get("accuracy").and_then(|v| v.as_f64()))
587                .or_else(|| candidate_view.get("score").and_then(|v| v.as_f64()))
588                .or_else(|| {
589                    candidate_view
590                        .get("score")
591                        .and_then(|v| v.as_object())
592                        .and_then(|score| {
593                            score
594                                .get("reward")
595                                .and_then(|v| v.as_f64())
596                                .or_else(|| score.get("mean_reward").and_then(|v| v.as_f64()))
597                        })
598                });
599
600            // Auto-derive objectives from accuracy if not found
601            let objectives = objectives.or_else(|| {
602                accuracy.map(|r| {
603                    let mut m = HashMap::new();
604                    m.insert("reward".to_string(), r);
605                    m
606                })
607            });
608
609            let instance_scores = candidate_view
610                .get("instance_scores")
611                .and_then(|v| v.as_array())
612                .and_then(|arr| {
613                    let mut out = Vec::with_capacity(arr.len());
614                    for item in arr {
615                        let val = item
616                            .as_f64()
617                            .or_else(|| item.as_str().and_then(|s| s.parse().ok()))?;
618                        out.push(val);
619                    }
620                    Some(out)
621                })
622                .or_else(|| Self::extract_instance_rewards(&candidate_value))
623                .unwrap_or_default();
624
625            let instance_objectives = candidate_view
626                .get("instance_objectives")
627                .and_then(|v| v.as_array())
628                .and_then(|arr| {
629                    let mut out = Vec::with_capacity(arr.len());
630                    for item in arr {
631                        let obj = item.as_object()?;
632                        let mut map = HashMap::new();
633                        for (k, v) in obj {
634                            let val = v
635                                .as_f64()
636                                .or_else(|| v.as_str().and_then(|s| s.parse().ok()))?;
637                            map.insert(k.clone(), val);
638                        }
639                        out.push(map);
640                    }
641                    Some(out)
642                });
643
644            self.baseline = Some(BaselineInfo {
645                reward: accuracy,
646                objectives,
647                val_reward: None,
648                instance_rewards: instance_scores,
649                instance_objectives,
650                seeds_evaluated: merged_data
651                    .get("seeds_evaluated")
652                    .and_then(|v| v.as_array())
653                    .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
654                    .unwrap_or_default(),
655                prompt: candidate_view.get("prompt").cloned(),
656                rollout_sample: candidate_view
657                    .get("rollout_sample")
658                    .and_then(|v| v.as_array())
659                    .map(|arr| {
660                        arr.iter()
661                            .filter_map(|item| {
662                                serde_json::from_value::<RolloutSample>(item.clone()).ok()
663                            })
664                            .collect()
665                    })
666                    .unwrap_or_default(),
667            });
668
669            if let Some(acc) = accuracy {
670                self.progress.baseline_reward = Some(acc);
671            }
672        }
673
674        if self.candidates_by_id.contains_key(&data.candidate_id) {
675            return;
676        }
677
678        let mut candidate = CandidateInfo {
679            candidate_id: data.candidate_id.clone(),
680            reward: data.reward,
681            objectives: data.objectives,
682            val_reward: None,
683            train_reward: data.reward,
684            generation: data.generation,
685            parent_id: data.parent_id,
686            is_pareto: data.is_pareto,
687            accepted: data.accepted,
688            mutation_type: data.mutation_type,
689            token_usage: None,
690            cost_usd: None,
691            timestamp: self.progress.elapsed_seconds,
692            timestamp_ms: event.timestamp_ms,
693            stages: HashMap::new(),
694            prompt_summary: None,
695            mutation_params: None,
696            transformation: None,
697            seed_rewards: Vec::new(),
698            seeds_evaluated: Vec::new(),
699            seed_info: Vec::new(),
700            rollout_sample: Vec::new(),
701            evaluation_duration_ms: None,
702            minibatch_rewards: Vec::new(),
703            skip_reason: None,
704            raw_data: HashMap::new(),
705        };
706
707        candidate.seeds_evaluated = merged_data
708            .get("seeds_evaluated")
709            .and_then(|v| v.as_array())
710            .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
711            .unwrap_or_default();
712
713        if let Some(val) = merged_data.get("val_accuracy").and_then(|v| v.as_f64()) {
714            candidate.val_reward = Some(val);
715        } else if let Some(val) = merged_data.get("full_score").and_then(|v| v.as_f64()) {
716            candidate.val_reward = Some(val);
717        }
718
719        if let Some(val) = merged_data.get("train_accuracy").and_then(|v| v.as_f64()) {
720            candidate.train_reward = Some(val);
721        } else if let Some(val) = merged_data.get("minibatch_score").and_then(|v| v.as_f64()) {
722            candidate.train_reward = Some(val);
723        }
724
725        if let Some(cost) = merged_data.get("cost_usd").and_then(|v| v.as_f64()) {
726            candidate.cost_usd = Some(cost);
727        }
728
729        if let Some(duration) = merged_data
730            .get("evaluation_duration_ms")
731            .and_then(|v| v.as_i64())
732        {
733            candidate.evaluation_duration_ms = Some(duration);
734        }
735
736        if let Some(scores) = merged_data
737            .get("minibatch_scores")
738            .and_then(|v| v.as_array())
739        {
740            candidate.minibatch_rewards = scores
741                .iter()
742                .filter_map(|v| {
743                    v.as_f64()
744                        .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
745                })
746                .collect();
747        } else if let Some(score) = merged_data.get("minibatch_score").and_then(|v| v.as_f64()) {
748            candidate.minibatch_rewards = vec![score];
749        }
750
751        if let Some(reason) = merged_data.get("skip_reason").and_then(|v| v.as_str()) {
752            candidate.skip_reason = Some(reason.to_string());
753        }
754
755        if let Some(scores) = merged_data.get("seed_scores").and_then(|v| v.as_array()) {
756            candidate.seed_rewards = scores.clone();
757        }
758
759        if let Some(info) = merged_data.get("seed_info").and_then(|v| v.as_array()) {
760            let mut seed_info = Vec::with_capacity(info.len());
761            for item in info {
762                if let Ok(parsed) = serde_json::from_value::<SeedInfo>(item.clone()) {
763                    seed_info.push(parsed);
764                }
765            }
766            candidate.seed_info = seed_info;
767        }
768
769        if let Some(samples) = merged_data.get("rollout_sample").and_then(|v| v.as_array()) {
770            let mut rollout_sample = Vec::with_capacity(samples.len());
771            for item in samples {
772                if let Ok(parsed) = serde_json::from_value::<RolloutSample>(item.clone()) {
773                    rollout_sample.push(parsed);
774                }
775            }
776            candidate.rollout_sample = rollout_sample;
777        }
778
779        if let Some(token_usage) = merged_data.get("token_usage") {
780            if let Ok(parsed) = serde_json::from_value::<TokenUsage>(token_usage.clone()) {
781                candidate.token_usage = Some(parsed);
782            }
783        }
784
785        if let Some(mutation_params) = merged_data
786            .get("mutation_params")
787            .and_then(|v| v.as_object())
788        {
789            candidate.mutation_params = Some(mutation_params.clone().into_iter().collect());
790        }
791
792        if let Some(transformation) = merged_data
793            .get("transformation")
794            .and_then(|v| v.as_object())
795        {
796            candidate.transformation = Some(transformation.clone().into_iter().collect());
797        }
798
799        if let Some(stages) = merged_data.get("stages").and_then(|v| v.as_object()) {
800            let mut stage_map = HashMap::new();
801            for (key, value) in stages {
802                if let Ok(stage) = serde_json::from_value::<StageInfo>(value.clone()) {
803                    stage_map.insert(key.clone(), stage);
804                }
805            }
806            candidate.stages = stage_map;
807        }
808
809        if candidate.prompt_summary.is_none() {
810            if let Some(summary) = merged_data
811                .get("prompt_summary")
812                .and_then(|v| v.as_str())
813                .or_else(|| merged_data.get("prompt_text").and_then(|v| v.as_str()))
814            {
815                candidate.prompt_summary = Some(summary.to_string());
816            } else if !candidate.stages.is_empty() {
817                let mut parts = Vec::new();
818                let mut stage_ids: Vec<_> = candidate.stages.keys().collect();
819                stage_ids.sort();
820                for stage_id in stage_ids {
821                    if let Some(stage) = candidate.stages.get(stage_id) {
822                        if !stage.instruction.is_empty() {
823                            parts.push(format!(
824                                "[{}]: {}",
825                                stage_id.to_uppercase(),
826                                stage.instruction
827                            ));
828                        }
829                    }
830                }
831                if !parts.is_empty() {
832                    candidate.prompt_summary = Some(parts.join("\n"));
833                }
834            }
835        }
836
837        if let Some(raw) = merged_value.as_object() {
838            candidate.raw_data = raw.clone().into_iter().collect();
839        }
840
841        // Update best reward
842        if let Some(acc) = data.reward {
843            if acc > self.progress.best_reward {
844                self.progress.best_reward = acc;
845            }
846        }
847
848        // Store candidate
849        let idx = self.candidates.len();
850        self.candidates.push(candidate);
851        self.candidates_by_id.insert(data.candidate_id, idx);
852        self.progress.candidates_evaluated += 1;
853    }
854
855    fn handle_frontier(&mut self, event: &ParsedEvent) {
856        let data = EventParser::parse_frontier(event);
857
858        self.frontier = data.frontier.clone();
859
860        if let Some(best) = data.best_reward {
861            if best > self.progress.best_reward {
862                self.progress.best_reward = best;
863            }
864        }
865
866        let update = FrontierUpdate {
867            timestamp: self.progress.elapsed_seconds,
868            added: data.added,
869            removed: data.removed,
870            frontier: data.frontier,
871            frontier_rewards: data.frontier_rewards.unwrap_or_default(),
872            frontier_objectives: data.frontier_objectives,
873            frontier_size: data.frontier_size,
874            optimistic_reward: data.best_reward,
875            generation: event
876                .data
877                .get("generation")
878                .and_then(|v| v.as_i64())
879                .map(|v| v as i32),
880            baseline_reward: event.data.get("baseline_score").and_then(|v| v.as_f64()),
881            timestamp_ms: event.timestamp_ms,
882        };
883        self.frontier_history.push(update);
884    }
885
886    fn handle_progress(&mut self, event: &ParsedEvent) {
887        let data = EventParser::parse_progress(event);
888
889        self.progress.rollouts_completed = data.rollouts_completed;
890        if let Some(total) = data.rollouts_total {
891            self.progress.rollouts_total = total;
892        }
893
894        if let Some(best) = data.best_reward {
895            if best > self.progress.best_reward {
896                self.progress.best_reward = best;
897            }
898        }
899
900        if let Some(baseline) = data.baseline_reward {
901            if self.progress.baseline_reward.is_none() {
902                self.progress.baseline_reward = Some(baseline);
903            }
904        }
905
906        // Estimate ETA
907        if self.progress.rollouts_total > 0 && self.progress.rollouts_completed > 0 {
908            let remaining = self.progress.rollouts_total - self.progress.rollouts_completed;
909            let rate = self.progress.elapsed_seconds / self.progress.rollouts_completed as f64;
910            self.progress.eta_seconds = Some(remaining as f64 * rate);
911        }
912    }
913
914    fn handle_generation(&mut self, event: &ParsedEvent) {
915        let data = EventParser::parse_generation(event);
916
917        self.progress.generations_completed = data.generation;
918
919        let info = GenerationInfo {
920            generation: data.generation,
921            best_reward: data.best_reward,
922            candidates_proposed: data.candidates_proposed,
923            candidates_accepted: data.candidates_accepted,
924            frontier_size: event
925                .data
926                .get("frontier_size")
927                .and_then(|v| v.as_i64())
928                .map(|v| v as i32)
929                .unwrap_or(0),
930            children: event
931                .data
932                .get("children")
933                .and_then(|v| v.as_array())
934                .cloned()
935                .unwrap_or_default(),
936            duration_ms: event.data.get("duration_ms").and_then(|v| v.as_f64()),
937            timestamp: event
938                .data
939                .get("timestamp")
940                .and_then(|v| v.as_f64())
941                .unwrap_or(self.progress.elapsed_seconds),
942        };
943        self.generation_history.push(info);
944    }
945
946    fn handle_complete(&mut self, event: &ParsedEvent) {
947        let data = EventParser::parse_complete(event);
948
949        self.progress.phase = "complete".to_string();
950        self.progress.finish_reason = data.finish_reason;
951
952        if let Some(best) = data.best_reward {
953            self.progress.best_reward = best;
954        }
955
956        if let Some(baseline) = data.baseline_reward {
957            self.progress.baseline_reward = Some(baseline);
958        }
959    }
960
961    fn handle_termination(&mut self, event: &ParsedEvent) {
962        let data = EventParser::parse_termination(event);
963
964        self.progress.phase = "complete".to_string();
965        self.progress.finish_reason = Some(data.reason);
966    }
967
968    fn handle_validation(&mut self, event: &ParsedEvent) {
969        self.progress.phase = "validation".to_string();
970
971        // Update candidate validation scores if provided
972        if let Some(candidate_id) = event.data.get("candidate_id").and_then(|v| v.as_str()) {
973            if let Some(val_score) = event.data.get("val_accuracy").and_then(|v| v.as_f64()) {
974                if let Some(&idx) = self.candidates_by_id.get(candidate_id) {
975                    self.candidates[idx].val_reward = Some(val_score);
976                }
977            }
978        }
979    }
980
981    /// Get a summary dict for serialization.
982    pub fn to_summary(&self) -> serde_json::Value {
983        serde_json::json!({
984            "phase": self.progress.phase,
985            "rollouts_completed": self.progress.rollouts_completed,
986            "rollouts_total": self.progress.rollouts_total,
987            "candidates_evaluated": self.progress.candidates_evaluated,
988            "generations_completed": self.progress.generations_completed,
989            "best_reward": self.progress.best_reward,
990            "baseline_reward": self.progress.baseline_reward,
991            "lift": self.lift(),
992            "elapsed_seconds": self.progress.elapsed_seconds,
993            "frontier_size": self.frontier.len(),
994        })
995    }
996}
997
998#[cfg(test)]
999mod tests {
1000    use super::*;
1001    use serde_json::json;
1002
1003    #[test]
1004    fn test_progress_default() {
1005        let progress = GEPAProgress::default();
1006        assert_eq!(progress.phase, "init");
1007        assert_eq!(progress.progress_pct(), 0.0);
1008        assert!(progress.lift().is_none());
1009    }
1010
1011    #[test]
1012    fn test_progress_lift() {
1013        let mut progress = GEPAProgress::default();
1014        progress.baseline_reward = Some(0.5);
1015        progress.best_reward = 0.75;
1016
1017        let lift = progress.lift().unwrap();
1018        assert!((lift - 0.5).abs() < 0.001); // 50% lift
1019    }
1020
1021    #[test]
1022    fn test_tracker_baseline() {
1023        let mut tracker = ProgressTracker::new();
1024
1025        let event = EventParser::parse(&json!({
1026            "type": "learning.policy.gepa.baseline",
1027            "seq": 1,
1028            "data": { "accuracy": 0.72 }
1029        }));
1030
1031        tracker.update(&event);
1032
1033        assert!(tracker.baseline.is_some());
1034        assert_eq!(tracker.baseline_reward(), Some(0.72));
1035        assert_eq!(tracker.progress.phase, "optimization");
1036    }
1037
1038    #[test]
1039    fn test_tracker_candidate() {
1040        let mut tracker = ProgressTracker::new();
1041
1042        // First baseline
1043        tracker.update(&EventParser::parse(&json!({
1044            "type": "learning.policy.gepa.baseline",
1045            "data": { "accuracy": 0.72 }
1046        })));
1047
1048        // Then candidate
1049        tracker.update(&EventParser::parse(&json!({
1050            "type": "learning.policy.gepa.candidate.evaluated",
1051            "seq": 2,
1052            "data": {
1053                "candidate_id": "cand_1",
1054                "accuracy": 0.85,
1055                "accepted": true,
1056                "generation": 1
1057            }
1058        })));
1059
1060        assert_eq!(tracker.candidates.len(), 1);
1061        assert_eq!(tracker.best_reward(), 0.85);
1062        assert_eq!(tracker.progress.candidates_evaluated, 1);
1063    }
1064
1065    #[test]
1066    fn test_tracker_frontier() {
1067        let mut tracker = ProgressTracker::new();
1068
1069        tracker.update(&EventParser::parse(&json!({
1070            "type": "learning.policy.gepa.frontier_updated",
1071            "data": {
1072                "frontier": ["cand_1", "cand_2"],
1073                "best_score": 0.88
1074            }
1075        })));
1076
1077        assert_eq!(tracker.frontier.len(), 2);
1078        assert_eq!(tracker.frontier_history.len(), 1);
1079        assert_eq!(tracker.best_reward(), 0.88);
1080    }
1081
1082    #[test]
1083    fn test_tracker_complete() {
1084        let mut tracker = ProgressTracker::new();
1085
1086        tracker.update(&EventParser::parse(&json!({
1087            "type": "learning.policy.gepa.job.completed",
1088            "data": {
1089                "best_score": 0.92,
1090                "baseline_score": 0.72,
1091                "finish_reason": "budget_exhausted"
1092            }
1093        })));
1094
1095        assert_eq!(tracker.progress.phase, "complete");
1096        assert_eq!(
1097            tracker.progress.finish_reason,
1098            Some("budget_exhausted".to_string())
1099        );
1100        assert_eq!(tracker.best_reward(), 0.92);
1101    }
1102}