1use crate::error::EvaluationError;
2use crate::evaluate::compare::compare_results;
3use crate::evaluate::types::{ComparisonResults, EvalResults};
4use owo_colors::OwoColorize;
5use potato_head::PyHelperFuncs;
6use pyo3::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use tabled::Tabled;
10use tabled::{
11 settings::{object::Rows, Alignment, Color, Format, Style},
12 Table,
13};
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16#[pyclass]
17pub struct TaskSummary {
18 #[pyo3(get)]
19 pub task_id: String,
20
21 #[pyo3(get)]
22 pub passed: bool,
23
24 #[pyo3(get)]
25 pub value: f64,
26}
27
28#[pymethods]
29impl TaskSummary {
30 pub fn __str__(&self) -> String {
31 PyHelperFuncs::__str__(self)
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36#[pyclass]
37pub struct EvalMetrics {
38 #[pyo3(get)]
39 pub overall_pass_rate: f64,
40
41 pub dataset_pass_rates: HashMap<String, f64>,
43
44 #[pyo3(get)]
45 pub scenario_pass_rate: f64,
46
47 #[pyo3(get)]
48 pub total_scenarios: usize,
49
50 #[pyo3(get)]
51 pub passed_scenarios: usize,
52
53 #[serde(default)]
54 pub scenario_task_pass_rates: HashMap<String, HashMap<String, f64>>,
55}
56
57#[derive(Tabled)]
58struct MetricEntry {
59 #[tabled(rename = "Metric")]
60 metric: String,
61 #[tabled(rename = "Value")]
62 value: String,
63}
64
65#[derive(Tabled)]
66struct DatasetPassRateEntry {
67 #[tabled(rename = "Alias")]
68 alias: String,
69 #[tabled(rename = "Pass Rate")]
70 pass_rate: String,
71}
72
73#[pymethods]
74impl EvalMetrics {
75 pub fn __str__(&self) -> String {
76 PyHelperFuncs::__str__(self)
77 }
78
79 #[getter]
80 pub fn dataset_pass_rates(&self) -> HashMap<String, f64> {
81 self.dataset_pass_rates.clone()
82 }
83
84 #[getter]
85 pub fn scenario_task_pass_rates(&self) -> HashMap<String, HashMap<String, f64>> {
86 self.scenario_task_pass_rates.clone()
87 }
88
89 pub fn as_table(&self) {
90 println!("\n{}", "Aggregate Metrics".truecolor(245, 77, 85).bold());
91
92 let entries = vec![
93 MetricEntry {
94 metric: "Overall Pass Rate".to_string(),
95 value: format!("{:.1}%", self.overall_pass_rate * 100.0),
96 },
97 MetricEntry {
98 metric: "Scenario Pass Rate".to_string(),
99 value: format!("{:.1}%", self.scenario_pass_rate * 100.0),
100 },
101 MetricEntry {
102 metric: "Total Scenarios".to_string(),
103 value: self.total_scenarios.to_string(),
104 },
105 MetricEntry {
106 metric: "Passed Scenarios".to_string(),
107 value: self.passed_scenarios.to_string(),
108 },
109 ];
110
111 let mut table = Table::new(entries);
112 table.with(Style::sharp());
113 table.modify(
114 Rows::new(0..1),
115 (
116 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
117 Alignment::center(),
118 Color::BOLD,
119 ),
120 );
121 println!("{}", table);
122
123 if !self.dataset_pass_rates.is_empty() {
124 println!("\n{}", "Sub-Agent Pass Rates".truecolor(245, 77, 85).bold());
125
126 let mut alias_entries: Vec<_> = self
127 .dataset_pass_rates
128 .iter()
129 .map(|(alias, rate)| DatasetPassRateEntry {
130 alias: alias.clone(),
131 pass_rate: format!("{:.1}%", rate * 100.0),
132 })
133 .collect();
134 alias_entries.sort_by(|a, b| a.alias.cmp(&b.alias));
135
136 let mut alias_table = Table::new(alias_entries);
137 alias_table.with(Style::sharp());
138 alias_table.modify(
139 Rows::new(0..1),
140 (
141 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
142 Alignment::center(),
143 Color::BOLD,
144 ),
145 );
146 println!("{}", alias_table);
147 }
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[pyclass]
153pub struct ScenarioResult {
154 #[pyo3(get)]
155 pub scenario_id: String,
156
157 #[pyo3(get)]
158 pub initial_query: String,
159
160 pub eval_results: EvalResults,
161
162 #[pyo3(get)]
163 pub passed: bool,
164
165 #[pyo3(get)]
166 pub pass_rate: f64,
167
168 #[pyo3(get)]
169 #[serde(default)]
170 pub task_results: Vec<TaskSummary>,
171}
172
173impl PartialEq for ScenarioResult {
174 fn eq(&self, other: &Self) -> bool {
175 self.scenario_id == other.scenario_id
176 && self.initial_query == other.initial_query
177 && self.passed == other.passed
178 && self.pass_rate == other.pass_rate
179 && self.task_results == other.task_results
180 }
181}
182
183#[pymethods]
184impl ScenarioResult {
185 pub fn __str__(&self) -> String {
186 PyHelperFuncs::__str__(self)
187 }
188
189 #[getter]
190 pub fn eval_results(&self) -> EvalResults {
191 self.eval_results.clone()
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
196#[pyclass]
197pub struct ScenarioDelta {
198 #[pyo3(get)]
199 pub scenario_id: String,
200
201 #[pyo3(get)]
202 pub initial_query: String,
203
204 #[pyo3(get)]
205 pub baseline_passed: bool,
206
207 #[pyo3(get)]
208 pub comparison_passed: bool,
209
210 #[pyo3(get)]
211 pub baseline_pass_rate: f64,
212
213 #[pyo3(get)]
214 pub comparison_pass_rate: f64,
215
216 #[pyo3(get)]
217 pub status_changed: bool,
218}
219
220#[pymethods]
221impl ScenarioDelta {
222 pub fn __str__(&self) -> String {
223 PyHelperFuncs::__str__(self)
224 }
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
228#[pyclass]
229pub struct ScenarioComparisonResults {
230 pub dataset_comparisons: HashMap<String, ComparisonResults>,
231
232 #[pyo3(get)]
233 pub scenario_deltas: Vec<ScenarioDelta>,
234
235 #[pyo3(get)]
236 pub baseline_overall_pass_rate: f64,
237
238 #[pyo3(get)]
239 pub comparison_overall_pass_rate: f64,
240
241 #[pyo3(get)]
242 pub regressed: bool,
243
244 #[pyo3(get)]
245 pub improved_aliases: Vec<String>,
246
247 #[pyo3(get)]
248 pub regressed_aliases: Vec<String>,
249
250 #[pyo3(get)]
251 #[serde(default)]
252 pub new_aliases: Vec<String>,
253
254 #[pyo3(get)]
255 #[serde(default)]
256 pub removed_aliases: Vec<String>,
257
258 #[pyo3(get)]
259 #[serde(default)]
260 pub new_scenarios: Vec<String>,
261
262 #[pyo3(get)]
263 #[serde(default)]
264 pub removed_scenarios: Vec<String>,
265
266 #[serde(default)]
267 pub baseline_alias_pass_rates: HashMap<String, f64>,
268
269 #[serde(default)]
270 pub comparison_alias_pass_rates: HashMap<String, f64>,
271}
272
273impl PartialEq for ScenarioComparisonResults {
274 fn eq(&self, other: &Self) -> bool {
275 self.scenario_deltas == other.scenario_deltas
276 && self.baseline_overall_pass_rate == other.baseline_overall_pass_rate
277 && self.comparison_overall_pass_rate == other.comparison_overall_pass_rate
278 && self.regressed == other.regressed
279 && self.improved_aliases == other.improved_aliases
280 && self.regressed_aliases == other.regressed_aliases
281 && self.new_aliases == other.new_aliases
282 && self.removed_aliases == other.removed_aliases
283 && self.new_scenarios == other.new_scenarios
284 && self.removed_scenarios == other.removed_scenarios
285 && self.baseline_alias_pass_rates == other.baseline_alias_pass_rates
286 && self.comparison_alias_pass_rates == other.comparison_alias_pass_rates
287 }
289}
290
291#[derive(Tabled)]
292struct DatasetComparisonEntry {
293 #[tabled(rename = "Alias")]
294 alias: String,
295 #[tabled(rename = "Delta")]
296 delta: String,
297 #[tabled(rename = "Status")]
298 status: String,
299}
300
301#[derive(Tabled)]
302struct ScenarioDeltaEntry {
303 #[tabled(rename = "Scenario ID")]
304 scenario_id: String,
305 #[tabled(rename = "Baseline")]
306 baseline: String,
307 #[tabled(rename = "Current")]
308 current: String,
309 #[tabled(rename = "Pass Rate Δ")]
310 pass_rate_delta: String,
311 #[tabled(rename = "Change")]
312 change: String,
313}
314
315#[derive(Tabled)]
316struct AliasPassRateEntry {
317 #[tabled(rename = "Alias")]
318 alias: String,
319 #[tabled(rename = "Baseline")]
320 baseline: String,
321 #[tabled(rename = "Current")]
322 current: String,
323 #[tabled(rename = "Delta")]
324 delta: String,
325 #[tabled(rename = "Status")]
326 status: String,
327}
328
329#[pymethods]
330impl ScenarioComparisonResults {
331 pub fn __str__(&self) -> String {
332 PyHelperFuncs::__str__(self)
333 }
334
335 #[getter]
336 pub fn dataset_comparisons(&self) -> HashMap<String, ComparisonResults> {
337 self.dataset_comparisons.clone()
338 }
339
340 #[getter]
341 pub fn baseline_alias_pass_rates(&self) -> HashMap<String, f64> {
342 self.baseline_alias_pass_rates.clone()
343 }
344
345 #[getter]
346 pub fn comparison_alias_pass_rates(&self) -> HashMap<String, f64> {
347 self.comparison_alias_pass_rates.clone()
348 }
349
350 pub fn model_dump_json(&self) -> Result<String, EvaluationError> {
351 serde_json::to_string(self).map_err(Into::into)
352 }
353
354 #[staticmethod]
355 pub fn model_validate_json(json_string: String) -> Result<Self, EvaluationError> {
356 serde_json::from_str(&json_string).map_err(Into::into)
357 }
358
359 pub fn save(&self, path: &str) -> Result<(), EvaluationError> {
360 let json = serde_json::to_string_pretty(self)?;
361 std::fs::write(path, json)?;
362 Ok(())
363 }
364
365 #[staticmethod]
366 pub fn load(path: &str) -> Result<Self, EvaluationError> {
367 let json = std::fs::read_to_string(path)?;
368 serde_json::from_str(&json).map_err(Into::into)
369 }
370
371 pub fn as_table(&self) {
372 if !self.baseline_alias_pass_rates.is_empty()
374 || !self.comparison_alias_pass_rates.is_empty()
375 {
376 println!("\n{}", "Sub-Agent Comparison".truecolor(245, 77, 85).bold());
377
378 let mut all_aliases: Vec<String> = self
379 .baseline_alias_pass_rates
380 .keys()
381 .chain(self.comparison_alias_pass_rates.keys())
382 .cloned()
383 .collect();
384 all_aliases.sort();
385 all_aliases.dedup();
386
387 let mut alias_entries: Vec<AliasPassRateEntry> = Vec::new();
388 for alias in &all_aliases {
389 let baseline_rate = self.baseline_alias_pass_rates.get(alias);
390 let current_rate = self.comparison_alias_pass_rates.get(alias);
391
392 let (baseline_str, current_str, delta_str, status) =
393 match (baseline_rate, current_rate) {
394 (Some(b), Some(c)) => {
395 let delta = (c - b) * 100.0;
396 let d = format!("{:+.1}%", delta);
397 let colored_delta = if delta > 1.0 {
398 d.green().to_string()
399 } else if delta < -1.0 {
400 d.red().to_string()
401 } else {
402 d.yellow().to_string()
403 };
404 let s = if self.regressed_aliases.contains(alias) {
405 "REGRESSION".red().to_string()
406 } else if self.improved_aliases.contains(alias) {
407 "IMPROVED".green().to_string()
408 } else {
409 "UNCHANGED".yellow().to_string()
410 };
411 (
412 format!("{:.1}%", b * 100.0),
413 format!("{:.1}%", c * 100.0),
414 colored_delta,
415 s,
416 )
417 }
418 (None, Some(c)) => (
419 "-".to_string(),
420 format!("{:.1}%", c * 100.0),
421 "-".to_string(),
422 "NEW".green().bold().to_string(),
423 ),
424 (Some(b), None) => (
425 format!("{:.1}%", b * 100.0),
426 "-".to_string(),
427 "-".to_string(),
428 "REMOVED".red().bold().to_string(),
429 ),
430 (None, None) => continue,
431 };
432
433 alias_entries.push(AliasPassRateEntry {
434 alias: alias.clone(),
435 baseline: baseline_str,
436 current: current_str,
437 delta: delta_str,
438 status,
439 });
440 }
441
442 let mut alias_table = Table::new(alias_entries);
443 alias_table.with(Style::sharp());
444 alias_table.modify(
445 Rows::new(0..1),
446 (
447 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
448 Alignment::center(),
449 Color::BOLD,
450 ),
451 );
452 println!("{}", alias_table);
453 } else if !self.dataset_comparisons.is_empty() {
454 println!("\n{}", "Sub-Agent Comparison".truecolor(245, 77, 85).bold());
456
457 let mut alias_entries: Vec<_> = self
458 .dataset_comparisons
459 .iter()
460 .map(|(alias, comp)| {
461 let delta = comp.mean_pass_rate_delta * 100.0;
462 let delta_str = format!("{:+.1}%", delta);
463 let colored_delta = if delta > 1.0 {
464 delta_str.green().to_string()
465 } else if delta < -1.0 {
466 delta_str.red().to_string()
467 } else {
468 delta_str.yellow().to_string()
469 };
470 let status = if comp.regressed {
471 "REGRESSION".red().to_string()
472 } else if comp.improved_workflows > 0 {
473 "IMPROVED".green().to_string()
474 } else {
475 "UNCHANGED".yellow().to_string()
476 };
477 DatasetComparisonEntry {
478 alias: alias.clone(),
479 delta: colored_delta,
480 status,
481 }
482 })
483 .collect();
484 alias_entries.sort_by(|a, b| a.alias.cmp(&b.alias));
485
486 let mut alias_table = Table::new(alias_entries);
487 alias_table.with(Style::sharp());
488 alias_table.modify(
489 Rows::new(0..1),
490 (
491 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
492 Alignment::center(),
493 Color::BOLD,
494 ),
495 );
496 println!("{}", alias_table);
497 }
498
499 if !self.scenario_deltas.is_empty() {
501 println!("\n{}", "Scenario Comparison".truecolor(245, 77, 85).bold());
502
503 let entries: Vec<_> = self
504 .scenario_deltas
505 .iter()
506 .map(|d| {
507 let baseline_str = if d.baseline_passed {
508 "PASS".green().to_string()
509 } else {
510 "FAIL".red().to_string()
511 };
512 let current_str = if d.comparison_passed {
513 "PASS".green().to_string()
514 } else {
515 "FAIL".red().to_string()
516 };
517 let pr_delta = (d.comparison_pass_rate - d.baseline_pass_rate) * 100.0;
518 let pr_delta_str = format!("{:+.1}%", pr_delta);
519 let colored_pr_delta = if pr_delta > 1.0 {
520 pr_delta_str.green().to_string()
521 } else if pr_delta < -1.0 {
522 pr_delta_str.red().to_string()
523 } else {
524 pr_delta_str.yellow().to_string()
525 };
526 let change = match (d.baseline_passed, d.comparison_passed) {
527 (true, false) => "Pass -> Fail".red().bold().to_string(),
528 (false, true) => "Fail -> Pass".green().bold().to_string(),
529 _ => "-".to_string(),
530 };
531 ScenarioDeltaEntry {
532 scenario_id: d.scenario_id.chars().take(16).collect::<String>(),
533 baseline: baseline_str,
534 current: current_str,
535 pass_rate_delta: colored_pr_delta,
536 change,
537 }
538 })
539 .collect();
540
541 let mut delta_table = Table::new(entries);
542 delta_table.with(Style::sharp());
543 delta_table.modify(
544 Rows::new(0..1),
545 (
546 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
547 Alignment::center(),
548 Color::BOLD,
549 ),
550 );
551 println!("{}", delta_table);
552 }
553
554 if !self.new_scenarios.is_empty() {
556 println!(
557 "\n{} {}",
558 "New Scenarios:".truecolor(245, 77, 85).bold(),
559 self.new_scenarios.join(", ").green()
560 );
561 }
562 if !self.removed_scenarios.is_empty() {
563 println!(
564 "\n{} {}",
565 "Removed Scenarios:".truecolor(245, 77, 85).bold(),
566 self.removed_scenarios.join(", ").red()
567 );
568 }
569
570 let overall_status = if self.regressed {
572 "REGRESSION DETECTED".red().bold().to_string()
573 } else if !self.improved_aliases.is_empty() {
574 "IMPROVEMENT DETECTED".green().bold().to_string()
575 } else {
576 "NO SIGNIFICANT CHANGE".yellow().bold().to_string()
577 };
578
579 println!("\n{}", "Summary".truecolor(245, 77, 85).bold());
580 println!(" Overall Status: {}", overall_status);
581 println!(
582 " Baseline Pass Rate: {:.1}%",
583 self.baseline_overall_pass_rate * 100.0
584 );
585 println!(
586 " Current Pass Rate: {:.1}%",
587 self.comparison_overall_pass_rate * 100.0
588 );
589 if !self.regressed_aliases.is_empty() {
590 println!(
591 " Regressed aliases: {}",
592 self.regressed_aliases.join(", ").red()
593 );
594 }
595 if !self.improved_aliases.is_empty() {
596 println!(
597 " Improved aliases: {}",
598 self.improved_aliases.join(", ").green()
599 );
600 }
601 if !self.new_aliases.is_empty() {
602 println!(" New aliases: {}", self.new_aliases.join(", ").green());
603 }
604 if !self.removed_aliases.is_empty() {
605 println!(
606 " Removed aliases: {}",
607 self.removed_aliases.join(", ").red()
608 );
609 }
610 let changed_count = self
611 .scenario_deltas
612 .iter()
613 .filter(|d| d.status_changed)
614 .count();
615 if changed_count > 0 {
616 println!(" Scenarios changed: {}", changed_count);
617 }
618 if !self.new_scenarios.is_empty() {
619 println!(" New scenarios: {}", self.new_scenarios.len());
620 }
621 if !self.removed_scenarios.is_empty() {
622 println!(" Removed scenarios: {}", self.removed_scenarios.len());
623 }
624 }
625}
626
627#[derive(Debug, Clone, Serialize, Deserialize)]
628#[pyclass]
629pub struct ScenarioEvalResults {
630 pub dataset_results: HashMap<String, EvalResults>,
631
632 pub scenario_results: Vec<ScenarioResult>,
633
634 #[pyo3(get)]
635 pub metrics: EvalMetrics,
636}
637
638impl PartialEq for ScenarioEvalResults {
639 fn eq(&self, other: &Self) -> bool {
640 self.scenario_results == other.scenario_results && self.metrics == other.metrics
641 }
643}
644
645#[derive(Tabled)]
646struct ScenarioResultEntry {
647 #[tabled(rename = "Scenario ID")]
648 scenario_id: String,
649 #[tabled(rename = "Initial Query")]
650 initial_query: String,
651 #[tabled(rename = "Tasks")]
652 tasks: usize,
653 #[tabled(rename = "Passed")]
654 passed_count: usize,
655 #[tabled(rename = "Failed")]
656 failed_count: usize,
657 #[tabled(rename = "Pass Rate")]
658 pass_rate: String,
659 #[tabled(rename = "Status")]
660 status: String,
661}
662
663#[pymethods]
664impl ScenarioEvalResults {
665 pub fn __str__(&self) -> String {
666 PyHelperFuncs::__str__(self)
667 }
668
669 pub fn model_dump_json(&self) -> Result<String, EvaluationError> {
670 serde_json::to_string(self).map_err(Into::into)
671 }
672
673 #[staticmethod]
674 pub fn model_validate_json(json_string: String) -> Result<Self, EvaluationError> {
675 serde_json::from_str(&json_string).map_err(Into::into)
676 }
677
678 pub fn save(&self, path: &str) -> Result<(), EvaluationError> {
679 let json = serde_json::to_string_pretty(self)?;
680 std::fs::write(path, json)?;
681 Ok(())
682 }
683
684 #[staticmethod]
685 pub fn load(path: &str) -> Result<Self, EvaluationError> {
686 let json = std::fs::read_to_string(path)?;
687 serde_json::from_str(&json).map_err(Into::into)
688 }
689
690 #[getter]
691 pub fn dataset_results(&self) -> HashMap<String, EvalResults> {
692 self.dataset_results.clone()
693 }
694
695 #[getter]
696 pub fn scenario_results(&self) -> Vec<ScenarioResult> {
697 self.scenario_results.clone()
698 }
699
700 pub fn get_scenario_detail(
701 &self,
702 scenario_id: &str,
703 ) -> Result<ScenarioResult, EvaluationError> {
704 self.scenario_results
705 .iter()
706 .find(|r| r.scenario_id == scenario_id)
707 .cloned()
708 .ok_or_else(|| EvaluationError::MissingKeyError(scenario_id.to_string()))
709 }
710
711 #[pyo3(signature = (baseline, regression_threshold = 0.05))]
712 pub fn compare_to(
713 &self,
714 baseline: &ScenarioEvalResults,
715 regression_threshold: f64,
716 ) -> Result<ScenarioComparisonResults, EvaluationError> {
717 let mut dataset_comparisons = HashMap::new();
718 let mut improved_aliases = Vec::new();
719 let mut regressed_aliases = Vec::new();
720
721 for (alias, current_results) in &self.dataset_results {
722 if let Some(baseline_results) = baseline.dataset_results.get(alias) {
723 let comp =
724 compare_results(baseline_results, current_results, regression_threshold)?;
725 if comp.regressed {
726 regressed_aliases.push(alias.clone());
727 } else if comp.improved_workflows > 0 {
728 improved_aliases.push(alias.clone());
729 }
730 dataset_comparisons.insert(alias.clone(), comp);
731 }
732 }
733
734 let mut new_aliases: Vec<String> = self
736 .dataset_results
737 .keys()
738 .filter(|alias| !baseline.dataset_results.contains_key(*alias))
739 .cloned()
740 .collect();
741 new_aliases.sort();
742
743 let mut removed_aliases: Vec<String> = baseline
744 .dataset_results
745 .keys()
746 .filter(|alias| !self.dataset_results.contains_key(*alias))
747 .cloned()
748 .collect();
749 removed_aliases.sort();
750
751 let baseline_scenario_map: HashMap<_, _> = baseline
752 .scenario_results
753 .iter()
754 .map(|r| (r.scenario_id.as_str(), r))
755 .collect();
756
757 let current_scenario_map: HashMap<_, _> = self
758 .scenario_results
759 .iter()
760 .map(|r| (r.scenario_id.as_str(), r))
761 .collect();
762
763 let mut scenario_deltas = Vec::new();
764 for current in &self.scenario_results {
765 if let Some(base) = baseline_scenario_map.get(current.scenario_id.as_str()) {
766 scenario_deltas.push(ScenarioDelta {
767 scenario_id: current.scenario_id.clone(),
768 initial_query: current.initial_query.clone(),
769 baseline_passed: base.passed,
770 comparison_passed: current.passed,
771 baseline_pass_rate: base.pass_rate,
772 comparison_pass_rate: current.pass_rate,
773 status_changed: base.passed != current.passed,
774 });
775 }
776 }
777
778 let mut new_scenarios: Vec<String> = current_scenario_map
780 .keys()
781 .filter(|id| !baseline_scenario_map.contains_key(*id))
782 .map(|id| id.to_string())
783 .collect();
784 new_scenarios.sort();
785
786 let mut removed_scenarios: Vec<String> = baseline_scenario_map
787 .keys()
788 .filter(|id| !current_scenario_map.contains_key(*id))
789 .map(|id| id.to_string())
790 .collect();
791 removed_scenarios.sort();
792
793 let baseline_alias_pass_rates = baseline.metrics.dataset_pass_rates.clone();
795 let comparison_alias_pass_rates = self.metrics.dataset_pass_rates.clone();
796
797 Ok(ScenarioComparisonResults {
798 dataset_comparisons,
799 scenario_deltas,
800 baseline_overall_pass_rate: baseline.metrics.overall_pass_rate,
801 comparison_overall_pass_rate: self.metrics.overall_pass_rate,
802 regressed: !regressed_aliases.is_empty(),
803 improved_aliases,
804 regressed_aliases,
805 new_aliases,
806 removed_aliases,
807 new_scenarios,
808 removed_scenarios,
809 baseline_alias_pass_rates,
810 comparison_alias_pass_rates,
811 })
812 }
813
814 #[pyo3(signature = (show_datasets=false))]
815 pub fn as_table(&mut self, show_datasets: bool) {
816 self.metrics.as_table();
817
818 if !self.scenario_results.is_empty() {
819 println!("\n{}", "Scenario Results".truecolor(245, 77, 85).bold());
820
821 let entries: Vec<_> = self
822 .scenario_results
823 .iter()
824 .map(|r| {
825 let query = if r.initial_query.chars().count() > 40 {
826 format!(
827 "{}...",
828 r.initial_query.chars().take(40).collect::<String>()
829 )
830 } else {
831 r.initial_query.clone()
832 };
833 let status = if r.passed {
834 "✓ PASS".green().to_string()
835 } else {
836 "✗ FAIL".red().to_string()
837 };
838 let total = r.task_results.len();
839 let passed_count = r.task_results.iter().filter(|t| t.passed).count();
840 let failed_count = total - passed_count;
841 ScenarioResultEntry {
842 scenario_id: r.scenario_id.chars().take(16).collect::<String>(),
843 initial_query: query,
844 tasks: total,
845 passed_count,
846 failed_count,
847 pass_rate: format!("{:.1}%", r.pass_rate * 100.0),
848 status,
849 }
850 })
851 .collect();
852
853 let mut table = Table::new(entries);
854 table.with(Style::sharp());
855 table.modify(
856 Rows::new(0..1),
857 (
858 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
859 Alignment::center(),
860 Color::BOLD,
861 ),
862 );
863 println!("{}", table);
864 }
865
866 if show_datasets {
867 let mut aliases: Vec<_> = self.dataset_results.keys().cloned().collect();
868 aliases.sort();
869 for alias in aliases {
870 if let Some(eval_results) = self.dataset_results.get_mut(&alias) {
871 println!(
872 "\n{}",
873 format!("Dataset: {}", alias).truecolor(245, 77, 85).bold()
874 );
875 eval_results.as_table(false);
876 }
877 }
878 }
879 }
880}
881
882#[cfg(test)]
883mod tests {
884 use super::*;
885 use crate::evaluate::types::EvalResults;
886
887 fn empty_metrics(
888 overall: f64,
889 scenario_pass_rate: f64,
890 total: usize,
891 passed: usize,
892 ) -> EvalMetrics {
893 EvalMetrics {
894 overall_pass_rate: overall,
895 dataset_pass_rates: HashMap::new(),
896 scenario_pass_rate,
897 total_scenarios: total,
898 passed_scenarios: passed,
899 scenario_task_pass_rates: HashMap::new(),
900 }
901 }
902
903 fn make_scenario_result(id: &str, query: &str, passed: bool, pass_rate: f64) -> ScenarioResult {
904 ScenarioResult {
905 scenario_id: id.to_string(),
906 initial_query: query.to_string(),
907 eval_results: EvalResults::new(),
908 passed,
909 pass_rate,
910 task_results: vec![],
911 }
912 }
913
914 fn make_eval_results() -> ScenarioEvalResults {
915 ScenarioEvalResults {
916 dataset_results: HashMap::new(),
917 scenario_results: vec![
918 make_scenario_result("s1", "Make pasta", true, 1.0),
919 make_scenario_result("s2", "Make curry", false, 0.5),
920 ],
921 metrics: empty_metrics(0.75, 0.5, 2, 1),
922 }
923 }
924
925 #[test]
926 fn eval_metrics_fields() {
927 let m = empty_metrics(0.9, 0.85, 100, 85);
928 assert_eq!(m.overall_pass_rate, 0.9);
929 assert_eq!(m.scenario_pass_rate, 0.85);
930 assert_eq!(m.total_scenarios, 100);
931 assert_eq!(m.passed_scenarios, 85);
932 }
933
934 #[test]
935 fn scenario_result_fields() {
936 let r = make_scenario_result("id-1", "hello", true, 1.0);
937 assert_eq!(r.scenario_id, "id-1");
938 assert!(r.passed);
939 assert_eq!(r.pass_rate, 1.0);
940 }
941
942 #[test]
943 fn model_dump_json_roundtrip() {
944 let results = make_eval_results();
945 let json = results.model_dump_json().unwrap();
946 let loaded = ScenarioEvalResults::model_validate_json(json).unwrap();
947 assert_eq!(loaded.scenario_results.len(), 2);
948 assert_eq!(loaded.metrics.total_scenarios, 2);
949 assert_eq!(loaded.metrics.passed_scenarios, 1);
950 assert_eq!(loaded.metrics.overall_pass_rate, 0.75);
951 assert_eq!(loaded.scenario_results[0].scenario_id, "s1");
952 assert_eq!(loaded.scenario_results[1].scenario_id, "s2");
953 assert_eq!(loaded.scenario_results[0].pass_rate, 1.0);
954 assert_eq!(loaded.scenario_results[1].pass_rate, 0.5);
955 }
956
957 #[test]
958 fn compare_to_regression_detection() {
959 let baseline = make_eval_results();
960
961 let current = ScenarioEvalResults {
962 dataset_results: HashMap::new(),
963 scenario_results: vec![
964 make_scenario_result("s1", "Make pasta", false, 0.0),
965 make_scenario_result("s2", "Make curry", false, 0.5),
966 ],
967 metrics: empty_metrics(0.25, 0.0, 2, 0),
968 };
969
970 let comp = current.compare_to(&baseline, 0.05).unwrap();
971 assert_eq!(comp.scenario_deltas.len(), 2);
972
973 let s1 = comp
974 .scenario_deltas
975 .iter()
976 .find(|d| d.scenario_id == "s1")
977 .unwrap();
978 assert!(s1.status_changed);
979 assert!(s1.baseline_passed);
980 assert!(!s1.comparison_passed);
981
982 let s2 = comp
983 .scenario_deltas
984 .iter()
985 .find(|d| d.scenario_id == "s2")
986 .unwrap();
987 assert!(!s2.status_changed);
988 }
989
990 #[test]
991 fn compare_to_no_regression() {
992 let baseline = make_eval_results();
993 let current = make_eval_results();
994
995 let comp = current.compare_to(&baseline, 0.05).unwrap();
996 assert!(!comp.regressed);
997 assert!(comp.scenario_deltas.iter().all(|d| !d.status_changed));
998 }
999
1000 #[test]
1001 fn compare_to_with_dataset_results() {
1002 let mut baseline_dataset = HashMap::new();
1003 baseline_dataset.insert("agent_a".to_string(), EvalResults::new());
1004
1005 let mut current_dataset = HashMap::new();
1006 current_dataset.insert("agent_a".to_string(), EvalResults::new());
1007
1008 let baseline = ScenarioEvalResults {
1009 dataset_results: baseline_dataset,
1010 scenario_results: vec![make_scenario_result("s1", "query one", true, 1.0)],
1011 metrics: empty_metrics(1.0, 1.0, 1, 1),
1012 };
1013
1014 let current = ScenarioEvalResults {
1015 dataset_results: current_dataset,
1016 scenario_results: vec![make_scenario_result("s1", "query one", false, 0.0)],
1017 metrics: empty_metrics(0.0, 0.0, 1, 0),
1018 };
1019
1020 let comp = current.compare_to(&baseline, 0.05).unwrap();
1021
1022 assert!(comp.dataset_comparisons.contains_key("agent_a"));
1024
1025 let s1 = comp
1027 .scenario_deltas
1028 .iter()
1029 .find(|d| d.scenario_id == "s1")
1030 .unwrap();
1031 assert!(s1.status_changed);
1032 assert!(s1.baseline_passed);
1033 assert!(!s1.comparison_passed);
1034
1035 assert!(comp.improved_aliases.is_empty());
1038 }
1039
1040 #[test]
1041 fn get_scenario_detail_found() {
1042 let results = make_eval_results();
1043 let detail = results.get_scenario_detail("s1").unwrap();
1044 assert_eq!(detail.scenario_id, "s1");
1045 }
1046
1047 #[test]
1048 fn get_scenario_detail_missing() {
1049 let results = make_eval_results();
1050 assert!(results.get_scenario_detail("nonexistent").is_err());
1051 }
1052
1053 #[test]
1054 fn save_load_roundtrip() {
1055 let results = make_eval_results();
1056 let dir = tempfile::tempdir().unwrap();
1057 let path = dir.path().join("results.json");
1058 let path_str = path.to_str().unwrap();
1059
1060 results.save(path_str).unwrap();
1061 let loaded = ScenarioEvalResults::load(path_str).unwrap();
1062 assert_eq!(results, loaded);
1063 }
1064
1065 #[test]
1066 fn comparison_model_dump_json_roundtrip() {
1067 let comp = ScenarioComparisonResults {
1068 dataset_comparisons: HashMap::new(),
1069 scenario_deltas: vec![ScenarioDelta {
1070 scenario_id: "s1".to_string(),
1071 initial_query: "hello".to_string(),
1072 baseline_passed: true,
1073 comparison_passed: false,
1074 baseline_pass_rate: 1.0,
1075 comparison_pass_rate: 0.0,
1076 status_changed: true,
1077 }],
1078 baseline_overall_pass_rate: 1.0,
1079 comparison_overall_pass_rate: 0.5,
1080 regressed: true,
1081 improved_aliases: vec![],
1082 regressed_aliases: vec!["a".to_string()],
1083 new_aliases: vec!["c".to_string()],
1084 removed_aliases: vec!["b".to_string()],
1085 new_scenarios: vec!["s3".to_string()],
1086 removed_scenarios: vec![],
1087 baseline_alias_pass_rates: HashMap::from([("a".to_string(), 1.0)]),
1088 comparison_alias_pass_rates: HashMap::from([("a".to_string(), 0.5)]),
1089 };
1090
1091 let json = comp.model_dump_json().unwrap();
1092 let loaded = ScenarioComparisonResults::model_validate_json(json).unwrap();
1093 assert_eq!(comp, loaded);
1094 }
1095
1096 #[test]
1097 fn comparison_save_load_roundtrip() {
1098 let comp = ScenarioComparisonResults {
1099 dataset_comparisons: HashMap::new(),
1100 scenario_deltas: vec![],
1101 baseline_overall_pass_rate: 0.8,
1102 comparison_overall_pass_rate: 0.9,
1103 regressed: false,
1104 improved_aliases: vec!["x".to_string()],
1105 regressed_aliases: vec![],
1106 new_aliases: vec![],
1107 removed_aliases: vec![],
1108 new_scenarios: vec![],
1109 removed_scenarios: vec![],
1110 baseline_alias_pass_rates: HashMap::new(),
1111 comparison_alias_pass_rates: HashMap::new(),
1112 };
1113
1114 let dir = tempfile::tempdir().unwrap();
1115 let path = dir.path().join("comp.json");
1116 let path_str = path.to_str().unwrap();
1117
1118 comp.save(path_str).unwrap();
1119 let loaded = ScenarioComparisonResults::load(path_str).unwrap();
1120 assert_eq!(comp, loaded);
1121 }
1122
1123 #[test]
1124 fn compare_to_new_scenarios() {
1125 let baseline = make_eval_results(); let current = ScenarioEvalResults {
1127 dataset_results: HashMap::new(),
1128 scenario_results: vec![
1129 make_scenario_result("s1", "Make pasta", true, 1.0),
1130 make_scenario_result("s2", "Make curry", false, 0.5),
1131 make_scenario_result("s3", "Make salad", true, 0.8),
1132 ],
1133 metrics: empty_metrics(0.77, 0.67, 3, 2),
1134 };
1135
1136 let comp = current.compare_to(&baseline, 0.05).unwrap();
1137 assert_eq!(comp.new_scenarios, vec!["s3".to_string()]);
1138 assert!(comp.removed_scenarios.is_empty());
1139 assert_eq!(comp.scenario_deltas.len(), 2);
1141 }
1142
1143 #[test]
1144 fn compare_to_removed_scenarios() {
1145 let baseline = ScenarioEvalResults {
1146 dataset_results: HashMap::new(),
1147 scenario_results: vec![
1148 make_scenario_result("s1", "Make pasta", true, 1.0),
1149 make_scenario_result("s2", "Make curry", false, 0.5),
1150 make_scenario_result("s3", "Make salad", true, 0.8),
1151 ],
1152 metrics: empty_metrics(0.77, 0.67, 3, 2),
1153 };
1154 let current = make_eval_results(); let comp = current.compare_to(&baseline, 0.05).unwrap();
1157 assert_eq!(comp.removed_scenarios, vec!["s3".to_string()]);
1158 assert!(comp.new_scenarios.is_empty());
1159 }
1160
1161 #[test]
1162 fn compare_to_new_removed_aliases() {
1163 let mut baseline_datasets = HashMap::new();
1164 baseline_datasets.insert("a".to_string(), EvalResults::new());
1165 baseline_datasets.insert("b".to_string(), EvalResults::new());
1166
1167 let mut current_datasets = HashMap::new();
1168 current_datasets.insert("a".to_string(), EvalResults::new());
1169 current_datasets.insert("c".to_string(), EvalResults::new());
1170
1171 let baseline = ScenarioEvalResults {
1172 dataset_results: baseline_datasets,
1173 scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1174 metrics: empty_metrics(1.0, 1.0, 1, 1),
1175 };
1176
1177 let current = ScenarioEvalResults {
1178 dataset_results: current_datasets,
1179 scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1180 metrics: empty_metrics(1.0, 1.0, 1, 1),
1181 };
1182
1183 let comp = current.compare_to(&baseline, 0.05).unwrap();
1184 assert_eq!(comp.new_aliases, vec!["c".to_string()]);
1185 assert_eq!(comp.removed_aliases, vec!["b".to_string()]);
1186 }
1187
1188 #[test]
1189 fn compare_to_alias_pass_rates() {
1190 let mut baseline_metrics = empty_metrics(0.9, 1.0, 1, 1);
1191 baseline_metrics
1192 .dataset_pass_rates
1193 .insert("agent_a".to_string(), 0.9);
1194 baseline_metrics
1195 .dataset_pass_rates
1196 .insert("agent_b".to_string(), 0.8);
1197
1198 let mut current_metrics = empty_metrics(0.85, 1.0, 1, 1);
1199 current_metrics
1200 .dataset_pass_rates
1201 .insert("agent_a".to_string(), 0.85);
1202 current_metrics
1203 .dataset_pass_rates
1204 .insert("agent_b".to_string(), 0.75);
1205
1206 let baseline = ScenarioEvalResults {
1207 dataset_results: HashMap::new(),
1208 scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1209 metrics: baseline_metrics,
1210 };
1211
1212 let current = ScenarioEvalResults {
1213 dataset_results: HashMap::new(),
1214 scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1215 metrics: current_metrics,
1216 };
1217
1218 let comp = current.compare_to(&baseline, 0.05).unwrap();
1219 assert_eq!(comp.baseline_alias_pass_rates.get("agent_a"), Some(&0.9));
1220 assert_eq!(comp.baseline_alias_pass_rates.get("agent_b"), Some(&0.8));
1221 assert_eq!(comp.comparison_alias_pass_rates.get("agent_a"), Some(&0.85));
1222 assert_eq!(comp.comparison_alias_pass_rates.get("agent_b"), Some(&0.75));
1223 }
1224}