Skip to main content

scouter_evaluate/evaluate/
scenario_results.rs

1use crate::error::EvaluationError;
2use crate::evaluate::compare::compare_results;
3use crate::evaluate::types::{ComparisonResults, EvalResults};
4use owo_colors::OwoColorize;
5use potato_head::PyHelperFuncs;
6use pyo3::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use tabled::Tabled;
10use tabled::{
11    settings::{object::Rows, Alignment, Color, Format, Style},
12    Table,
13};
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16#[pyclass]
17pub struct EvalMetrics {
18    #[pyo3(get)]
19    pub overall_pass_rate: f64,
20
21    // HashMap<K,V> fields cannot use #[pyo3(get)]; exposed via the getter method below
22    pub dataset_pass_rates: HashMap<String, f64>,
23
24    #[pyo3(get)]
25    pub scenario_pass_rate: f64,
26
27    #[pyo3(get)]
28    pub total_scenarios: usize,
29
30    #[pyo3(get)]
31    pub passed_scenarios: usize,
32}
33
34#[derive(Tabled)]
35struct MetricEntry {
36    #[tabled(rename = "Metric")]
37    metric: String,
38    #[tabled(rename = "Value")]
39    value: String,
40}
41
42#[derive(Tabled)]
43struct DatasetPassRateEntry {
44    #[tabled(rename = "Alias")]
45    alias: String,
46    #[tabled(rename = "Pass Rate")]
47    pass_rate: String,
48}
49
50#[pymethods]
51impl EvalMetrics {
52    pub fn __str__(&self) -> String {
53        PyHelperFuncs::__str__(self)
54    }
55
56    #[getter]
57    pub fn dataset_pass_rates(&self) -> HashMap<String, f64> {
58        self.dataset_pass_rates.clone()
59    }
60
61    pub fn as_table(&self) {
62        println!("\n{}", "Aggregate Metrics".truecolor(245, 77, 85).bold());
63
64        let entries = vec![
65            MetricEntry {
66                metric: "Overall Pass Rate".to_string(),
67                value: format!("{:.1}%", self.overall_pass_rate * 100.0),
68            },
69            MetricEntry {
70                metric: "Scenario Pass Rate".to_string(),
71                value: format!("{:.1}%", self.scenario_pass_rate * 100.0),
72            },
73            MetricEntry {
74                metric: "Total Scenarios".to_string(),
75                value: self.total_scenarios.to_string(),
76            },
77            MetricEntry {
78                metric: "Passed Scenarios".to_string(),
79                value: self.passed_scenarios.to_string(),
80            },
81        ];
82
83        let mut table = Table::new(entries);
84        table.with(Style::sharp());
85        table.modify(
86            Rows::new(0..1),
87            (
88                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
89                Alignment::center(),
90                Color::BOLD,
91            ),
92        );
93        println!("{}", table);
94
95        if !self.dataset_pass_rates.is_empty() {
96            println!("\n{}", "Sub-Agent Pass Rates".truecolor(245, 77, 85).bold());
97
98            let mut alias_entries: Vec<_> = self
99                .dataset_pass_rates
100                .iter()
101                .map(|(alias, rate)| DatasetPassRateEntry {
102                    alias: alias.clone(),
103                    pass_rate: format!("{:.1}%", rate * 100.0),
104                })
105                .collect();
106            alias_entries.sort_by(|a, b| a.alias.cmp(&b.alias));
107
108            let mut alias_table = Table::new(alias_entries);
109            alias_table.with(Style::sharp());
110            alias_table.modify(
111                Rows::new(0..1),
112                (
113                    Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
114                    Alignment::center(),
115                    Color::BOLD,
116                ),
117            );
118            println!("{}", alias_table);
119        }
120    }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
124#[pyclass]
125pub struct ScenarioResult {
126    #[pyo3(get)]
127    pub scenario_id: String,
128
129    #[pyo3(get)]
130    pub initial_query: String,
131
132    pub eval_results: EvalResults,
133
134    #[pyo3(get)]
135    pub passed: bool,
136
137    #[pyo3(get)]
138    pub pass_rate: f64,
139}
140
141impl PartialEq for ScenarioResult {
142    fn eq(&self, other: &Self) -> bool {
143        self.scenario_id == other.scenario_id
144            && self.initial_query == other.initial_query
145            && self.passed == other.passed
146            && self.pass_rate == other.pass_rate
147    }
148}
149
150#[pymethods]
151impl ScenarioResult {
152    pub fn __str__(&self) -> String {
153        PyHelperFuncs::__str__(self)
154    }
155
156    #[getter]
157    pub fn eval_results(&self) -> EvalResults {
158        self.eval_results.clone()
159    }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
163#[pyclass]
164pub struct ScenarioDelta {
165    #[pyo3(get)]
166    pub scenario_id: String,
167
168    #[pyo3(get)]
169    pub initial_query: String,
170
171    #[pyo3(get)]
172    pub baseline_passed: bool,
173
174    #[pyo3(get)]
175    pub comparison_passed: bool,
176
177    #[pyo3(get)]
178    pub baseline_pass_rate: f64,
179
180    #[pyo3(get)]
181    pub comparison_pass_rate: f64,
182
183    #[pyo3(get)]
184    pub status_changed: bool,
185}
186
187#[pymethods]
188impl ScenarioDelta {
189    pub fn __str__(&self) -> String {
190        PyHelperFuncs::__str__(self)
191    }
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195#[pyclass]
196pub struct ScenarioComparisonResults {
197    pub dataset_comparisons: HashMap<String, ComparisonResults>,
198
199    #[pyo3(get)]
200    pub scenario_deltas: Vec<ScenarioDelta>,
201
202    #[pyo3(get)]
203    pub baseline_overall_pass_rate: f64,
204
205    #[pyo3(get)]
206    pub comparison_overall_pass_rate: f64,
207
208    #[pyo3(get)]
209    pub regressed: bool,
210
211    #[pyo3(get)]
212    pub improved_aliases: Vec<String>,
213
214    #[pyo3(get)]
215    pub regressed_aliases: Vec<String>,
216
217    #[pyo3(get)]
218    #[serde(default)]
219    pub new_aliases: Vec<String>,
220
221    #[pyo3(get)]
222    #[serde(default)]
223    pub removed_aliases: Vec<String>,
224
225    #[pyo3(get)]
226    #[serde(default)]
227    pub new_scenarios: Vec<String>,
228
229    #[pyo3(get)]
230    #[serde(default)]
231    pub removed_scenarios: Vec<String>,
232
233    #[serde(default)]
234    pub baseline_alias_pass_rates: HashMap<String, f64>,
235
236    #[serde(default)]
237    pub comparison_alias_pass_rates: HashMap<String, f64>,
238}
239
240impl PartialEq for ScenarioComparisonResults {
241    fn eq(&self, other: &Self) -> bool {
242        self.scenario_deltas == other.scenario_deltas
243            && self.baseline_overall_pass_rate == other.baseline_overall_pass_rate
244            && self.comparison_overall_pass_rate == other.comparison_overall_pass_rate
245            && self.regressed == other.regressed
246            && self.improved_aliases == other.improved_aliases
247            && self.regressed_aliases == other.regressed_aliases
248            && self.new_aliases == other.new_aliases
249            && self.removed_aliases == other.removed_aliases
250            && self.new_scenarios == other.new_scenarios
251            && self.removed_scenarios == other.removed_scenarios
252            && self.baseline_alias_pass_rates == other.baseline_alias_pass_rates
253            && self.comparison_alias_pass_rates == other.comparison_alias_pass_rates
254        // dataset_comparisons excluded: ComparisonResults does not implement PartialEq
255    }
256}
257
258#[derive(Tabled)]
259struct DatasetComparisonEntry {
260    #[tabled(rename = "Alias")]
261    alias: String,
262    #[tabled(rename = "Delta")]
263    delta: String,
264    #[tabled(rename = "Status")]
265    status: String,
266}
267
268#[derive(Tabled)]
269struct ScenarioDeltaEntry {
270    #[tabled(rename = "Scenario ID")]
271    scenario_id: String,
272    #[tabled(rename = "Baseline")]
273    baseline: String,
274    #[tabled(rename = "Current")]
275    current: String,
276    #[tabled(rename = "Pass Rate Δ")]
277    pass_rate_delta: String,
278    #[tabled(rename = "Change")]
279    change: String,
280}
281
282#[derive(Tabled)]
283struct AliasPassRateEntry {
284    #[tabled(rename = "Alias")]
285    alias: String,
286    #[tabled(rename = "Baseline")]
287    baseline: String,
288    #[tabled(rename = "Current")]
289    current: String,
290    #[tabled(rename = "Delta")]
291    delta: String,
292    #[tabled(rename = "Status")]
293    status: String,
294}
295
296#[pymethods]
297impl ScenarioComparisonResults {
298    pub fn __str__(&self) -> String {
299        PyHelperFuncs::__str__(self)
300    }
301
302    #[getter]
303    pub fn dataset_comparisons(&self) -> HashMap<String, ComparisonResults> {
304        self.dataset_comparisons.clone()
305    }
306
307    #[getter]
308    pub fn baseline_alias_pass_rates(&self) -> HashMap<String, f64> {
309        self.baseline_alias_pass_rates.clone()
310    }
311
312    #[getter]
313    pub fn comparison_alias_pass_rates(&self) -> HashMap<String, f64> {
314        self.comparison_alias_pass_rates.clone()
315    }
316
317    pub fn model_dump_json(&self) -> Result<String, EvaluationError> {
318        serde_json::to_string(self).map_err(Into::into)
319    }
320
321    #[staticmethod]
322    pub fn model_validate_json(json_string: String) -> Result<Self, EvaluationError> {
323        serde_json::from_str(&json_string).map_err(Into::into)
324    }
325
326    pub fn save(&self, path: &str) -> Result<(), EvaluationError> {
327        let json = serde_json::to_string_pretty(self)?;
328        std::fs::write(path, json)?;
329        Ok(())
330    }
331
332    #[staticmethod]
333    pub fn load(path: &str) -> Result<Self, EvaluationError> {
334        let json = std::fs::read_to_string(path)?;
335        serde_json::from_str(&json).map_err(Into::into)
336    }
337
338    pub fn as_table(&self) {
339        // Sub-Agent Comparison with pass rates
340        if !self.baseline_alias_pass_rates.is_empty()
341            || !self.comparison_alias_pass_rates.is_empty()
342        {
343            println!("\n{}", "Sub-Agent Comparison".truecolor(245, 77, 85).bold());
344
345            let mut all_aliases: Vec<String> = self
346                .baseline_alias_pass_rates
347                .keys()
348                .chain(self.comparison_alias_pass_rates.keys())
349                .cloned()
350                .collect();
351            all_aliases.sort();
352            all_aliases.dedup();
353
354            let mut alias_entries: Vec<AliasPassRateEntry> = Vec::new();
355            for alias in &all_aliases {
356                let baseline_rate = self.baseline_alias_pass_rates.get(alias);
357                let current_rate = self.comparison_alias_pass_rates.get(alias);
358
359                let (baseline_str, current_str, delta_str, status) =
360                    match (baseline_rate, current_rate) {
361                        (Some(b), Some(c)) => {
362                            let delta = (c - b) * 100.0;
363                            let d = format!("{:+.1}%", delta);
364                            let colored_delta = if delta > 1.0 {
365                                d.green().to_string()
366                            } else if delta < -1.0 {
367                                d.red().to_string()
368                            } else {
369                                d.yellow().to_string()
370                            };
371                            let s = if self.regressed_aliases.contains(alias) {
372                                "REGRESSION".red().to_string()
373                            } else if self.improved_aliases.contains(alias) {
374                                "IMPROVED".green().to_string()
375                            } else {
376                                "UNCHANGED".yellow().to_string()
377                            };
378                            (
379                                format!("{:.1}%", b * 100.0),
380                                format!("{:.1}%", c * 100.0),
381                                colored_delta,
382                                s,
383                            )
384                        }
385                        (None, Some(c)) => (
386                            "-".to_string(),
387                            format!("{:.1}%", c * 100.0),
388                            "-".to_string(),
389                            "NEW".green().bold().to_string(),
390                        ),
391                        (Some(b), None) => (
392                            format!("{:.1}%", b * 100.0),
393                            "-".to_string(),
394                            "-".to_string(),
395                            "REMOVED".red().bold().to_string(),
396                        ),
397                        (None, None) => continue,
398                    };
399
400                alias_entries.push(AliasPassRateEntry {
401                    alias: alias.clone(),
402                    baseline: baseline_str,
403                    current: current_str,
404                    delta: delta_str,
405                    status,
406                });
407            }
408
409            let mut alias_table = Table::new(alias_entries);
410            alias_table.with(Style::sharp());
411            alias_table.modify(
412                Rows::new(0..1),
413                (
414                    Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
415                    Alignment::center(),
416                    Color::BOLD,
417                ),
418            );
419            println!("{}", alias_table);
420        } else if !self.dataset_comparisons.is_empty() {
421            // Fallback for old data without alias pass rates
422            println!("\n{}", "Sub-Agent Comparison".truecolor(245, 77, 85).bold());
423
424            let mut alias_entries: Vec<_> = self
425                .dataset_comparisons
426                .iter()
427                .map(|(alias, comp)| {
428                    let delta = comp.mean_pass_rate_delta * 100.0;
429                    let delta_str = format!("{:+.1}%", delta);
430                    let colored_delta = if delta > 1.0 {
431                        delta_str.green().to_string()
432                    } else if delta < -1.0 {
433                        delta_str.red().to_string()
434                    } else {
435                        delta_str.yellow().to_string()
436                    };
437                    let status = if comp.regressed {
438                        "REGRESSION".red().to_string()
439                    } else if comp.improved_workflows > 0 {
440                        "IMPROVED".green().to_string()
441                    } else {
442                        "UNCHANGED".yellow().to_string()
443                    };
444                    DatasetComparisonEntry {
445                        alias: alias.clone(),
446                        delta: colored_delta,
447                        status,
448                    }
449                })
450                .collect();
451            alias_entries.sort_by(|a, b| a.alias.cmp(&b.alias));
452
453            let mut alias_table = Table::new(alias_entries);
454            alias_table.with(Style::sharp());
455            alias_table.modify(
456                Rows::new(0..1),
457                (
458                    Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
459                    Alignment::center(),
460                    Color::BOLD,
461                ),
462            );
463            println!("{}", alias_table);
464        }
465
466        // All Scenarios table (not just changed ones)
467        if !self.scenario_deltas.is_empty() {
468            println!("\n{}", "Scenario Comparison".truecolor(245, 77, 85).bold());
469
470            let entries: Vec<_> = self
471                .scenario_deltas
472                .iter()
473                .map(|d| {
474                    let baseline_str = if d.baseline_passed {
475                        "PASS".green().to_string()
476                    } else {
477                        "FAIL".red().to_string()
478                    };
479                    let current_str = if d.comparison_passed {
480                        "PASS".green().to_string()
481                    } else {
482                        "FAIL".red().to_string()
483                    };
484                    let pr_delta = (d.comparison_pass_rate - d.baseline_pass_rate) * 100.0;
485                    let pr_delta_str = format!("{:+.1}%", pr_delta);
486                    let colored_pr_delta = if pr_delta > 1.0 {
487                        pr_delta_str.green().to_string()
488                    } else if pr_delta < -1.0 {
489                        pr_delta_str.red().to_string()
490                    } else {
491                        pr_delta_str.yellow().to_string()
492                    };
493                    let change = match (d.baseline_passed, d.comparison_passed) {
494                        (true, false) => "Pass -> Fail".red().bold().to_string(),
495                        (false, true) => "Fail -> Pass".green().bold().to_string(),
496                        _ => "-".to_string(),
497                    };
498                    ScenarioDeltaEntry {
499                        scenario_id: d.scenario_id.chars().take(16).collect::<String>(),
500                        baseline: baseline_str,
501                        current: current_str,
502                        pass_rate_delta: colored_pr_delta,
503                        change,
504                    }
505                })
506                .collect();
507
508            let mut delta_table = Table::new(entries);
509            delta_table.with(Style::sharp());
510            delta_table.modify(
511                Rows::new(0..1),
512                (
513                    Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
514                    Alignment::center(),
515                    Color::BOLD,
516                ),
517            );
518            println!("{}", delta_table);
519        }
520
521        // New/Removed Scenarios
522        if !self.new_scenarios.is_empty() {
523            println!(
524                "\n{}  {}",
525                "New Scenarios:".truecolor(245, 77, 85).bold(),
526                self.new_scenarios.join(", ").green()
527            );
528        }
529        if !self.removed_scenarios.is_empty() {
530            println!(
531                "\n{}  {}",
532                "Removed Scenarios:".truecolor(245, 77, 85).bold(),
533                self.removed_scenarios.join(", ").red()
534            );
535        }
536
537        // Summary
538        let overall_status = if self.regressed {
539            "REGRESSION DETECTED".red().bold().to_string()
540        } else if !self.improved_aliases.is_empty() {
541            "IMPROVEMENT DETECTED".green().bold().to_string()
542        } else {
543            "NO SIGNIFICANT CHANGE".yellow().bold().to_string()
544        };
545
546        println!("\n{}", "Summary".truecolor(245, 77, 85).bold());
547        println!("  Overall Status: {}", overall_status);
548        println!(
549            "  Baseline Pass Rate: {:.1}%",
550            self.baseline_overall_pass_rate * 100.0
551        );
552        println!(
553            "  Current Pass Rate:  {:.1}%",
554            self.comparison_overall_pass_rate * 100.0
555        );
556        if !self.regressed_aliases.is_empty() {
557            println!(
558                "  Regressed aliases: {}",
559                self.regressed_aliases.join(", ").red()
560            );
561        }
562        if !self.improved_aliases.is_empty() {
563            println!(
564                "  Improved aliases: {}",
565                self.improved_aliases.join(", ").green()
566            );
567        }
568        if !self.new_aliases.is_empty() {
569            println!("  New aliases: {}", self.new_aliases.join(", ").green());
570        }
571        if !self.removed_aliases.is_empty() {
572            println!(
573                "  Removed aliases: {}",
574                self.removed_aliases.join(", ").red()
575            );
576        }
577        let changed_count = self
578            .scenario_deltas
579            .iter()
580            .filter(|d| d.status_changed)
581            .count();
582        if changed_count > 0 {
583            println!("  Scenarios changed: {}", changed_count);
584        }
585        if !self.new_scenarios.is_empty() {
586            println!("  New scenarios: {}", self.new_scenarios.len());
587        }
588        if !self.removed_scenarios.is_empty() {
589            println!("  Removed scenarios: {}", self.removed_scenarios.len());
590        }
591    }
592}
593
594#[derive(Debug, Clone, Serialize, Deserialize)]
595#[pyclass]
596pub struct ScenarioEvalResults {
597    pub dataset_results: HashMap<String, EvalResults>,
598
599    pub scenario_results: Vec<ScenarioResult>,
600
601    #[pyo3(get)]
602    pub metrics: EvalMetrics,
603}
604
605impl PartialEq for ScenarioEvalResults {
606    fn eq(&self, other: &Self) -> bool {
607        self.scenario_results == other.scenario_results && self.metrics == other.metrics
608        // dataset_results excluded: EvalResults does not implement PartialEq
609    }
610}
611
612#[derive(Tabled)]
613struct ScenarioResultEntry {
614    #[tabled(rename = "Scenario ID")]
615    scenario_id: String,
616    #[tabled(rename = "Initial Query")]
617    initial_query: String,
618    #[tabled(rename = "Pass Rate")]
619    pass_rate: String,
620    #[tabled(rename = "Status")]
621    status: String,
622}
623
624#[pymethods]
625impl ScenarioEvalResults {
626    pub fn __str__(&self) -> String {
627        PyHelperFuncs::__str__(self)
628    }
629
630    pub fn model_dump_json(&self) -> Result<String, EvaluationError> {
631        serde_json::to_string(self).map_err(Into::into)
632    }
633
634    #[staticmethod]
635    pub fn model_validate_json(json_string: String) -> Result<Self, EvaluationError> {
636        serde_json::from_str(&json_string).map_err(Into::into)
637    }
638
639    pub fn save(&self, path: &str) -> Result<(), EvaluationError> {
640        let json = serde_json::to_string_pretty(self)?;
641        std::fs::write(path, json)?;
642        Ok(())
643    }
644
645    #[staticmethod]
646    pub fn load(path: &str) -> Result<Self, EvaluationError> {
647        let json = std::fs::read_to_string(path)?;
648        serde_json::from_str(&json).map_err(Into::into)
649    }
650
651    #[getter]
652    pub fn dataset_results(&self) -> HashMap<String, EvalResults> {
653        self.dataset_results.clone()
654    }
655
656    #[getter]
657    pub fn scenario_results(&self) -> Vec<ScenarioResult> {
658        self.scenario_results.clone()
659    }
660
661    pub fn get_scenario_detail(
662        &self,
663        scenario_id: &str,
664    ) -> Result<ScenarioResult, EvaluationError> {
665        self.scenario_results
666            .iter()
667            .find(|r| r.scenario_id == scenario_id)
668            .cloned()
669            .ok_or_else(|| EvaluationError::MissingKeyError(scenario_id.to_string()))
670    }
671
672    #[pyo3(signature = (baseline, regression_threshold = 0.05))]
673    pub fn compare_to(
674        &self,
675        baseline: &ScenarioEvalResults,
676        regression_threshold: f64,
677    ) -> Result<ScenarioComparisonResults, EvaluationError> {
678        let mut dataset_comparisons = HashMap::new();
679        let mut improved_aliases = Vec::new();
680        let mut regressed_aliases = Vec::new();
681
682        for (alias, current_results) in &self.dataset_results {
683            if let Some(baseline_results) = baseline.dataset_results.get(alias) {
684                let comp =
685                    compare_results(baseline_results, current_results, regression_threshold)?;
686                if comp.regressed {
687                    regressed_aliases.push(alias.clone());
688                } else if comp.improved_workflows > 0 {
689                    improved_aliases.push(alias.clone());
690                }
691                dataset_comparisons.insert(alias.clone(), comp);
692            }
693        }
694
695        // Detect new/removed aliases
696        let mut new_aliases: Vec<String> = self
697            .dataset_results
698            .keys()
699            .filter(|alias| !baseline.dataset_results.contains_key(*alias))
700            .cloned()
701            .collect();
702        new_aliases.sort();
703
704        let mut removed_aliases: Vec<String> = baseline
705            .dataset_results
706            .keys()
707            .filter(|alias| !self.dataset_results.contains_key(*alias))
708            .cloned()
709            .collect();
710        removed_aliases.sort();
711
712        let baseline_scenario_map: HashMap<_, _> = baseline
713            .scenario_results
714            .iter()
715            .map(|r| (r.scenario_id.as_str(), r))
716            .collect();
717
718        let current_scenario_map: HashMap<_, _> = self
719            .scenario_results
720            .iter()
721            .map(|r| (r.scenario_id.as_str(), r))
722            .collect();
723
724        let mut scenario_deltas = Vec::new();
725        for current in &self.scenario_results {
726            if let Some(base) = baseline_scenario_map.get(current.scenario_id.as_str()) {
727                scenario_deltas.push(ScenarioDelta {
728                    scenario_id: current.scenario_id.clone(),
729                    initial_query: current.initial_query.clone(),
730                    baseline_passed: base.passed,
731                    comparison_passed: current.passed,
732                    baseline_pass_rate: base.pass_rate,
733                    comparison_pass_rate: current.pass_rate,
734                    status_changed: base.passed != current.passed,
735                });
736            }
737        }
738
739        // Detect new/removed scenarios
740        let mut new_scenarios: Vec<String> = current_scenario_map
741            .keys()
742            .filter(|id| !baseline_scenario_map.contains_key(*id))
743            .map(|id| id.to_string())
744            .collect();
745        new_scenarios.sort();
746
747        let mut removed_scenarios: Vec<String> = baseline_scenario_map
748            .keys()
749            .filter(|id| !current_scenario_map.contains_key(*id))
750            .map(|id| id.to_string())
751            .collect();
752        removed_scenarios.sort();
753
754        // Per-alias pass rates from metrics
755        let baseline_alias_pass_rates = baseline.metrics.dataset_pass_rates.clone();
756        let comparison_alias_pass_rates = self.metrics.dataset_pass_rates.clone();
757
758        Ok(ScenarioComparisonResults {
759            dataset_comparisons,
760            scenario_deltas,
761            baseline_overall_pass_rate: baseline.metrics.overall_pass_rate,
762            comparison_overall_pass_rate: self.metrics.overall_pass_rate,
763            regressed: !regressed_aliases.is_empty(),
764            improved_aliases,
765            regressed_aliases,
766            new_aliases,
767            removed_aliases,
768            new_scenarios,
769            removed_scenarios,
770            baseline_alias_pass_rates,
771            comparison_alias_pass_rates,
772        })
773    }
774
775    pub fn as_table(&self) {
776        self.metrics.as_table();
777
778        if !self.scenario_results.is_empty() {
779            println!("\n{}", "Scenario Results".truecolor(245, 77, 85).bold());
780
781            let entries: Vec<_> = self
782                .scenario_results
783                .iter()
784                .map(|r| {
785                    let query = if r.initial_query.chars().count() > 40 {
786                        format!(
787                            "{}...",
788                            r.initial_query.chars().take(40).collect::<String>()
789                        )
790                    } else {
791                        r.initial_query.clone()
792                    };
793                    let status = if r.passed {
794                        "✓ PASS".green().to_string()
795                    } else {
796                        "✗ FAIL".red().to_string()
797                    };
798                    ScenarioResultEntry {
799                        scenario_id: r.scenario_id.chars().take(16).collect::<String>(),
800                        initial_query: query,
801                        pass_rate: format!("{:.1}%", r.pass_rate * 100.0),
802                        status,
803                    }
804                })
805                .collect();
806
807            let mut table = Table::new(entries);
808            table.with(Style::sharp());
809            table.modify(
810                Rows::new(0..1),
811                (
812                    Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
813                    Alignment::center(),
814                    Color::BOLD,
815                ),
816            );
817            println!("{}", table);
818        }
819    }
820}
821
822#[cfg(test)]
823mod tests {
824    use super::*;
825    use crate::evaluate::types::EvalResults;
826
827    fn empty_metrics(
828        overall: f64,
829        scenario_pass_rate: f64,
830        total: usize,
831        passed: usize,
832    ) -> EvalMetrics {
833        EvalMetrics {
834            overall_pass_rate: overall,
835            dataset_pass_rates: HashMap::new(),
836            scenario_pass_rate,
837            total_scenarios: total,
838            passed_scenarios: passed,
839        }
840    }
841
842    fn make_scenario_result(id: &str, query: &str, passed: bool, pass_rate: f64) -> ScenarioResult {
843        ScenarioResult {
844            scenario_id: id.to_string(),
845            initial_query: query.to_string(),
846            eval_results: EvalResults::new(),
847            passed,
848            pass_rate,
849        }
850    }
851
852    fn make_eval_results() -> ScenarioEvalResults {
853        ScenarioEvalResults {
854            dataset_results: HashMap::new(),
855            scenario_results: vec![
856                make_scenario_result("s1", "Make pasta", true, 1.0),
857                make_scenario_result("s2", "Make curry", false, 0.5),
858            ],
859            metrics: empty_metrics(0.75, 0.5, 2, 1),
860        }
861    }
862
863    #[test]
864    fn eval_metrics_fields() {
865        let m = empty_metrics(0.9, 0.85, 100, 85);
866        assert_eq!(m.overall_pass_rate, 0.9);
867        assert_eq!(m.scenario_pass_rate, 0.85);
868        assert_eq!(m.total_scenarios, 100);
869        assert_eq!(m.passed_scenarios, 85);
870    }
871
872    #[test]
873    fn scenario_result_fields() {
874        let r = make_scenario_result("id-1", "hello", true, 1.0);
875        assert_eq!(r.scenario_id, "id-1");
876        assert!(r.passed);
877        assert_eq!(r.pass_rate, 1.0);
878    }
879
880    #[test]
881    fn model_dump_json_roundtrip() {
882        let results = make_eval_results();
883        let json = results.model_dump_json().unwrap();
884        let loaded = ScenarioEvalResults::model_validate_json(json).unwrap();
885        assert_eq!(loaded.scenario_results.len(), 2);
886        assert_eq!(loaded.metrics.total_scenarios, 2);
887        assert_eq!(loaded.metrics.passed_scenarios, 1);
888        assert_eq!(loaded.metrics.overall_pass_rate, 0.75);
889        assert_eq!(loaded.scenario_results[0].scenario_id, "s1");
890        assert_eq!(loaded.scenario_results[1].scenario_id, "s2");
891        assert_eq!(loaded.scenario_results[0].pass_rate, 1.0);
892        assert_eq!(loaded.scenario_results[1].pass_rate, 0.5);
893    }
894
895    #[test]
896    fn compare_to_regression_detection() {
897        let baseline = make_eval_results();
898
899        let current = ScenarioEvalResults {
900            dataset_results: HashMap::new(),
901            scenario_results: vec![
902                make_scenario_result("s1", "Make pasta", false, 0.0),
903                make_scenario_result("s2", "Make curry", false, 0.5),
904            ],
905            metrics: empty_metrics(0.25, 0.0, 2, 0),
906        };
907
908        let comp = current.compare_to(&baseline, 0.05).unwrap();
909        assert_eq!(comp.scenario_deltas.len(), 2);
910
911        let s1 = comp
912            .scenario_deltas
913            .iter()
914            .find(|d| d.scenario_id == "s1")
915            .unwrap();
916        assert!(s1.status_changed);
917        assert!(s1.baseline_passed);
918        assert!(!s1.comparison_passed);
919
920        let s2 = comp
921            .scenario_deltas
922            .iter()
923            .find(|d| d.scenario_id == "s2")
924            .unwrap();
925        assert!(!s2.status_changed);
926    }
927
928    #[test]
929    fn compare_to_no_regression() {
930        let baseline = make_eval_results();
931        let current = make_eval_results();
932
933        let comp = current.compare_to(&baseline, 0.05).unwrap();
934        assert!(!comp.regressed);
935        assert!(comp.scenario_deltas.iter().all(|d| !d.status_changed));
936    }
937
938    #[test]
939    fn compare_to_with_dataset_results() {
940        let mut baseline_dataset = HashMap::new();
941        baseline_dataset.insert("agent_a".to_string(), EvalResults::new());
942
943        let mut current_dataset = HashMap::new();
944        current_dataset.insert("agent_a".to_string(), EvalResults::new());
945
946        let baseline = ScenarioEvalResults {
947            dataset_results: baseline_dataset,
948            scenario_results: vec![make_scenario_result("s1", "query one", true, 1.0)],
949            metrics: empty_metrics(1.0, 1.0, 1, 1),
950        };
951
952        let current = ScenarioEvalResults {
953            dataset_results: current_dataset,
954            scenario_results: vec![make_scenario_result("s1", "query one", false, 0.0)],
955            metrics: empty_metrics(0.0, 0.0, 1, 0),
956        };
957
958        let comp = current.compare_to(&baseline, 0.05).unwrap();
959
960        // dataset_comparisons should be populated for agent_a
961        assert!(comp.dataset_comparisons.contains_key("agent_a"));
962
963        // scenario delta for s1 should show regression
964        let s1 = comp
965            .scenario_deltas
966            .iter()
967            .find(|d| d.scenario_id == "s1")
968            .unwrap();
969        assert!(s1.status_changed);
970        assert!(s1.baseline_passed);
971        assert!(!s1.comparison_passed);
972
973        // regressed flag reflects true state
974        // no dataset-level regression (empty EvalResults → no workflows to regress)
975        assert!(comp.improved_aliases.is_empty());
976    }
977
978    #[test]
979    fn get_scenario_detail_found() {
980        let results = make_eval_results();
981        let detail = results.get_scenario_detail("s1").unwrap();
982        assert_eq!(detail.scenario_id, "s1");
983    }
984
985    #[test]
986    fn get_scenario_detail_missing() {
987        let results = make_eval_results();
988        assert!(results.get_scenario_detail("nonexistent").is_err());
989    }
990
991    #[test]
992    fn save_load_roundtrip() {
993        let results = make_eval_results();
994        let dir = tempfile::tempdir().unwrap();
995        let path = dir.path().join("results.json");
996        let path_str = path.to_str().unwrap();
997
998        results.save(path_str).unwrap();
999        let loaded = ScenarioEvalResults::load(path_str).unwrap();
1000        assert_eq!(results, loaded);
1001    }
1002
1003    #[test]
1004    fn comparison_model_dump_json_roundtrip() {
1005        let comp = ScenarioComparisonResults {
1006            dataset_comparisons: HashMap::new(),
1007            scenario_deltas: vec![ScenarioDelta {
1008                scenario_id: "s1".to_string(),
1009                initial_query: "hello".to_string(),
1010                baseline_passed: true,
1011                comparison_passed: false,
1012                baseline_pass_rate: 1.0,
1013                comparison_pass_rate: 0.0,
1014                status_changed: true,
1015            }],
1016            baseline_overall_pass_rate: 1.0,
1017            comparison_overall_pass_rate: 0.5,
1018            regressed: true,
1019            improved_aliases: vec![],
1020            regressed_aliases: vec!["a".to_string()],
1021            new_aliases: vec!["c".to_string()],
1022            removed_aliases: vec!["b".to_string()],
1023            new_scenarios: vec!["s3".to_string()],
1024            removed_scenarios: vec![],
1025            baseline_alias_pass_rates: HashMap::from([("a".to_string(), 1.0)]),
1026            comparison_alias_pass_rates: HashMap::from([("a".to_string(), 0.5)]),
1027        };
1028
1029        let json = comp.model_dump_json().unwrap();
1030        let loaded = ScenarioComparisonResults::model_validate_json(json).unwrap();
1031        assert_eq!(comp, loaded);
1032    }
1033
1034    #[test]
1035    fn comparison_save_load_roundtrip() {
1036        let comp = ScenarioComparisonResults {
1037            dataset_comparisons: HashMap::new(),
1038            scenario_deltas: vec![],
1039            baseline_overall_pass_rate: 0.8,
1040            comparison_overall_pass_rate: 0.9,
1041            regressed: false,
1042            improved_aliases: vec!["x".to_string()],
1043            regressed_aliases: vec![],
1044            new_aliases: vec![],
1045            removed_aliases: vec![],
1046            new_scenarios: vec![],
1047            removed_scenarios: vec![],
1048            baseline_alias_pass_rates: HashMap::new(),
1049            comparison_alias_pass_rates: HashMap::new(),
1050        };
1051
1052        let dir = tempfile::tempdir().unwrap();
1053        let path = dir.path().join("comp.json");
1054        let path_str = path.to_str().unwrap();
1055
1056        comp.save(path_str).unwrap();
1057        let loaded = ScenarioComparisonResults::load(path_str).unwrap();
1058        assert_eq!(comp, loaded);
1059    }
1060
1061    #[test]
1062    fn compare_to_new_scenarios() {
1063        let baseline = make_eval_results(); // s1, s2
1064        let current = ScenarioEvalResults {
1065            dataset_results: HashMap::new(),
1066            scenario_results: vec![
1067                make_scenario_result("s1", "Make pasta", true, 1.0),
1068                make_scenario_result("s2", "Make curry", false, 0.5),
1069                make_scenario_result("s3", "Make salad", true, 0.8),
1070            ],
1071            metrics: empty_metrics(0.77, 0.67, 3, 2),
1072        };
1073
1074        let comp = current.compare_to(&baseline, 0.05).unwrap();
1075        assert_eq!(comp.new_scenarios, vec!["s3".to_string()]);
1076        assert!(comp.removed_scenarios.is_empty());
1077        // Only shared scenarios appear in deltas
1078        assert_eq!(comp.scenario_deltas.len(), 2);
1079    }
1080
1081    #[test]
1082    fn compare_to_removed_scenarios() {
1083        let baseline = ScenarioEvalResults {
1084            dataset_results: HashMap::new(),
1085            scenario_results: vec![
1086                make_scenario_result("s1", "Make pasta", true, 1.0),
1087                make_scenario_result("s2", "Make curry", false, 0.5),
1088                make_scenario_result("s3", "Make salad", true, 0.8),
1089            ],
1090            metrics: empty_metrics(0.77, 0.67, 3, 2),
1091        };
1092        let current = make_eval_results(); // s1, s2
1093
1094        let comp = current.compare_to(&baseline, 0.05).unwrap();
1095        assert_eq!(comp.removed_scenarios, vec!["s3".to_string()]);
1096        assert!(comp.new_scenarios.is_empty());
1097    }
1098
1099    #[test]
1100    fn compare_to_new_removed_aliases() {
1101        let mut baseline_datasets = HashMap::new();
1102        baseline_datasets.insert("a".to_string(), EvalResults::new());
1103        baseline_datasets.insert("b".to_string(), EvalResults::new());
1104
1105        let mut current_datasets = HashMap::new();
1106        current_datasets.insert("a".to_string(), EvalResults::new());
1107        current_datasets.insert("c".to_string(), EvalResults::new());
1108
1109        let baseline = ScenarioEvalResults {
1110            dataset_results: baseline_datasets,
1111            scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1112            metrics: empty_metrics(1.0, 1.0, 1, 1),
1113        };
1114
1115        let current = ScenarioEvalResults {
1116            dataset_results: current_datasets,
1117            scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1118            metrics: empty_metrics(1.0, 1.0, 1, 1),
1119        };
1120
1121        let comp = current.compare_to(&baseline, 0.05).unwrap();
1122        assert_eq!(comp.new_aliases, vec!["c".to_string()]);
1123        assert_eq!(comp.removed_aliases, vec!["b".to_string()]);
1124    }
1125
1126    #[test]
1127    fn compare_to_alias_pass_rates() {
1128        let mut baseline_metrics = empty_metrics(0.9, 1.0, 1, 1);
1129        baseline_metrics
1130            .dataset_pass_rates
1131            .insert("agent_a".to_string(), 0.9);
1132        baseline_metrics
1133            .dataset_pass_rates
1134            .insert("agent_b".to_string(), 0.8);
1135
1136        let mut current_metrics = empty_metrics(0.85, 1.0, 1, 1);
1137        current_metrics
1138            .dataset_pass_rates
1139            .insert("agent_a".to_string(), 0.85);
1140        current_metrics
1141            .dataset_pass_rates
1142            .insert("agent_b".to_string(), 0.75);
1143
1144        let baseline = ScenarioEvalResults {
1145            dataset_results: HashMap::new(),
1146            scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1147            metrics: baseline_metrics,
1148        };
1149
1150        let current = ScenarioEvalResults {
1151            dataset_results: HashMap::new(),
1152            scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1153            metrics: current_metrics,
1154        };
1155
1156        let comp = current.compare_to(&baseline, 0.05).unwrap();
1157        assert_eq!(comp.baseline_alias_pass_rates.get("agent_a"), Some(&0.9));
1158        assert_eq!(comp.baseline_alias_pass_rates.get("agent_b"), Some(&0.8));
1159        assert_eq!(comp.comparison_alias_pass_rates.get("agent_a"), Some(&0.85));
1160        assert_eq!(comp.comparison_alias_pass_rates.get("agent_b"), Some(&0.75));
1161    }
1162}