Skip to main content

scouter_evaluate/evaluate/
scenario_results.rs

1use crate::error::EvaluationError;
2use crate::evaluate::compare::compare_results;
3use crate::evaluate::types::{ComparisonResults, EvalResults};
4use owo_colors::OwoColorize;
5use potato_head::PyHelperFuncs;
6use pyo3::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use tabled::Tabled;
10use tabled::{
11    settings::{object::Rows, Alignment, Color, Format, Style},
12    Table,
13};
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16#[pyclass]
17pub struct TaskSummary {
18    #[pyo3(get)]
19    pub task_id: String,
20
21    #[pyo3(get)]
22    pub passed: bool,
23
24    #[pyo3(get)]
25    pub value: f64,
26}
27
28#[pymethods]
29impl TaskSummary {
30    pub fn __str__(&self) -> String {
31        PyHelperFuncs::__str__(self)
32    }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36#[pyclass]
37pub struct EvalMetrics {
38    #[pyo3(get)]
39    pub overall_pass_rate: f64,
40
41    // HashMap<K,V> fields cannot use #[pyo3(get)]; exposed via the getter method below
42    pub dataset_pass_rates: HashMap<String, f64>,
43
44    #[pyo3(get)]
45    pub scenario_pass_rate: f64,
46
47    #[pyo3(get)]
48    pub total_scenarios: usize,
49
50    #[pyo3(get)]
51    pub passed_scenarios: usize,
52
53    #[serde(default)]
54    pub scenario_task_pass_rates: HashMap<String, HashMap<String, f64>>,
55}
56
57#[derive(Tabled)]
58struct MetricEntry {
59    #[tabled(rename = "Metric")]
60    metric: String,
61    #[tabled(rename = "Value")]
62    value: String,
63}
64
65#[derive(Tabled)]
66struct DatasetPassRateEntry {
67    #[tabled(rename = "Alias")]
68    alias: String,
69    #[tabled(rename = "Pass Rate")]
70    pass_rate: String,
71}
72
73#[pymethods]
74impl EvalMetrics {
75    pub fn __str__(&self) -> String {
76        PyHelperFuncs::__str__(self)
77    }
78
79    #[getter]
80    pub fn dataset_pass_rates(&self) -> HashMap<String, f64> {
81        self.dataset_pass_rates.clone()
82    }
83
84    #[getter]
85    pub fn scenario_task_pass_rates(&self) -> HashMap<String, HashMap<String, f64>> {
86        self.scenario_task_pass_rates.clone()
87    }
88
89    pub fn as_table(&self) {
90        println!("\n{}", "Aggregate Metrics".truecolor(245, 77, 85).bold());
91
92        let entries = vec![
93            MetricEntry {
94                metric: "Overall Pass Rate".to_string(),
95                value: format!("{:.1}%", self.overall_pass_rate * 100.0),
96            },
97            MetricEntry {
98                metric: "Scenario Pass Rate".to_string(),
99                value: format!("{:.1}%", self.scenario_pass_rate * 100.0),
100            },
101            MetricEntry {
102                metric: "Total Scenarios".to_string(),
103                value: self.total_scenarios.to_string(),
104            },
105            MetricEntry {
106                metric: "Passed Scenarios".to_string(),
107                value: self.passed_scenarios.to_string(),
108            },
109        ];
110
111        let mut table = Table::new(entries);
112        table.with(Style::sharp());
113        table.modify(
114            Rows::new(0..1),
115            (
116                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
117                Alignment::center(),
118                Color::BOLD,
119            ),
120        );
121        println!("{}", table);
122
123        if !self.dataset_pass_rates.is_empty() {
124            println!("\n{}", "Sub-Agent Pass Rates".truecolor(245, 77, 85).bold());
125
126            let mut alias_entries: Vec<_> = self
127                .dataset_pass_rates
128                .iter()
129                .map(|(alias, rate)| DatasetPassRateEntry {
130                    alias: alias.clone(),
131                    pass_rate: format!("{:.1}%", rate * 100.0),
132                })
133                .collect();
134            alias_entries.sort_by(|a, b| a.alias.cmp(&b.alias));
135
136            let mut alias_table = Table::new(alias_entries);
137            alias_table.with(Style::sharp());
138            alias_table.modify(
139                Rows::new(0..1),
140                (
141                    Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
142                    Alignment::center(),
143                    Color::BOLD,
144                ),
145            );
146            println!("{}", alias_table);
147        }
148    }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[pyclass]
153pub struct ScenarioResult {
154    #[pyo3(get)]
155    pub scenario_id: String,
156
157    #[pyo3(get)]
158    pub initial_query: String,
159
160    pub eval_results: EvalResults,
161
162    #[pyo3(get)]
163    pub passed: bool,
164
165    #[pyo3(get)]
166    pub pass_rate: f64,
167
168    #[pyo3(get)]
169    #[serde(default)]
170    pub task_results: Vec<TaskSummary>,
171}
172
173impl PartialEq for ScenarioResult {
174    fn eq(&self, other: &Self) -> bool {
175        self.scenario_id == other.scenario_id
176            && self.initial_query == other.initial_query
177            && self.passed == other.passed
178            && self.pass_rate == other.pass_rate
179            && self.task_results == other.task_results
180    }
181}
182
183#[pymethods]
184impl ScenarioResult {
185    pub fn __str__(&self) -> String {
186        PyHelperFuncs::__str__(self)
187    }
188
189    #[getter]
190    pub fn eval_results(&self) -> EvalResults {
191        self.eval_results.clone()
192    }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
196#[pyclass]
197pub struct ScenarioDelta {
198    #[pyo3(get)]
199    pub scenario_id: String,
200
201    #[pyo3(get)]
202    pub initial_query: String,
203
204    #[pyo3(get)]
205    pub baseline_passed: bool,
206
207    #[pyo3(get)]
208    pub comparison_passed: bool,
209
210    #[pyo3(get)]
211    pub baseline_pass_rate: f64,
212
213    #[pyo3(get)]
214    pub comparison_pass_rate: f64,
215
216    #[pyo3(get)]
217    pub status_changed: bool,
218}
219
220#[pymethods]
221impl ScenarioDelta {
222    pub fn __str__(&self) -> String {
223        PyHelperFuncs::__str__(self)
224    }
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
228#[pyclass]
229pub struct ScenarioComparisonResults {
230    pub dataset_comparisons: HashMap<String, ComparisonResults>,
231
232    #[pyo3(get)]
233    pub scenario_deltas: Vec<ScenarioDelta>,
234
235    #[pyo3(get)]
236    pub baseline_overall_pass_rate: f64,
237
238    #[pyo3(get)]
239    pub comparison_overall_pass_rate: f64,
240
241    #[pyo3(get)]
242    pub regressed: bool,
243
244    #[pyo3(get)]
245    pub improved_aliases: Vec<String>,
246
247    #[pyo3(get)]
248    pub regressed_aliases: Vec<String>,
249
250    #[pyo3(get)]
251    #[serde(default)]
252    pub new_aliases: Vec<String>,
253
254    #[pyo3(get)]
255    #[serde(default)]
256    pub removed_aliases: Vec<String>,
257
258    #[pyo3(get)]
259    #[serde(default)]
260    pub new_scenarios: Vec<String>,
261
262    #[pyo3(get)]
263    #[serde(default)]
264    pub removed_scenarios: Vec<String>,
265
266    #[serde(default)]
267    pub baseline_alias_pass_rates: HashMap<String, f64>,
268
269    #[serde(default)]
270    pub comparison_alias_pass_rates: HashMap<String, f64>,
271}
272
273impl PartialEq for ScenarioComparisonResults {
274    fn eq(&self, other: &Self) -> bool {
275        self.scenario_deltas == other.scenario_deltas
276            && self.baseline_overall_pass_rate == other.baseline_overall_pass_rate
277            && self.comparison_overall_pass_rate == other.comparison_overall_pass_rate
278            && self.regressed == other.regressed
279            && self.improved_aliases == other.improved_aliases
280            && self.regressed_aliases == other.regressed_aliases
281            && self.new_aliases == other.new_aliases
282            && self.removed_aliases == other.removed_aliases
283            && self.new_scenarios == other.new_scenarios
284            && self.removed_scenarios == other.removed_scenarios
285            && self.baseline_alias_pass_rates == other.baseline_alias_pass_rates
286            && self.comparison_alias_pass_rates == other.comparison_alias_pass_rates
287        // dataset_comparisons excluded: ComparisonResults does not implement PartialEq
288    }
289}
290
291#[derive(Tabled)]
292struct DatasetComparisonEntry {
293    #[tabled(rename = "Alias")]
294    alias: String,
295    #[tabled(rename = "Delta")]
296    delta: String,
297    #[tabled(rename = "Status")]
298    status: String,
299}
300
301#[derive(Tabled)]
302struct ScenarioDeltaEntry {
303    #[tabled(rename = "Scenario ID")]
304    scenario_id: String,
305    #[tabled(rename = "Baseline")]
306    baseline: String,
307    #[tabled(rename = "Current")]
308    current: String,
309    #[tabled(rename = "Pass Rate Δ")]
310    pass_rate_delta: String,
311    #[tabled(rename = "Change")]
312    change: String,
313}
314
315#[derive(Tabled)]
316struct AliasPassRateEntry {
317    #[tabled(rename = "Alias")]
318    alias: String,
319    #[tabled(rename = "Baseline")]
320    baseline: String,
321    #[tabled(rename = "Current")]
322    current: String,
323    #[tabled(rename = "Delta")]
324    delta: String,
325    #[tabled(rename = "Status")]
326    status: String,
327}
328
329#[pymethods]
330impl ScenarioComparisonResults {
331    pub fn __str__(&self) -> String {
332        PyHelperFuncs::__str__(self)
333    }
334
335    #[getter]
336    pub fn dataset_comparisons(&self) -> HashMap<String, ComparisonResults> {
337        self.dataset_comparisons.clone()
338    }
339
340    #[getter]
341    pub fn baseline_alias_pass_rates(&self) -> HashMap<String, f64> {
342        self.baseline_alias_pass_rates.clone()
343    }
344
345    #[getter]
346    pub fn comparison_alias_pass_rates(&self) -> HashMap<String, f64> {
347        self.comparison_alias_pass_rates.clone()
348    }
349
350    pub fn model_dump_json(&self) -> Result<String, EvaluationError> {
351        serde_json::to_string(self).map_err(Into::into)
352    }
353
354    #[staticmethod]
355    pub fn model_validate_json(json_string: String) -> Result<Self, EvaluationError> {
356        serde_json::from_str(&json_string).map_err(Into::into)
357    }
358
359    pub fn save(&self, path: &str) -> Result<(), EvaluationError> {
360        let json = serde_json::to_string_pretty(self)?;
361        std::fs::write(path, json)?;
362        Ok(())
363    }
364
365    #[staticmethod]
366    pub fn load(path: &str) -> Result<Self, EvaluationError> {
367        let json = std::fs::read_to_string(path)?;
368        serde_json::from_str(&json).map_err(Into::into)
369    }
370
371    pub fn as_table(&self) {
372        // Sub-Agent Comparison with pass rates
373        if !self.baseline_alias_pass_rates.is_empty()
374            || !self.comparison_alias_pass_rates.is_empty()
375        {
376            println!("\n{}", "Sub-Agent Comparison".truecolor(245, 77, 85).bold());
377
378            let mut all_aliases: Vec<String> = self
379                .baseline_alias_pass_rates
380                .keys()
381                .chain(self.comparison_alias_pass_rates.keys())
382                .cloned()
383                .collect();
384            all_aliases.sort();
385            all_aliases.dedup();
386
387            let mut alias_entries: Vec<AliasPassRateEntry> = Vec::new();
388            for alias in &all_aliases {
389                let baseline_rate = self.baseline_alias_pass_rates.get(alias);
390                let current_rate = self.comparison_alias_pass_rates.get(alias);
391
392                let (baseline_str, current_str, delta_str, status) =
393                    match (baseline_rate, current_rate) {
394                        (Some(b), Some(c)) => {
395                            let delta = (c - b) * 100.0;
396                            let d = format!("{:+.1}%", delta);
397                            let colored_delta = if delta > 1.0 {
398                                d.green().to_string()
399                            } else if delta < -1.0 {
400                                d.red().to_string()
401                            } else {
402                                d.yellow().to_string()
403                            };
404                            let s = if self.regressed_aliases.contains(alias) {
405                                "REGRESSION".red().to_string()
406                            } else if self.improved_aliases.contains(alias) {
407                                "IMPROVED".green().to_string()
408                            } else {
409                                "UNCHANGED".yellow().to_string()
410                            };
411                            (
412                                format!("{:.1}%", b * 100.0),
413                                format!("{:.1}%", c * 100.0),
414                                colored_delta,
415                                s,
416                            )
417                        }
418                        (None, Some(c)) => (
419                            "-".to_string(),
420                            format!("{:.1}%", c * 100.0),
421                            "-".to_string(),
422                            "NEW".green().bold().to_string(),
423                        ),
424                        (Some(b), None) => (
425                            format!("{:.1}%", b * 100.0),
426                            "-".to_string(),
427                            "-".to_string(),
428                            "REMOVED".red().bold().to_string(),
429                        ),
430                        (None, None) => continue,
431                    };
432
433                alias_entries.push(AliasPassRateEntry {
434                    alias: alias.clone(),
435                    baseline: baseline_str,
436                    current: current_str,
437                    delta: delta_str,
438                    status,
439                });
440            }
441
442            let mut alias_table = Table::new(alias_entries);
443            alias_table.with(Style::sharp());
444            alias_table.modify(
445                Rows::new(0..1),
446                (
447                    Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
448                    Alignment::center(),
449                    Color::BOLD,
450                ),
451            );
452            println!("{}", alias_table);
453        } else if !self.dataset_comparisons.is_empty() {
454            // Fallback for old data without alias pass rates
455            println!("\n{}", "Sub-Agent Comparison".truecolor(245, 77, 85).bold());
456
457            let mut alias_entries: Vec<_> = self
458                .dataset_comparisons
459                .iter()
460                .map(|(alias, comp)| {
461                    let delta = comp.mean_pass_rate_delta * 100.0;
462                    let delta_str = format!("{:+.1}%", delta);
463                    let colored_delta = if delta > 1.0 {
464                        delta_str.green().to_string()
465                    } else if delta < -1.0 {
466                        delta_str.red().to_string()
467                    } else {
468                        delta_str.yellow().to_string()
469                    };
470                    let status = if comp.regressed {
471                        "REGRESSION".red().to_string()
472                    } else if comp.improved_workflows > 0 {
473                        "IMPROVED".green().to_string()
474                    } else {
475                        "UNCHANGED".yellow().to_string()
476                    };
477                    DatasetComparisonEntry {
478                        alias: alias.clone(),
479                        delta: colored_delta,
480                        status,
481                    }
482                })
483                .collect();
484            alias_entries.sort_by(|a, b| a.alias.cmp(&b.alias));
485
486            let mut alias_table = Table::new(alias_entries);
487            alias_table.with(Style::sharp());
488            alias_table.modify(
489                Rows::new(0..1),
490                (
491                    Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
492                    Alignment::center(),
493                    Color::BOLD,
494                ),
495            );
496            println!("{}", alias_table);
497        }
498
499        // All Scenarios table (not just changed ones)
500        if !self.scenario_deltas.is_empty() {
501            println!("\n{}", "Scenario Comparison".truecolor(245, 77, 85).bold());
502
503            let entries: Vec<_> = self
504                .scenario_deltas
505                .iter()
506                .map(|d| {
507                    let baseline_str = if d.baseline_passed {
508                        "PASS".green().to_string()
509                    } else {
510                        "FAIL".red().to_string()
511                    };
512                    let current_str = if d.comparison_passed {
513                        "PASS".green().to_string()
514                    } else {
515                        "FAIL".red().to_string()
516                    };
517                    let pr_delta = (d.comparison_pass_rate - d.baseline_pass_rate) * 100.0;
518                    let pr_delta_str = format!("{:+.1}%", pr_delta);
519                    let colored_pr_delta = if pr_delta > 1.0 {
520                        pr_delta_str.green().to_string()
521                    } else if pr_delta < -1.0 {
522                        pr_delta_str.red().to_string()
523                    } else {
524                        pr_delta_str.yellow().to_string()
525                    };
526                    let change = match (d.baseline_passed, d.comparison_passed) {
527                        (true, false) => "Pass -> Fail".red().bold().to_string(),
528                        (false, true) => "Fail -> Pass".green().bold().to_string(),
529                        _ => "-".to_string(),
530                    };
531                    ScenarioDeltaEntry {
532                        scenario_id: d.scenario_id.chars().take(16).collect::<String>(),
533                        baseline: baseline_str,
534                        current: current_str,
535                        pass_rate_delta: colored_pr_delta,
536                        change,
537                    }
538                })
539                .collect();
540
541            let mut delta_table = Table::new(entries);
542            delta_table.with(Style::sharp());
543            delta_table.modify(
544                Rows::new(0..1),
545                (
546                    Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
547                    Alignment::center(),
548                    Color::BOLD,
549                ),
550            );
551            println!("{}", delta_table);
552        }
553
554        // New/Removed Scenarios
555        if !self.new_scenarios.is_empty() {
556            println!(
557                "\n{}  {}",
558                "New Scenarios:".truecolor(245, 77, 85).bold(),
559                self.new_scenarios.join(", ").green()
560            );
561        }
562        if !self.removed_scenarios.is_empty() {
563            println!(
564                "\n{}  {}",
565                "Removed Scenarios:".truecolor(245, 77, 85).bold(),
566                self.removed_scenarios.join(", ").red()
567            );
568        }
569
570        // Summary
571        let overall_status = if self.regressed {
572            "REGRESSION DETECTED".red().bold().to_string()
573        } else if !self.improved_aliases.is_empty() {
574            "IMPROVEMENT DETECTED".green().bold().to_string()
575        } else {
576            "NO SIGNIFICANT CHANGE".yellow().bold().to_string()
577        };
578
579        println!("\n{}", "Summary".truecolor(245, 77, 85).bold());
580        println!("  Overall Status: {}", overall_status);
581        println!(
582            "  Baseline Pass Rate: {:.1}%",
583            self.baseline_overall_pass_rate * 100.0
584        );
585        println!(
586            "  Current Pass Rate:  {:.1}%",
587            self.comparison_overall_pass_rate * 100.0
588        );
589        if !self.regressed_aliases.is_empty() {
590            println!(
591                "  Regressed aliases: {}",
592                self.regressed_aliases.join(", ").red()
593            );
594        }
595        if !self.improved_aliases.is_empty() {
596            println!(
597                "  Improved aliases: {}",
598                self.improved_aliases.join(", ").green()
599            );
600        }
601        if !self.new_aliases.is_empty() {
602            println!("  New aliases: {}", self.new_aliases.join(", ").green());
603        }
604        if !self.removed_aliases.is_empty() {
605            println!(
606                "  Removed aliases: {}",
607                self.removed_aliases.join(", ").red()
608            );
609        }
610        let changed_count = self
611            .scenario_deltas
612            .iter()
613            .filter(|d| d.status_changed)
614            .count();
615        if changed_count > 0 {
616            println!("  Scenarios changed: {}", changed_count);
617        }
618        if !self.new_scenarios.is_empty() {
619            println!("  New scenarios: {}", self.new_scenarios.len());
620        }
621        if !self.removed_scenarios.is_empty() {
622            println!("  Removed scenarios: {}", self.removed_scenarios.len());
623        }
624    }
625}
626
627#[derive(Debug, Clone, Serialize, Deserialize)]
628#[pyclass]
629pub struct ScenarioEvalResults {
630    pub dataset_results: HashMap<String, EvalResults>,
631
632    pub scenario_results: Vec<ScenarioResult>,
633
634    #[pyo3(get)]
635    pub metrics: EvalMetrics,
636}
637
638impl PartialEq for ScenarioEvalResults {
639    fn eq(&self, other: &Self) -> bool {
640        self.scenario_results == other.scenario_results && self.metrics == other.metrics
641        // dataset_results excluded: EvalResults does not implement PartialEq
642    }
643}
644
645#[derive(Tabled)]
646struct ScenarioResultEntry {
647    #[tabled(rename = "Scenario ID")]
648    scenario_id: String,
649    #[tabled(rename = "Initial Query")]
650    initial_query: String,
651    #[tabled(rename = "Tasks")]
652    tasks: usize,
653    #[tabled(rename = "Passed")]
654    passed_count: usize,
655    #[tabled(rename = "Failed")]
656    failed_count: usize,
657    #[tabled(rename = "Pass Rate")]
658    pass_rate: String,
659    #[tabled(rename = "Status")]
660    status: String,
661}
662
663#[pymethods]
664impl ScenarioEvalResults {
665    pub fn __str__(&self) -> String {
666        PyHelperFuncs::__str__(self)
667    }
668
669    pub fn model_dump_json(&self) -> Result<String, EvaluationError> {
670        serde_json::to_string(self).map_err(Into::into)
671    }
672
673    #[staticmethod]
674    pub fn model_validate_json(json_string: String) -> Result<Self, EvaluationError> {
675        serde_json::from_str(&json_string).map_err(Into::into)
676    }
677
678    pub fn save(&self, path: &str) -> Result<(), EvaluationError> {
679        let json = serde_json::to_string_pretty(self)?;
680        std::fs::write(path, json)?;
681        Ok(())
682    }
683
684    #[staticmethod]
685    pub fn load(path: &str) -> Result<Self, EvaluationError> {
686        let json = std::fs::read_to_string(path)?;
687        serde_json::from_str(&json).map_err(Into::into)
688    }
689
690    #[getter]
691    pub fn dataset_results(&self) -> HashMap<String, EvalResults> {
692        self.dataset_results.clone()
693    }
694
695    #[getter]
696    pub fn scenario_results(&self) -> Vec<ScenarioResult> {
697        self.scenario_results.clone()
698    }
699
700    pub fn get_scenario_detail(
701        &self,
702        scenario_id: &str,
703    ) -> Result<ScenarioResult, EvaluationError> {
704        self.scenario_results
705            .iter()
706            .find(|r| r.scenario_id == scenario_id)
707            .cloned()
708            .ok_or_else(|| EvaluationError::MissingKeyError(scenario_id.to_string()))
709    }
710
711    #[pyo3(signature = (baseline, regression_threshold = 0.05))]
712    pub fn compare_to(
713        &self,
714        baseline: &ScenarioEvalResults,
715        regression_threshold: f64,
716    ) -> Result<ScenarioComparisonResults, EvaluationError> {
717        let mut dataset_comparisons = HashMap::new();
718        let mut improved_aliases = Vec::new();
719        let mut regressed_aliases = Vec::new();
720
721        for (alias, current_results) in &self.dataset_results {
722            if let Some(baseline_results) = baseline.dataset_results.get(alias) {
723                let comp =
724                    compare_results(baseline_results, current_results, regression_threshold)?;
725                if comp.regressed {
726                    regressed_aliases.push(alias.clone());
727                } else if comp.improved_workflows > 0 {
728                    improved_aliases.push(alias.clone());
729                }
730                dataset_comparisons.insert(alias.clone(), comp);
731            }
732        }
733
734        // Detect new/removed aliases
735        let mut new_aliases: Vec<String> = self
736            .dataset_results
737            .keys()
738            .filter(|alias| !baseline.dataset_results.contains_key(*alias))
739            .cloned()
740            .collect();
741        new_aliases.sort();
742
743        let mut removed_aliases: Vec<String> = baseline
744            .dataset_results
745            .keys()
746            .filter(|alias| !self.dataset_results.contains_key(*alias))
747            .cloned()
748            .collect();
749        removed_aliases.sort();
750
751        let baseline_scenario_map: HashMap<_, _> = baseline
752            .scenario_results
753            .iter()
754            .map(|r| (r.scenario_id.as_str(), r))
755            .collect();
756
757        let current_scenario_map: HashMap<_, _> = self
758            .scenario_results
759            .iter()
760            .map(|r| (r.scenario_id.as_str(), r))
761            .collect();
762
763        let mut scenario_deltas = Vec::new();
764        for current in &self.scenario_results {
765            if let Some(base) = baseline_scenario_map.get(current.scenario_id.as_str()) {
766                scenario_deltas.push(ScenarioDelta {
767                    scenario_id: current.scenario_id.clone(),
768                    initial_query: current.initial_query.clone(),
769                    baseline_passed: base.passed,
770                    comparison_passed: current.passed,
771                    baseline_pass_rate: base.pass_rate,
772                    comparison_pass_rate: current.pass_rate,
773                    status_changed: base.passed != current.passed,
774                });
775            }
776        }
777
778        // Detect new/removed scenarios
779        let mut new_scenarios: Vec<String> = current_scenario_map
780            .keys()
781            .filter(|id| !baseline_scenario_map.contains_key(*id))
782            .map(|id| id.to_string())
783            .collect();
784        new_scenarios.sort();
785
786        let mut removed_scenarios: Vec<String> = baseline_scenario_map
787            .keys()
788            .filter(|id| !current_scenario_map.contains_key(*id))
789            .map(|id| id.to_string())
790            .collect();
791        removed_scenarios.sort();
792
793        // Per-alias pass rates from metrics
794        let baseline_alias_pass_rates = baseline.metrics.dataset_pass_rates.clone();
795        let comparison_alias_pass_rates = self.metrics.dataset_pass_rates.clone();
796
797        Ok(ScenarioComparisonResults {
798            dataset_comparisons,
799            scenario_deltas,
800            baseline_overall_pass_rate: baseline.metrics.overall_pass_rate,
801            comparison_overall_pass_rate: self.metrics.overall_pass_rate,
802            regressed: !regressed_aliases.is_empty(),
803            improved_aliases,
804            regressed_aliases,
805            new_aliases,
806            removed_aliases,
807            new_scenarios,
808            removed_scenarios,
809            baseline_alias_pass_rates,
810            comparison_alias_pass_rates,
811        })
812    }
813
814    #[pyo3(signature = (show_datasets=false))]
815    pub fn as_table(&mut self, show_datasets: bool) {
816        self.metrics.as_table();
817
818        if !self.scenario_results.is_empty() {
819            println!("\n{}", "Scenario Results".truecolor(245, 77, 85).bold());
820
821            let entries: Vec<_> = self
822                .scenario_results
823                .iter()
824                .map(|r| {
825                    let query = if r.initial_query.chars().count() > 40 {
826                        format!(
827                            "{}...",
828                            r.initial_query.chars().take(40).collect::<String>()
829                        )
830                    } else {
831                        r.initial_query.clone()
832                    };
833                    let status = if r.passed {
834                        "✓ PASS".green().to_string()
835                    } else {
836                        "✗ FAIL".red().to_string()
837                    };
838                    let total = r.task_results.len();
839                    let passed_count = r.task_results.iter().filter(|t| t.passed).count();
840                    let failed_count = total - passed_count;
841                    ScenarioResultEntry {
842                        scenario_id: r.scenario_id.chars().take(16).collect::<String>(),
843                        initial_query: query,
844                        tasks: total,
845                        passed_count,
846                        failed_count,
847                        pass_rate: format!("{:.1}%", r.pass_rate * 100.0),
848                        status,
849                    }
850                })
851                .collect();
852
853            let mut table = Table::new(entries);
854            table.with(Style::sharp());
855            table.modify(
856                Rows::new(0..1),
857                (
858                    Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
859                    Alignment::center(),
860                    Color::BOLD,
861                ),
862            );
863            println!("{}", table);
864        }
865
866        if show_datasets {
867            let mut aliases: Vec<_> = self.dataset_results.keys().cloned().collect();
868            aliases.sort();
869            for alias in aliases {
870                if let Some(eval_results) = self.dataset_results.get_mut(&alias) {
871                    println!(
872                        "\n{}",
873                        format!("Dataset: {}", alias).truecolor(245, 77, 85).bold()
874                    );
875                    eval_results.as_table(false);
876                }
877            }
878        }
879    }
880}
881
882#[cfg(test)]
883mod tests {
884    use super::*;
885    use crate::evaluate::types::EvalResults;
886
887    fn empty_metrics(
888        overall: f64,
889        scenario_pass_rate: f64,
890        total: usize,
891        passed: usize,
892    ) -> EvalMetrics {
893        EvalMetrics {
894            overall_pass_rate: overall,
895            dataset_pass_rates: HashMap::new(),
896            scenario_pass_rate,
897            total_scenarios: total,
898            passed_scenarios: passed,
899            scenario_task_pass_rates: HashMap::new(),
900        }
901    }
902
903    fn make_scenario_result(id: &str, query: &str, passed: bool, pass_rate: f64) -> ScenarioResult {
904        ScenarioResult {
905            scenario_id: id.to_string(),
906            initial_query: query.to_string(),
907            eval_results: EvalResults::new(),
908            passed,
909            pass_rate,
910            task_results: vec![],
911        }
912    }
913
914    fn make_eval_results() -> ScenarioEvalResults {
915        ScenarioEvalResults {
916            dataset_results: HashMap::new(),
917            scenario_results: vec![
918                make_scenario_result("s1", "Make pasta", true, 1.0),
919                make_scenario_result("s2", "Make curry", false, 0.5),
920            ],
921            metrics: empty_metrics(0.75, 0.5, 2, 1),
922        }
923    }
924
925    #[test]
926    fn eval_metrics_fields() {
927        let m = empty_metrics(0.9, 0.85, 100, 85);
928        assert_eq!(m.overall_pass_rate, 0.9);
929        assert_eq!(m.scenario_pass_rate, 0.85);
930        assert_eq!(m.total_scenarios, 100);
931        assert_eq!(m.passed_scenarios, 85);
932    }
933
934    #[test]
935    fn scenario_result_fields() {
936        let r = make_scenario_result("id-1", "hello", true, 1.0);
937        assert_eq!(r.scenario_id, "id-1");
938        assert!(r.passed);
939        assert_eq!(r.pass_rate, 1.0);
940    }
941
942    #[test]
943    fn model_dump_json_roundtrip() {
944        let results = make_eval_results();
945        let json = results.model_dump_json().unwrap();
946        let loaded = ScenarioEvalResults::model_validate_json(json).unwrap();
947        assert_eq!(loaded.scenario_results.len(), 2);
948        assert_eq!(loaded.metrics.total_scenarios, 2);
949        assert_eq!(loaded.metrics.passed_scenarios, 1);
950        assert_eq!(loaded.metrics.overall_pass_rate, 0.75);
951        assert_eq!(loaded.scenario_results[0].scenario_id, "s1");
952        assert_eq!(loaded.scenario_results[1].scenario_id, "s2");
953        assert_eq!(loaded.scenario_results[0].pass_rate, 1.0);
954        assert_eq!(loaded.scenario_results[1].pass_rate, 0.5);
955    }
956
957    #[test]
958    fn compare_to_regression_detection() {
959        let baseline = make_eval_results();
960
961        let current = ScenarioEvalResults {
962            dataset_results: HashMap::new(),
963            scenario_results: vec![
964                make_scenario_result("s1", "Make pasta", false, 0.0),
965                make_scenario_result("s2", "Make curry", false, 0.5),
966            ],
967            metrics: empty_metrics(0.25, 0.0, 2, 0),
968        };
969
970        let comp = current.compare_to(&baseline, 0.05).unwrap();
971        assert_eq!(comp.scenario_deltas.len(), 2);
972
973        let s1 = comp
974            .scenario_deltas
975            .iter()
976            .find(|d| d.scenario_id == "s1")
977            .unwrap();
978        assert!(s1.status_changed);
979        assert!(s1.baseline_passed);
980        assert!(!s1.comparison_passed);
981
982        let s2 = comp
983            .scenario_deltas
984            .iter()
985            .find(|d| d.scenario_id == "s2")
986            .unwrap();
987        assert!(!s2.status_changed);
988    }
989
990    #[test]
991    fn compare_to_no_regression() {
992        let baseline = make_eval_results();
993        let current = make_eval_results();
994
995        let comp = current.compare_to(&baseline, 0.05).unwrap();
996        assert!(!comp.regressed);
997        assert!(comp.scenario_deltas.iter().all(|d| !d.status_changed));
998    }
999
1000    #[test]
1001    fn compare_to_with_dataset_results() {
1002        let mut baseline_dataset = HashMap::new();
1003        baseline_dataset.insert("agent_a".to_string(), EvalResults::new());
1004
1005        let mut current_dataset = HashMap::new();
1006        current_dataset.insert("agent_a".to_string(), EvalResults::new());
1007
1008        let baseline = ScenarioEvalResults {
1009            dataset_results: baseline_dataset,
1010            scenario_results: vec![make_scenario_result("s1", "query one", true, 1.0)],
1011            metrics: empty_metrics(1.0, 1.0, 1, 1),
1012        };
1013
1014        let current = ScenarioEvalResults {
1015            dataset_results: current_dataset,
1016            scenario_results: vec![make_scenario_result("s1", "query one", false, 0.0)],
1017            metrics: empty_metrics(0.0, 0.0, 1, 0),
1018        };
1019
1020        let comp = current.compare_to(&baseline, 0.05).unwrap();
1021
1022        // dataset_comparisons should be populated for agent_a
1023        assert!(comp.dataset_comparisons.contains_key("agent_a"));
1024
1025        // scenario delta for s1 should show regression
1026        let s1 = comp
1027            .scenario_deltas
1028            .iter()
1029            .find(|d| d.scenario_id == "s1")
1030            .unwrap();
1031        assert!(s1.status_changed);
1032        assert!(s1.baseline_passed);
1033        assert!(!s1.comparison_passed);
1034
1035        // regressed flag reflects true state
1036        // no dataset-level regression (empty EvalResults → no workflows to regress)
1037        assert!(comp.improved_aliases.is_empty());
1038    }
1039
1040    #[test]
1041    fn get_scenario_detail_found() {
1042        let results = make_eval_results();
1043        let detail = results.get_scenario_detail("s1").unwrap();
1044        assert_eq!(detail.scenario_id, "s1");
1045    }
1046
1047    #[test]
1048    fn get_scenario_detail_missing() {
1049        let results = make_eval_results();
1050        assert!(results.get_scenario_detail("nonexistent").is_err());
1051    }
1052
1053    #[test]
1054    fn save_load_roundtrip() {
1055        let results = make_eval_results();
1056        let dir = tempfile::tempdir().unwrap();
1057        let path = dir.path().join("results.json");
1058        let path_str = path.to_str().unwrap();
1059
1060        results.save(path_str).unwrap();
1061        let loaded = ScenarioEvalResults::load(path_str).unwrap();
1062        assert_eq!(results, loaded);
1063    }
1064
1065    #[test]
1066    fn comparison_model_dump_json_roundtrip() {
1067        let comp = ScenarioComparisonResults {
1068            dataset_comparisons: HashMap::new(),
1069            scenario_deltas: vec![ScenarioDelta {
1070                scenario_id: "s1".to_string(),
1071                initial_query: "hello".to_string(),
1072                baseline_passed: true,
1073                comparison_passed: false,
1074                baseline_pass_rate: 1.0,
1075                comparison_pass_rate: 0.0,
1076                status_changed: true,
1077            }],
1078            baseline_overall_pass_rate: 1.0,
1079            comparison_overall_pass_rate: 0.5,
1080            regressed: true,
1081            improved_aliases: vec![],
1082            regressed_aliases: vec!["a".to_string()],
1083            new_aliases: vec!["c".to_string()],
1084            removed_aliases: vec!["b".to_string()],
1085            new_scenarios: vec!["s3".to_string()],
1086            removed_scenarios: vec![],
1087            baseline_alias_pass_rates: HashMap::from([("a".to_string(), 1.0)]),
1088            comparison_alias_pass_rates: HashMap::from([("a".to_string(), 0.5)]),
1089        };
1090
1091        let json = comp.model_dump_json().unwrap();
1092        let loaded = ScenarioComparisonResults::model_validate_json(json).unwrap();
1093        assert_eq!(comp, loaded);
1094    }
1095
1096    #[test]
1097    fn comparison_save_load_roundtrip() {
1098        let comp = ScenarioComparisonResults {
1099            dataset_comparisons: HashMap::new(),
1100            scenario_deltas: vec![],
1101            baseline_overall_pass_rate: 0.8,
1102            comparison_overall_pass_rate: 0.9,
1103            regressed: false,
1104            improved_aliases: vec!["x".to_string()],
1105            regressed_aliases: vec![],
1106            new_aliases: vec![],
1107            removed_aliases: vec![],
1108            new_scenarios: vec![],
1109            removed_scenarios: vec![],
1110            baseline_alias_pass_rates: HashMap::new(),
1111            comparison_alias_pass_rates: HashMap::new(),
1112        };
1113
1114        let dir = tempfile::tempdir().unwrap();
1115        let path = dir.path().join("comp.json");
1116        let path_str = path.to_str().unwrap();
1117
1118        comp.save(path_str).unwrap();
1119        let loaded = ScenarioComparisonResults::load(path_str).unwrap();
1120        assert_eq!(comp, loaded);
1121    }
1122
1123    #[test]
1124    fn compare_to_new_scenarios() {
1125        let baseline = make_eval_results(); // s1, s2
1126        let current = ScenarioEvalResults {
1127            dataset_results: HashMap::new(),
1128            scenario_results: vec![
1129                make_scenario_result("s1", "Make pasta", true, 1.0),
1130                make_scenario_result("s2", "Make curry", false, 0.5),
1131                make_scenario_result("s3", "Make salad", true, 0.8),
1132            ],
1133            metrics: empty_metrics(0.77, 0.67, 3, 2),
1134        };
1135
1136        let comp = current.compare_to(&baseline, 0.05).unwrap();
1137        assert_eq!(comp.new_scenarios, vec!["s3".to_string()]);
1138        assert!(comp.removed_scenarios.is_empty());
1139        // Only shared scenarios appear in deltas
1140        assert_eq!(comp.scenario_deltas.len(), 2);
1141    }
1142
1143    #[test]
1144    fn compare_to_removed_scenarios() {
1145        let baseline = ScenarioEvalResults {
1146            dataset_results: HashMap::new(),
1147            scenario_results: vec![
1148                make_scenario_result("s1", "Make pasta", true, 1.0),
1149                make_scenario_result("s2", "Make curry", false, 0.5),
1150                make_scenario_result("s3", "Make salad", true, 0.8),
1151            ],
1152            metrics: empty_metrics(0.77, 0.67, 3, 2),
1153        };
1154        let current = make_eval_results(); // s1, s2
1155
1156        let comp = current.compare_to(&baseline, 0.05).unwrap();
1157        assert_eq!(comp.removed_scenarios, vec!["s3".to_string()]);
1158        assert!(comp.new_scenarios.is_empty());
1159    }
1160
1161    #[test]
1162    fn compare_to_new_removed_aliases() {
1163        let mut baseline_datasets = HashMap::new();
1164        baseline_datasets.insert("a".to_string(), EvalResults::new());
1165        baseline_datasets.insert("b".to_string(), EvalResults::new());
1166
1167        let mut current_datasets = HashMap::new();
1168        current_datasets.insert("a".to_string(), EvalResults::new());
1169        current_datasets.insert("c".to_string(), EvalResults::new());
1170
1171        let baseline = ScenarioEvalResults {
1172            dataset_results: baseline_datasets,
1173            scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1174            metrics: empty_metrics(1.0, 1.0, 1, 1),
1175        };
1176
1177        let current = ScenarioEvalResults {
1178            dataset_results: current_datasets,
1179            scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1180            metrics: empty_metrics(1.0, 1.0, 1, 1),
1181        };
1182
1183        let comp = current.compare_to(&baseline, 0.05).unwrap();
1184        assert_eq!(comp.new_aliases, vec!["c".to_string()]);
1185        assert_eq!(comp.removed_aliases, vec!["b".to_string()]);
1186    }
1187
1188    #[test]
1189    fn compare_to_alias_pass_rates() {
1190        let mut baseline_metrics = empty_metrics(0.9, 1.0, 1, 1);
1191        baseline_metrics
1192            .dataset_pass_rates
1193            .insert("agent_a".to_string(), 0.9);
1194        baseline_metrics
1195            .dataset_pass_rates
1196            .insert("agent_b".to_string(), 0.8);
1197
1198        let mut current_metrics = empty_metrics(0.85, 1.0, 1, 1);
1199        current_metrics
1200            .dataset_pass_rates
1201            .insert("agent_a".to_string(), 0.85);
1202        current_metrics
1203            .dataset_pass_rates
1204            .insert("agent_b".to_string(), 0.75);
1205
1206        let baseline = ScenarioEvalResults {
1207            dataset_results: HashMap::new(),
1208            scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1209            metrics: baseline_metrics,
1210        };
1211
1212        let current = ScenarioEvalResults {
1213            dataset_results: HashMap::new(),
1214            scenario_results: vec![make_scenario_result("s1", "q", true, 1.0)],
1215            metrics: current_metrics,
1216        };
1217
1218        let comp = current.compare_to(&baseline, 0.05).unwrap();
1219        assert_eq!(comp.baseline_alias_pass_rates.get("agent_a"), Some(&0.9));
1220        assert_eq!(comp.baseline_alias_pass_rates.get("agent_b"), Some(&0.8));
1221        assert_eq!(comp.comparison_alias_pass_rates.get("agent_a"), Some(&0.85));
1222        assert_eq!(comp.comparison_alias_pass_rates.get("agent_b"), Some(&0.75));
1223    }
1224}