1use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::time::Instant;
10
11use super::events::{EventCategory, EventParser, ParsedEvent};
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct TokenUsage {
16 pub prompt_tokens: i64,
18 pub completion_tokens: i64,
20 pub total_tokens: i64,
22 #[serde(default)]
24 pub reasoning_tokens: i64,
25 #[serde(default)]
27 pub cached_tokens: i64,
28}
29
30impl TokenUsage {
31 pub fn new(prompt: i64, completion: i64) -> Self {
33 Self {
34 prompt_tokens: prompt,
35 completion_tokens: completion,
36 total_tokens: prompt + completion,
37 reasoning_tokens: 0,
38 cached_tokens: 0,
39 }
40 }
41}
42
43#[derive(Debug, Clone, Default, Serialize, Deserialize)]
45pub struct StageInfo {
46 pub instruction: String,
48 #[serde(default)]
50 pub rules: HashMap<String, Value>,
51 #[serde(default)]
53 pub temperature: Option<f64>,
54 #[serde(default)]
56 pub prompts: Option<Vec<String>>,
57}
58
59#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct SeedInfo {
62 pub seed: i64,
63 #[serde(default)]
64 pub query: String,
65 #[serde(default)]
66 pub expected: String,
67 #[serde(default)]
68 pub predicted: Option<String>,
69 #[serde(default)]
70 pub correct: Option<bool>,
71 #[serde(default, alias = "score")]
72 pub reward: Option<f64>,
73}
74
75#[derive(Debug, Clone, Default, Serialize, Deserialize)]
77pub struct RolloutSample {
78 pub seed: i64,
79 #[serde(default)]
80 pub query: String,
81 #[serde(default)]
82 pub expected: String,
83 #[serde(default)]
84 pub predicted: String,
85 #[serde(default)]
86 pub correct: bool,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct CandidateInfo {
92 pub candidate_id: String,
94 #[serde(default, alias = "accuracy")]
96 pub reward: Option<f64>,
97 #[serde(default)]
99 pub objectives: Option<HashMap<String, f64>>,
100 #[serde(default, alias = "val_accuracy")]
102 pub val_reward: Option<f64>,
103 #[serde(default, alias = "train_accuracy")]
105 pub train_reward: Option<f64>,
106 #[serde(default)]
108 pub generation: Option<i32>,
109 #[serde(default)]
111 pub parent_id: Option<String>,
112 #[serde(default)]
114 pub is_pareto: bool,
115 #[serde(default)]
117 pub accepted: bool,
118 #[serde(default)]
120 pub mutation_type: Option<String>,
121 #[serde(default)]
123 pub token_usage: Option<TokenUsage>,
124 #[serde(default)]
126 pub cost_usd: Option<f64>,
127 #[serde(default)]
129 pub timestamp: f64,
130 #[serde(default)]
132 pub timestamp_ms: Option<i64>,
133 #[serde(default)]
135 pub stages: HashMap<String, StageInfo>,
136 #[serde(default)]
138 pub prompt_summary: Option<String>,
139 #[serde(default)]
141 pub mutation_params: Option<HashMap<String, Value>>,
142 #[serde(default)]
144 pub transformation: Option<HashMap<String, Value>>,
145 #[serde(default, alias = "seed_scores")]
147 pub seed_rewards: Vec<Value>,
148 #[serde(default)]
150 pub seeds_evaluated: Vec<i64>,
151 #[serde(default)]
153 pub seed_info: Vec<SeedInfo>,
154 #[serde(default)]
156 pub rollout_sample: Vec<RolloutSample>,
157 #[serde(default)]
159 pub evaluation_duration_ms: Option<i64>,
160 #[serde(default, alias = "minibatch_scores")]
162 pub minibatch_rewards: Vec<f64>,
163 #[serde(default)]
165 pub skip_reason: Option<String>,
166 #[serde(default)]
168 pub raw_data: HashMap<String, Value>,
169}
170
171impl Default for CandidateInfo {
172 fn default() -> Self {
173 Self {
174 candidate_id: String::new(),
175 reward: None,
176 objectives: None,
177 val_reward: None,
178 train_reward: None,
179 generation: None,
180 parent_id: None,
181 is_pareto: false,
182 accepted: false,
183 mutation_type: None,
184 token_usage: None,
185 cost_usd: None,
186 timestamp: 0.0,
187 timestamp_ms: None,
188 stages: HashMap::new(),
189 prompt_summary: None,
190 mutation_params: None,
191 transformation: None,
192 seed_rewards: Vec::new(),
193 seeds_evaluated: Vec::new(),
194 seed_info: Vec::new(),
195 rollout_sample: Vec::new(),
196 evaluation_duration_ms: None,
197 minibatch_rewards: Vec::new(),
198 skip_reason: None,
199 raw_data: HashMap::new(),
200 }
201 }
202}
203
204#[derive(Debug, Clone, Default, Serialize, Deserialize)]
206pub struct BaselineInfo {
207 #[serde(alias = "accuracy")]
209 pub reward: Option<f64>,
210 #[serde(default)]
212 pub objectives: Option<HashMap<String, f64>>,
213 #[serde(default, alias = "val_accuracy")]
215 pub val_reward: Option<f64>,
216 #[serde(default, alias = "instance_scores")]
218 pub instance_rewards: Vec<f64>,
219 #[serde(default)]
221 pub instance_objectives: Option<Vec<HashMap<String, f64>>>,
222 #[serde(default)]
224 pub seeds_evaluated: Vec<i64>,
225 #[serde(default)]
227 pub prompt: Option<Value>,
228 #[serde(default)]
230 pub rollout_sample: Vec<RolloutSample>,
231}
232
233#[derive(Debug, Clone, Default, Serialize, Deserialize)]
235pub struct FrontierUpdate {
236 pub timestamp: f64,
238 #[serde(default)]
240 pub added: Vec<String>,
241 #[serde(default)]
243 pub removed: Vec<String>,
244 #[serde(default)]
246 pub frontier: Vec<String>,
247 #[serde(default, alias = "frontier_scores")]
249 pub frontier_rewards: HashMap<String, f64>,
250 #[serde(default)]
252 pub frontier_objectives: Option<Vec<HashMap<String, f64>>>,
253 #[serde(default)]
255 pub frontier_size: i32,
256 #[serde(default, alias = "optimistic_score")]
258 pub optimistic_reward: Option<f64>,
259 #[serde(default)]
261 pub generation: Option<i32>,
262 #[serde(default, alias = "baseline_score")]
264 pub baseline_reward: Option<f64>,
265 #[serde(default)]
267 pub timestamp_ms: Option<i64>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct GEPAProgress {
273 pub phase: String,
275 pub rollouts_completed: i32,
277 pub rollouts_total: i32,
279 pub generations_completed: i32,
281 pub candidates_evaluated: i32,
283 #[serde(alias = "best_score")]
285 pub best_reward: f64,
286 #[serde(alias = "baseline_score")]
288 pub baseline_reward: Option<f64>,
289 pub elapsed_seconds: f64,
291 pub eta_seconds: Option<f64>,
293 pub finish_reason: Option<String>,
295}
296
297impl Default for GEPAProgress {
298 fn default() -> Self {
299 Self {
300 phase: "init".to_string(),
301 rollouts_completed: 0,
302 rollouts_total: 0,
303 generations_completed: 0,
304 candidates_evaluated: 0,
305 best_reward: 0.0,
306 baseline_reward: None,
307 elapsed_seconds: 0.0,
308 eta_seconds: None,
309 finish_reason: None,
310 }
311 }
312}
313
314impl GEPAProgress {
315 pub fn progress_pct(&self) -> f64 {
317 if self.rollouts_total > 0 {
318 (self.rollouts_completed as f64 / self.rollouts_total as f64) * 100.0
319 } else {
320 0.0
321 }
322 }
323
324 pub fn lift(&self) -> Option<f64> {
326 self.baseline_reward.map(|b| {
327 if b > 0.0 {
328 (self.best_reward - b) / b
329 } else {
330 0.0
331 }
332 })
333 }
334}
335
336pub struct ProgressTracker {
338 pub progress: GEPAProgress,
340 pub candidates: Vec<CandidateInfo>,
342 candidates_by_id: HashMap<String, usize>,
344 pub baseline: Option<BaselineInfo>,
346 pub frontier: Vec<String>,
348 pub frontier_history: Vec<FrontierUpdate>,
350 pub generation_history: Vec<GenerationInfo>,
352 start_time: Option<Instant>,
354 pub last_seq: i64,
356}
357
358#[derive(Debug, Clone, Default, Serialize, Deserialize)]
360pub struct GenerationInfo {
361 pub generation: i32,
363 #[serde(alias = "best_accuracy")]
365 pub best_reward: f64,
366 pub candidates_proposed: i32,
368 pub candidates_accepted: i32,
370 #[serde(default)]
372 pub frontier_size: i32,
373 #[serde(default)]
375 pub children: Vec<Value>,
376 #[serde(default)]
378 pub duration_ms: Option<f64>,
379 #[serde(default)]
381 pub timestamp: f64,
382}
383
384impl Default for ProgressTracker {
385 fn default() -> Self {
386 Self::new()
387 }
388}
389
390impl ProgressTracker {
391 fn extract_instance_rewards(value: &Value) -> Option<Vec<f64>> {
392 let instance_objectives = value.get("instance_objectives")?.as_array()?;
393 if instance_objectives.is_empty() {
394 return None;
395 }
396 let mut values = Vec::with_capacity(instance_objectives.len());
397 for item in instance_objectives {
398 let reward_val = if let Some(obj) = item.as_object() {
399 if let Some(objectives) = obj.get("objectives").and_then(|v| v.as_object()) {
400 objectives.get("reward").and_then(|v| {
401 v.as_f64()
402 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
403 })
404 } else {
405 obj.get("reward").and_then(|v| {
406 v.as_f64()
407 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
408 })
409 }
410 } else {
411 None
412 };
413 let reward_val = reward_val?;
414 values.push(reward_val);
415 }
416 if values.is_empty() {
417 None
418 } else {
419 Some(values)
420 }
421 }
422
423 pub fn new() -> Self {
425 Self {
426 progress: GEPAProgress::default(),
427 candidates: Vec::new(),
428 candidates_by_id: HashMap::new(),
429 baseline: None,
430 frontier: Vec::new(),
431 frontier_history: Vec::new(),
432 generation_history: Vec::new(),
433 start_time: None,
434 last_seq: -1,
435 }
436 }
437
438 pub fn best_reward(&self) -> f64 {
440 self.progress.best_reward
441 }
442
443 pub fn baseline_reward(&self) -> Option<f64> {
445 self.progress.baseline_reward
446 }
447
448 pub fn lift(&self) -> Option<f64> {
450 self.progress.lift()
451 }
452
453 pub fn current_frontier(&self) -> &[String] {
455 &self.frontier
456 }
457
458 pub fn update(&mut self, event: &ParsedEvent) {
460 if self.start_time.is_none() {
462 self.start_time = Some(Instant::now());
463 }
464
465 if let Some(start) = self.start_time {
467 self.progress.elapsed_seconds = start.elapsed().as_secs_f64();
468 }
469
470 if let Some(seq) = event.seq {
472 if seq > self.last_seq {
473 self.last_seq = seq;
474 }
475 }
476
477 match event.category {
479 EventCategory::Baseline => self.handle_baseline(event),
480 EventCategory::Candidate => self.handle_candidate(event),
481 EventCategory::Frontier => self.handle_frontier(event),
482 EventCategory::Progress => self.handle_progress(event),
483 EventCategory::Generation => self.handle_generation(event),
484 EventCategory::Complete => self.handle_complete(event),
485 EventCategory::Termination => self.handle_termination(event),
486 EventCategory::Validation => self.handle_validation(event),
487 _ => {}
488 }
489 }
490
491 fn handle_baseline(&mut self, event: &ParsedEvent) {
492 let data = EventParser::parse_baseline(event);
493
494 self.baseline = Some(BaselineInfo {
495 reward: data.reward,
496 objectives: data.objectives,
497 val_reward: None,
498 instance_rewards: data.instance_rewards.unwrap_or_default(),
499 instance_objectives: data.instance_objectives,
500 seeds_evaluated: event
501 .data
502 .get("seeds_evaluated")
503 .and_then(|v| v.as_array())
504 .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
505 .unwrap_or_default(),
506 prompt: event.data.get("prompt").cloned(),
507 rollout_sample: event
508 .data
509 .get("rollout_sample")
510 .and_then(|v| v.as_array())
511 .map(|arr| {
512 arr.iter()
513 .filter_map(|item| {
514 serde_json::from_value::<RolloutSample>(item.clone()).ok()
515 })
516 .collect()
517 })
518 .unwrap_or_default(),
519 });
520
521 if let Some(acc) = data.reward {
522 self.progress.baseline_reward = Some(acc);
523 if self.progress.best_reward == 0.0 {
525 self.progress.best_reward = acc;
526 }
527 }
528
529 self.progress.phase = "optimization".to_string();
530 }
531
532 fn handle_candidate(&mut self, event: &ParsedEvent) {
533 let data = EventParser::parse_candidate(event);
534
535 let is_baseline = event
536 .data
537 .get("is_baseline")
538 .and_then(|v| v.as_bool())
539 .unwrap_or(false)
540 || data.parent_id.is_none();
541
542 let mut merged_data = event.data.as_object().cloned().unwrap_or_default();
543 if let Some(program_candidate) = event
544 .data
545 .get("program_candidate")
546 .and_then(|v| v.as_object())
547 {
548 for (k, v) in program_candidate {
549 merged_data.insert(k.clone(), v.clone());
550 }
551 }
552 let merged_value = Value::Object(merged_data.clone());
553
554 if is_baseline && self.baseline.is_none() {
555 let candidate_view = merged_data.clone();
556 let candidate_value = Value::Object(candidate_view.clone());
557
558 let parse_f64_map = |val: Option<&Value>| -> Option<HashMap<String, f64>> {
559 let map = val?.as_object()?;
560 let mut out = HashMap::new();
561 for (k, v) in map {
562 let val = v
563 .as_f64()
564 .or_else(|| v.as_str().and_then(|s| s.parse().ok()));
565 let val = match val {
566 Some(val) => val,
567 None => return None,
568 };
569 out.insert(k.clone(), val);
570 }
571 Some(out)
572 };
573
574 let objectives = parse_f64_map(candidate_view.get("objectives")).or_else(|| {
576 candidate_view
577 .get("score")
578 .and_then(|v| v.as_object())
579 .and_then(|score| parse_f64_map(score.get("objectives")))
580 });
581
582 let accuracy = objectives
583 .as_ref()
584 .and_then(|m| m.get("reward").copied())
585 .or_else(|| candidate_view.get("reward").and_then(|v| v.as_f64()))
586 .or_else(|| candidate_view.get("accuracy").and_then(|v| v.as_f64()))
587 .or_else(|| candidate_view.get("score").and_then(|v| v.as_f64()))
588 .or_else(|| {
589 candidate_view
590 .get("score")
591 .and_then(|v| v.as_object())
592 .and_then(|score| {
593 score
594 .get("reward")
595 .and_then(|v| v.as_f64())
596 .or_else(|| score.get("mean_reward").and_then(|v| v.as_f64()))
597 })
598 });
599
600 let objectives = objectives.or_else(|| {
602 accuracy.map(|r| {
603 let mut m = HashMap::new();
604 m.insert("reward".to_string(), r);
605 m
606 })
607 });
608
609 let instance_scores = candidate_view
610 .get("instance_scores")
611 .and_then(|v| v.as_array())
612 .and_then(|arr| {
613 let mut out = Vec::with_capacity(arr.len());
614 for item in arr {
615 let val = item
616 .as_f64()
617 .or_else(|| item.as_str().and_then(|s| s.parse().ok()))?;
618 out.push(val);
619 }
620 Some(out)
621 })
622 .or_else(|| Self::extract_instance_rewards(&candidate_value))
623 .unwrap_or_default();
624
625 let instance_objectives = candidate_view
626 .get("instance_objectives")
627 .and_then(|v| v.as_array())
628 .and_then(|arr| {
629 let mut out = Vec::with_capacity(arr.len());
630 for item in arr {
631 let obj = item.as_object()?;
632 let mut map = HashMap::new();
633 for (k, v) in obj {
634 let val = v
635 .as_f64()
636 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))?;
637 map.insert(k.clone(), val);
638 }
639 out.push(map);
640 }
641 Some(out)
642 });
643
644 self.baseline = Some(BaselineInfo {
645 reward: accuracy,
646 objectives,
647 val_reward: None,
648 instance_rewards: instance_scores,
649 instance_objectives,
650 seeds_evaluated: merged_data
651 .get("seeds_evaluated")
652 .and_then(|v| v.as_array())
653 .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
654 .unwrap_or_default(),
655 prompt: candidate_view.get("prompt").cloned(),
656 rollout_sample: candidate_view
657 .get("rollout_sample")
658 .and_then(|v| v.as_array())
659 .map(|arr| {
660 arr.iter()
661 .filter_map(|item| {
662 serde_json::from_value::<RolloutSample>(item.clone()).ok()
663 })
664 .collect()
665 })
666 .unwrap_or_default(),
667 });
668
669 if let Some(acc) = accuracy {
670 self.progress.baseline_reward = Some(acc);
671 }
672 }
673
674 if self.candidates_by_id.contains_key(&data.candidate_id) {
675 return;
676 }
677
678 let mut candidate = CandidateInfo {
679 candidate_id: data.candidate_id.clone(),
680 reward: data.reward,
681 objectives: data.objectives,
682 val_reward: None,
683 train_reward: data.reward,
684 generation: data.generation,
685 parent_id: data.parent_id,
686 is_pareto: data.is_pareto,
687 accepted: data.accepted,
688 mutation_type: data.mutation_type,
689 token_usage: None,
690 cost_usd: None,
691 timestamp: self.progress.elapsed_seconds,
692 timestamp_ms: event.timestamp_ms,
693 stages: HashMap::new(),
694 prompt_summary: None,
695 mutation_params: None,
696 transformation: None,
697 seed_rewards: Vec::new(),
698 seeds_evaluated: Vec::new(),
699 seed_info: Vec::new(),
700 rollout_sample: Vec::new(),
701 evaluation_duration_ms: None,
702 minibatch_rewards: Vec::new(),
703 skip_reason: None,
704 raw_data: HashMap::new(),
705 };
706
707 candidate.seeds_evaluated = merged_data
708 .get("seeds_evaluated")
709 .and_then(|v| v.as_array())
710 .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
711 .unwrap_or_default();
712
713 if let Some(val) = merged_data.get("val_accuracy").and_then(|v| v.as_f64()) {
714 candidate.val_reward = Some(val);
715 } else if let Some(val) = merged_data.get("full_score").and_then(|v| v.as_f64()) {
716 candidate.val_reward = Some(val);
717 }
718
719 if let Some(val) = merged_data.get("train_accuracy").and_then(|v| v.as_f64()) {
720 candidate.train_reward = Some(val);
721 } else if let Some(val) = merged_data.get("minibatch_score").and_then(|v| v.as_f64()) {
722 candidate.train_reward = Some(val);
723 }
724
725 if let Some(cost) = merged_data.get("cost_usd").and_then(|v| v.as_f64()) {
726 candidate.cost_usd = Some(cost);
727 }
728
729 if let Some(duration) = merged_data
730 .get("evaluation_duration_ms")
731 .and_then(|v| v.as_i64())
732 {
733 candidate.evaluation_duration_ms = Some(duration);
734 }
735
736 if let Some(scores) = merged_data
737 .get("minibatch_scores")
738 .and_then(|v| v.as_array())
739 {
740 candidate.minibatch_rewards = scores
741 .iter()
742 .filter_map(|v| {
743 v.as_f64()
744 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
745 })
746 .collect();
747 } else if let Some(score) = merged_data.get("minibatch_score").and_then(|v| v.as_f64()) {
748 candidate.minibatch_rewards = vec![score];
749 }
750
751 if let Some(reason) = merged_data.get("skip_reason").and_then(|v| v.as_str()) {
752 candidate.skip_reason = Some(reason.to_string());
753 }
754
755 if let Some(scores) = merged_data.get("seed_scores").and_then(|v| v.as_array()) {
756 candidate.seed_rewards = scores.clone();
757 }
758
759 if let Some(info) = merged_data.get("seed_info").and_then(|v| v.as_array()) {
760 let mut seed_info = Vec::with_capacity(info.len());
761 for item in info {
762 if let Ok(parsed) = serde_json::from_value::<SeedInfo>(item.clone()) {
763 seed_info.push(parsed);
764 }
765 }
766 candidate.seed_info = seed_info;
767 }
768
769 if let Some(samples) = merged_data.get("rollout_sample").and_then(|v| v.as_array()) {
770 let mut rollout_sample = Vec::with_capacity(samples.len());
771 for item in samples {
772 if let Ok(parsed) = serde_json::from_value::<RolloutSample>(item.clone()) {
773 rollout_sample.push(parsed);
774 }
775 }
776 candidate.rollout_sample = rollout_sample;
777 }
778
779 if let Some(token_usage) = merged_data.get("token_usage") {
780 if let Ok(parsed) = serde_json::from_value::<TokenUsage>(token_usage.clone()) {
781 candidate.token_usage = Some(parsed);
782 }
783 }
784
785 if let Some(mutation_params) = merged_data
786 .get("mutation_params")
787 .and_then(|v| v.as_object())
788 {
789 candidate.mutation_params = Some(mutation_params.clone().into_iter().collect());
790 }
791
792 if let Some(transformation) = merged_data
793 .get("transformation")
794 .and_then(|v| v.as_object())
795 {
796 candidate.transformation = Some(transformation.clone().into_iter().collect());
797 }
798
799 if let Some(stages) = merged_data.get("stages").and_then(|v| v.as_object()) {
800 let mut stage_map = HashMap::new();
801 for (key, value) in stages {
802 if let Ok(stage) = serde_json::from_value::<StageInfo>(value.clone()) {
803 stage_map.insert(key.clone(), stage);
804 }
805 }
806 candidate.stages = stage_map;
807 }
808
809 if candidate.prompt_summary.is_none() {
810 if let Some(summary) = merged_data
811 .get("prompt_summary")
812 .and_then(|v| v.as_str())
813 .or_else(|| merged_data.get("prompt_text").and_then(|v| v.as_str()))
814 {
815 candidate.prompt_summary = Some(summary.to_string());
816 } else if !candidate.stages.is_empty() {
817 let mut parts = Vec::new();
818 let mut stage_ids: Vec<_> = candidate.stages.keys().collect();
819 stage_ids.sort();
820 for stage_id in stage_ids {
821 if let Some(stage) = candidate.stages.get(stage_id) {
822 if !stage.instruction.is_empty() {
823 parts.push(format!(
824 "[{}]: {}",
825 stage_id.to_uppercase(),
826 stage.instruction
827 ));
828 }
829 }
830 }
831 if !parts.is_empty() {
832 candidate.prompt_summary = Some(parts.join("\n"));
833 }
834 }
835 }
836
837 if let Some(raw) = merged_value.as_object() {
838 candidate.raw_data = raw.clone().into_iter().collect();
839 }
840
841 if let Some(acc) = data.reward {
843 if acc > self.progress.best_reward {
844 self.progress.best_reward = acc;
845 }
846 }
847
848 let idx = self.candidates.len();
850 self.candidates.push(candidate);
851 self.candidates_by_id.insert(data.candidate_id, idx);
852 self.progress.candidates_evaluated += 1;
853 }
854
855 fn handle_frontier(&mut self, event: &ParsedEvent) {
856 let data = EventParser::parse_frontier(event);
857
858 self.frontier = data.frontier.clone();
859
860 if let Some(best) = data.best_reward {
861 if best > self.progress.best_reward {
862 self.progress.best_reward = best;
863 }
864 }
865
866 let update = FrontierUpdate {
867 timestamp: self.progress.elapsed_seconds,
868 added: data.added,
869 removed: data.removed,
870 frontier: data.frontier,
871 frontier_rewards: data.frontier_rewards.unwrap_or_default(),
872 frontier_objectives: data.frontier_objectives,
873 frontier_size: data.frontier_size,
874 optimistic_reward: data.best_reward,
875 generation: event
876 .data
877 .get("generation")
878 .and_then(|v| v.as_i64())
879 .map(|v| v as i32),
880 baseline_reward: event.data.get("baseline_score").and_then(|v| v.as_f64()),
881 timestamp_ms: event.timestamp_ms,
882 };
883 self.frontier_history.push(update);
884 }
885
886 fn handle_progress(&mut self, event: &ParsedEvent) {
887 let data = EventParser::parse_progress(event);
888
889 self.progress.rollouts_completed = data.rollouts_completed;
890 if let Some(total) = data.rollouts_total {
891 self.progress.rollouts_total = total;
892 }
893
894 if let Some(best) = data.best_reward {
895 if best > self.progress.best_reward {
896 self.progress.best_reward = best;
897 }
898 }
899
900 if let Some(baseline) = data.baseline_reward {
901 if self.progress.baseline_reward.is_none() {
902 self.progress.baseline_reward = Some(baseline);
903 }
904 }
905
906 if self.progress.rollouts_total > 0 && self.progress.rollouts_completed > 0 {
908 let remaining = self.progress.rollouts_total - self.progress.rollouts_completed;
909 let rate = self.progress.elapsed_seconds / self.progress.rollouts_completed as f64;
910 self.progress.eta_seconds = Some(remaining as f64 * rate);
911 }
912 }
913
914 fn handle_generation(&mut self, event: &ParsedEvent) {
915 let data = EventParser::parse_generation(event);
916
917 self.progress.generations_completed = data.generation;
918
919 let info = GenerationInfo {
920 generation: data.generation,
921 best_reward: data.best_reward,
922 candidates_proposed: data.candidates_proposed,
923 candidates_accepted: data.candidates_accepted,
924 frontier_size: event
925 .data
926 .get("frontier_size")
927 .and_then(|v| v.as_i64())
928 .map(|v| v as i32)
929 .unwrap_or(0),
930 children: event
931 .data
932 .get("children")
933 .and_then(|v| v.as_array())
934 .cloned()
935 .unwrap_or_default(),
936 duration_ms: event.data.get("duration_ms").and_then(|v| v.as_f64()),
937 timestamp: event
938 .data
939 .get("timestamp")
940 .and_then(|v| v.as_f64())
941 .unwrap_or(self.progress.elapsed_seconds),
942 };
943 self.generation_history.push(info);
944 }
945
946 fn handle_complete(&mut self, event: &ParsedEvent) {
947 let data = EventParser::parse_complete(event);
948
949 self.progress.phase = "complete".to_string();
950 self.progress.finish_reason = data.finish_reason;
951
952 if let Some(best) = data.best_reward {
953 self.progress.best_reward = best;
954 }
955
956 if let Some(baseline) = data.baseline_reward {
957 self.progress.baseline_reward = Some(baseline);
958 }
959 }
960
961 fn handle_termination(&mut self, event: &ParsedEvent) {
962 let data = EventParser::parse_termination(event);
963
964 self.progress.phase = "complete".to_string();
965 self.progress.finish_reason = Some(data.reason);
966 }
967
968 fn handle_validation(&mut self, event: &ParsedEvent) {
969 self.progress.phase = "validation".to_string();
970
971 if let Some(candidate_id) = event.data.get("candidate_id").and_then(|v| v.as_str()) {
973 if let Some(val_score) = event.data.get("val_accuracy").and_then(|v| v.as_f64()) {
974 if let Some(&idx) = self.candidates_by_id.get(candidate_id) {
975 self.candidates[idx].val_reward = Some(val_score);
976 }
977 }
978 }
979 }
980
981 pub fn to_summary(&self) -> serde_json::Value {
983 serde_json::json!({
984 "phase": self.progress.phase,
985 "rollouts_completed": self.progress.rollouts_completed,
986 "rollouts_total": self.progress.rollouts_total,
987 "candidates_evaluated": self.progress.candidates_evaluated,
988 "generations_completed": self.progress.generations_completed,
989 "best_reward": self.progress.best_reward,
990 "baseline_reward": self.progress.baseline_reward,
991 "lift": self.lift(),
992 "elapsed_seconds": self.progress.elapsed_seconds,
993 "frontier_size": self.frontier.len(),
994 })
995 }
996}
997
998#[cfg(test)]
999mod tests {
1000 use super::*;
1001 use serde_json::json;
1002
1003 #[test]
1004 fn test_progress_default() {
1005 let progress = GEPAProgress::default();
1006 assert_eq!(progress.phase, "init");
1007 assert_eq!(progress.progress_pct(), 0.0);
1008 assert!(progress.lift().is_none());
1009 }
1010
1011 #[test]
1012 fn test_progress_lift() {
1013 let mut progress = GEPAProgress::default();
1014 progress.baseline_reward = Some(0.5);
1015 progress.best_reward = 0.75;
1016
1017 let lift = progress.lift().unwrap();
1018 assert!((lift - 0.5).abs() < 0.001); }
1020
1021 #[test]
1022 fn test_tracker_baseline() {
1023 let mut tracker = ProgressTracker::new();
1024
1025 let event = EventParser::parse(&json!({
1026 "type": "learning.policy.gepa.baseline",
1027 "seq": 1,
1028 "data": { "accuracy": 0.72 }
1029 }));
1030
1031 tracker.update(&event);
1032
1033 assert!(tracker.baseline.is_some());
1034 assert_eq!(tracker.baseline_reward(), Some(0.72));
1035 assert_eq!(tracker.progress.phase, "optimization");
1036 }
1037
1038 #[test]
1039 fn test_tracker_candidate() {
1040 let mut tracker = ProgressTracker::new();
1041
1042 tracker.update(&EventParser::parse(&json!({
1044 "type": "learning.policy.gepa.baseline",
1045 "data": { "accuracy": 0.72 }
1046 })));
1047
1048 tracker.update(&EventParser::parse(&json!({
1050 "type": "learning.policy.gepa.candidate.evaluated",
1051 "seq": 2,
1052 "data": {
1053 "candidate_id": "cand_1",
1054 "accuracy": 0.85,
1055 "accepted": true,
1056 "generation": 1
1057 }
1058 })));
1059
1060 assert_eq!(tracker.candidates.len(), 1);
1061 assert_eq!(tracker.best_reward(), 0.85);
1062 assert_eq!(tracker.progress.candidates_evaluated, 1);
1063 }
1064
1065 #[test]
1066 fn test_tracker_frontier() {
1067 let mut tracker = ProgressTracker::new();
1068
1069 tracker.update(&EventParser::parse(&json!({
1070 "type": "learning.policy.gepa.frontier_updated",
1071 "data": {
1072 "frontier": ["cand_1", "cand_2"],
1073 "best_score": 0.88
1074 }
1075 })));
1076
1077 assert_eq!(tracker.frontier.len(), 2);
1078 assert_eq!(tracker.frontier_history.len(), 1);
1079 assert_eq!(tracker.best_reward(), 0.88);
1080 }
1081
1082 #[test]
1083 fn test_tracker_complete() {
1084 let mut tracker = ProgressTracker::new();
1085
1086 tracker.update(&EventParser::parse(&json!({
1087 "type": "learning.policy.gepa.job.completed",
1088 "data": {
1089 "best_score": 0.92,
1090 "baseline_score": 0.72,
1091 "finish_reason": "budget_exhausted"
1092 }
1093 })));
1094
1095 assert_eq!(tracker.progress.phase, "complete");
1096 assert_eq!(
1097 tracker.progress.finish_reason,
1098 Some("budget_exhausted".to_string())
1099 );
1100 assert_eq!(tracker.best_reward(), 0.92);
1101 }
1102}