1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::Instant;
9
10use super::events::{EventCategory, EventParser, ParsedEvent};
11
12#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct TokenUsage {
15 pub prompt_tokens: i64,
17 pub completion_tokens: i64,
19 pub total_tokens: i64,
21 #[serde(default)]
23 pub reasoning_tokens: i64,
24 #[serde(default)]
26 pub cached_tokens: i64,
27}
28
29impl TokenUsage {
30 pub fn new(prompt: i64, completion: i64) -> Self {
32 Self {
33 prompt_tokens: prompt,
34 completion_tokens: completion,
35 total_tokens: prompt + completion,
36 reasoning_tokens: 0,
37 cached_tokens: 0,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CandidateInfo {
45 pub candidate_id: String,
47 #[serde(default)]
49 pub accuracy: Option<f64>,
50 #[serde(default)]
52 pub objectives: Option<HashMap<String, f64>>,
53 #[serde(default)]
55 pub val_accuracy: Option<f64>,
56 #[serde(default)]
58 pub train_accuracy: Option<f64>,
59 #[serde(default)]
61 pub generation: Option<i32>,
62 #[serde(default)]
64 pub parent_id: Option<String>,
65 #[serde(default)]
67 pub is_pareto: bool,
68 #[serde(default)]
70 pub accepted: bool,
71 #[serde(default)]
73 pub mutation_type: Option<String>,
74 #[serde(default)]
76 pub token_usage: Option<TokenUsage>,
77 #[serde(default)]
79 pub cost_usd: Option<f64>,
80 #[serde(default)]
82 pub timestamp: f64,
83 #[serde(default)]
85 pub timestamp_ms: Option<i64>,
86}
87
88impl Default for CandidateInfo {
89 fn default() -> Self {
90 Self {
91 candidate_id: String::new(),
92 accuracy: None,
93 objectives: None,
94 val_accuracy: None,
95 train_accuracy: None,
96 generation: None,
97 parent_id: None,
98 is_pareto: false,
99 accepted: false,
100 mutation_type: None,
101 token_usage: None,
102 cost_usd: None,
103 timestamp: 0.0,
104 timestamp_ms: None,
105 }
106 }
107}
108
109#[derive(Debug, Clone, Default, Serialize, Deserialize)]
111pub struct BaselineInfo {
112 pub accuracy: Option<f64>,
114 #[serde(default)]
116 pub objectives: Option<HashMap<String, f64>>,
117 #[serde(default)]
119 pub val_accuracy: Option<f64>,
120 #[serde(default)]
122 pub instance_scores: Vec<f64>,
123}
124
125#[derive(Debug, Clone, Default, Serialize, Deserialize)]
127pub struct FrontierUpdate {
128 pub timestamp: f64,
130 #[serde(default)]
132 pub added: Vec<String>,
133 #[serde(default)]
135 pub removed: Vec<String>,
136 #[serde(default)]
138 pub frontier: Vec<String>,
139 #[serde(default)]
141 pub frontier_scores: HashMap<String, f64>,
142 #[serde(default)]
144 pub frontier_size: i32,
145 #[serde(default)]
147 pub optimistic_score: Option<f64>,
148 #[serde(default)]
150 pub generation: Option<i32>,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct GEPAProgress {
156 pub phase: String,
158 pub rollouts_completed: i32,
160 pub rollouts_total: i32,
162 pub generations_completed: i32,
164 pub candidates_evaluated: i32,
166 pub best_score: f64,
168 pub baseline_score: Option<f64>,
170 pub elapsed_seconds: f64,
172 pub eta_seconds: Option<f64>,
174 pub finish_reason: Option<String>,
176}
177
178impl Default for GEPAProgress {
179 fn default() -> Self {
180 Self {
181 phase: "init".to_string(),
182 rollouts_completed: 0,
183 rollouts_total: 0,
184 generations_completed: 0,
185 candidates_evaluated: 0,
186 best_score: 0.0,
187 baseline_score: None,
188 elapsed_seconds: 0.0,
189 eta_seconds: None,
190 finish_reason: None,
191 }
192 }
193}
194
195impl GEPAProgress {
196 pub fn progress_pct(&self) -> f64 {
198 if self.rollouts_total > 0 {
199 (self.rollouts_completed as f64 / self.rollouts_total as f64) * 100.0
200 } else {
201 0.0
202 }
203 }
204
205 pub fn lift(&self) -> Option<f64> {
207 self.baseline_score.map(|b| {
208 if b > 0.0 {
209 (self.best_score - b) / b
210 } else {
211 0.0
212 }
213 })
214 }
215}
216
217pub struct ProgressTracker {
219 pub progress: GEPAProgress,
221 pub candidates: Vec<CandidateInfo>,
223 candidates_by_id: HashMap<String, usize>,
225 pub baseline: Option<BaselineInfo>,
227 pub frontier: Vec<String>,
229 pub frontier_history: Vec<FrontierUpdate>,
231 pub generation_history: Vec<GenerationInfo>,
233 start_time: Option<Instant>,
235 pub last_seq: i64,
237}
238
239#[derive(Debug, Clone, Default, Serialize, Deserialize)]
241pub struct GenerationInfo {
242 pub generation: i32,
244 pub best_accuracy: f64,
246 pub candidates_proposed: i32,
248 pub candidates_accepted: i32,
250}
251
252impl Default for ProgressTracker {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258impl ProgressTracker {
259 pub fn new() -> Self {
261 Self {
262 progress: GEPAProgress::default(),
263 candidates: Vec::new(),
264 candidates_by_id: HashMap::new(),
265 baseline: None,
266 frontier: Vec::new(),
267 frontier_history: Vec::new(),
268 generation_history: Vec::new(),
269 start_time: None,
270 last_seq: -1,
271 }
272 }
273
274 pub fn best_score(&self) -> f64 {
276 self.progress.best_score
277 }
278
279 pub fn baseline_score(&self) -> Option<f64> {
281 self.progress.baseline_score
282 }
283
284 pub fn lift(&self) -> Option<f64> {
286 self.progress.lift()
287 }
288
289 pub fn current_frontier(&self) -> &[String] {
291 &self.frontier
292 }
293
294 pub fn update(&mut self, event: &ParsedEvent) {
296 if self.start_time.is_none() {
298 self.start_time = Some(Instant::now());
299 }
300
301 if let Some(start) = self.start_time {
303 self.progress.elapsed_seconds = start.elapsed().as_secs_f64();
304 }
305
306 if let Some(seq) = event.seq {
308 if seq > self.last_seq {
309 self.last_seq = seq;
310 }
311 }
312
313 match event.category {
315 EventCategory::Baseline => self.handle_baseline(event),
316 EventCategory::Candidate => self.handle_candidate(event),
317 EventCategory::Frontier => self.handle_frontier(event),
318 EventCategory::Progress => self.handle_progress(event),
319 EventCategory::Generation => self.handle_generation(event),
320 EventCategory::Complete => self.handle_complete(event),
321 EventCategory::Termination => self.handle_termination(event),
322 EventCategory::Validation => self.handle_validation(event),
323 _ => {}
324 }
325 }
326
327 fn handle_baseline(&mut self, event: &ParsedEvent) {
328 let data = EventParser::parse_baseline(event);
329
330 self.baseline = Some(BaselineInfo {
331 accuracy: data.accuracy,
332 objectives: data.objectives,
333 val_accuracy: None,
334 instance_scores: data.instance_scores.unwrap_or_default(),
335 });
336
337 if let Some(acc) = data.accuracy {
338 self.progress.baseline_score = Some(acc);
339 if self.progress.best_score == 0.0 {
341 self.progress.best_score = acc;
342 }
343 }
344
345 self.progress.phase = "optimization".to_string();
346 }
347
348 fn handle_candidate(&mut self, event: &ParsedEvent) {
349 let data = EventParser::parse_candidate(event);
350
351 let candidate = CandidateInfo {
352 candidate_id: data.candidate_id.clone(),
353 accuracy: data.accuracy,
354 objectives: data.objectives,
355 val_accuracy: None,
356 train_accuracy: data.accuracy,
357 generation: data.generation,
358 parent_id: data.parent_id,
359 is_pareto: data.is_pareto,
360 accepted: data.accepted,
361 mutation_type: data.mutation_type,
362 token_usage: None,
363 cost_usd: None,
364 timestamp: self.progress.elapsed_seconds,
365 timestamp_ms: event.timestamp_ms,
366 };
367
368 if let Some(acc) = data.accuracy {
370 if acc > self.progress.best_score {
371 self.progress.best_score = acc;
372 }
373 }
374
375 let idx = self.candidates.len();
377 self.candidates.push(candidate);
378 self.candidates_by_id.insert(data.candidate_id, idx);
379 self.progress.candidates_evaluated += 1;
380 }
381
382 fn handle_frontier(&mut self, event: &ParsedEvent) {
383 let data = EventParser::parse_frontier(event);
384
385 self.frontier = data.frontier.clone();
386
387 if let Some(best) = data.best_score {
388 if best > self.progress.best_score {
389 self.progress.best_score = best;
390 }
391 }
392
393 let update = FrontierUpdate {
394 timestamp: self.progress.elapsed_seconds,
395 added: data.added,
396 removed: data.removed,
397 frontier: data.frontier,
398 frontier_scores: data.frontier_scores.unwrap_or_default(),
399 frontier_size: data.frontier_size,
400 optimistic_score: data.best_score,
401 generation: None,
402 };
403 self.frontier_history.push(update);
404 }
405
406 fn handle_progress(&mut self, event: &ParsedEvent) {
407 let data = EventParser::parse_progress(event);
408
409 self.progress.rollouts_completed = data.rollouts_completed;
410 if let Some(total) = data.rollouts_total {
411 self.progress.rollouts_total = total;
412 }
413
414 if let Some(best) = data.best_score {
415 if best > self.progress.best_score {
416 self.progress.best_score = best;
417 }
418 }
419
420 if let Some(baseline) = data.baseline_score {
421 if self.progress.baseline_score.is_none() {
422 self.progress.baseline_score = Some(baseline);
423 }
424 }
425
426 if self.progress.rollouts_total > 0 && self.progress.rollouts_completed > 0 {
428 let remaining = self.progress.rollouts_total - self.progress.rollouts_completed;
429 let rate = self.progress.elapsed_seconds / self.progress.rollouts_completed as f64;
430 self.progress.eta_seconds = Some(remaining as f64 * rate);
431 }
432 }
433
434 fn handle_generation(&mut self, event: &ParsedEvent) {
435 let data = EventParser::parse_generation(event);
436
437 self.progress.generations_completed = data.generation;
438
439 let info = GenerationInfo {
440 generation: data.generation,
441 best_accuracy: data.best_accuracy,
442 candidates_proposed: data.candidates_proposed,
443 candidates_accepted: data.candidates_accepted,
444 };
445 self.generation_history.push(info);
446 }
447
448 fn handle_complete(&mut self, event: &ParsedEvent) {
449 let data = EventParser::parse_complete(event);
450
451 self.progress.phase = "complete".to_string();
452 self.progress.finish_reason = data.finish_reason;
453
454 if let Some(best) = data.best_score {
455 self.progress.best_score = best;
456 }
457
458 if let Some(baseline) = data.baseline_score {
459 self.progress.baseline_score = Some(baseline);
460 }
461 }
462
463 fn handle_termination(&mut self, event: &ParsedEvent) {
464 let data = EventParser::parse_termination(event);
465
466 self.progress.phase = "complete".to_string();
467 self.progress.finish_reason = Some(data.reason);
468 }
469
470 fn handle_validation(&mut self, event: &ParsedEvent) {
471 self.progress.phase = "validation".to_string();
472
473 if let Some(candidate_id) = event.data.get("candidate_id").and_then(|v| v.as_str()) {
475 if let Some(val_score) = event.data.get("val_accuracy").and_then(|v| v.as_f64()) {
476 if let Some(&idx) = self.candidates_by_id.get(candidate_id) {
477 self.candidates[idx].val_accuracy = Some(val_score);
478 }
479 }
480 }
481 }
482
483 pub fn to_summary(&self) -> serde_json::Value {
485 serde_json::json!({
486 "phase": self.progress.phase,
487 "rollouts_completed": self.progress.rollouts_completed,
488 "rollouts_total": self.progress.rollouts_total,
489 "candidates_evaluated": self.progress.candidates_evaluated,
490 "generations_completed": self.progress.generations_completed,
491 "best_score": self.progress.best_score,
492 "baseline_score": self.progress.baseline_score,
493 "lift": self.lift(),
494 "elapsed_seconds": self.progress.elapsed_seconds,
495 "frontier_size": self.frontier.len(),
496 })
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use serde_json::json;
504
505 #[test]
506 fn test_progress_default() {
507 let progress = GEPAProgress::default();
508 assert_eq!(progress.phase, "init");
509 assert_eq!(progress.progress_pct(), 0.0);
510 assert!(progress.lift().is_none());
511 }
512
513 #[test]
514 fn test_progress_lift() {
515 let mut progress = GEPAProgress::default();
516 progress.baseline_score = Some(0.5);
517 progress.best_score = 0.75;
518
519 let lift = progress.lift().unwrap();
520 assert!((lift - 0.5).abs() < 0.001); }
522
523 #[test]
524 fn test_tracker_baseline() {
525 let mut tracker = ProgressTracker::new();
526
527 let event = EventParser::parse(&json!({
528 "type": "learning.policy.gepa.baseline",
529 "seq": 1,
530 "data": { "accuracy": 0.72 }
531 }));
532
533 tracker.update(&event);
534
535 assert!(tracker.baseline.is_some());
536 assert_eq!(tracker.baseline_score(), Some(0.72));
537 assert_eq!(tracker.progress.phase, "optimization");
538 }
539
540 #[test]
541 fn test_tracker_candidate() {
542 let mut tracker = ProgressTracker::new();
543
544 tracker.update(&EventParser::parse(&json!({
546 "type": "learning.policy.gepa.baseline",
547 "data": { "accuracy": 0.72 }
548 })));
549
550 tracker.update(&EventParser::parse(&json!({
552 "type": "learning.policy.gepa.candidate.evaluated",
553 "seq": 2,
554 "data": {
555 "candidate_id": "cand_1",
556 "accuracy": 0.85,
557 "accepted": true,
558 "generation": 1
559 }
560 })));
561
562 assert_eq!(tracker.candidates.len(), 1);
563 assert_eq!(tracker.best_score(), 0.85);
564 assert_eq!(tracker.progress.candidates_evaluated, 1);
565 }
566
567 #[test]
568 fn test_tracker_frontier() {
569 let mut tracker = ProgressTracker::new();
570
571 tracker.update(&EventParser::parse(&json!({
572 "type": "learning.policy.gepa.frontier_updated",
573 "data": {
574 "frontier": ["cand_1", "cand_2"],
575 "best_score": 0.88
576 }
577 })));
578
579 assert_eq!(tracker.frontier.len(), 2);
580 assert_eq!(tracker.frontier_history.len(), 1);
581 assert_eq!(tracker.best_score(), 0.88);
582 }
583
584 #[test]
585 fn test_tracker_complete() {
586 let mut tracker = ProgressTracker::new();
587
588 tracker.update(&EventParser::parse(&json!({
589 "type": "learning.policy.gepa.job.completed",
590 "data": {
591 "best_score": 0.92,
592 "baseline_score": 0.72,
593 "finish_reason": "budget_exhausted"
594 }
595 })));
596
597 assert_eq!(tracker.progress.phase, "complete");
598 assert_eq!(tracker.progress.finish_reason, Some("budget_exhausted".to_string()));
599 assert_eq!(tracker.best_score(), 0.92);
600 }
601}