Skip to main content

scouter_evaluate/
runner.rs

1use crate::error::EvaluationError;
2use crate::evaluate::evaluator::GenAIEvaluator;
3use crate::evaluate::scenario_results::{
4    EvalMetrics, ScenarioEvalResults, ScenarioResult, TaskSummary,
5};
6use crate::evaluate::types::{EvalResults, EvaluationConfig};
7use crate::genai::{evaluate_genai_dataset, EvalDataset};
8use crate::scenario::EvalScenarios;
9use pyo3::prelude::*;
10use scouter_state::app_state;
11use scouter_types::genai::EvalScenario;
12use scouter_types::genai::{GenAIEvalConfig, GenAIEvalProfile};
13use scouter_types::trace::build_trace_spans;
14use scouter_types::trace::sql::TraceSpan;
15use scouter_types::EvalRecord;
16use scouter_types::TraceId as ScouterTraceId;
17use serde_json::json;
18use std::collections::{BTreeMap, HashMap, HashSet};
19use std::sync::Arc;
20use tracing::{debug, error, warn};
21
22struct AliasData {
23    records: Vec<EvalRecord>,
24    profile: Option<Arc<GenAIEvalProfile>>,
25    spans: Vec<TraceSpan>,
26}
27
28/// Stateful evaluation engine that orchestrates scenario evaluation.
29///
30/// `EvalRunner` owns the scenario definitions and profiles (as `Arc`s),
31/// mirroring the `ScouterQueue` pattern. It provides:
32/// - `collect_scenario_data()`: Populates scenario datasets and contexts
33/// - `evaluate()`: Runs multi-level evaluation (sub-agent + scenario + aggregate),
34///   pulling captured spans from the global buffer automatically.
35#[derive(Debug)]
36#[pyclass]
37pub struct EvalRunner {
38    profiles: HashMap<String, Arc<GenAIEvalProfile>>,
39    scenarios: EvalScenarios,
40}
41
42#[pymethods]
43impl EvalRunner {
44    #[new]
45    #[pyo3(signature = (scenarios, profiles))]
46    pub fn new(scenarios: EvalScenarios, profiles: HashMap<String, GenAIEvalProfile>) -> Self {
47        let arc_profiles: HashMap<String, Arc<GenAIEvalProfile>> = profiles
48            .into_iter()
49            .map(|(k, v)| (k, Arc::new(v)))
50            .collect();
51        Self {
52            profiles: arc_profiles,
53            scenarios,
54        }
55    }
56
57    #[getter]
58    pub fn scenarios(&self) -> EvalScenarios {
59        self.scenarios.clone()
60    }
61
62    /// Run multi-level evaluation.
63    ///
64    /// Spans are pulled automatically from the global capture buffer —
65    /// no need to pass them explicitly.
66    ///
67    /// **LEVEL 1** — Sub-agent evaluation: flatten all records per alias → one EvalDataset → evaluate
68    /// **LEVEL 2** — Scenario evaluation: per scenario with tasks, evaluate against response + traces
69    /// **LEVEL 3** — Aggregate metrics
70    #[pyo3(signature = (config=None))]
71    pub fn evaluate(
72        &mut self,
73        config: Option<EvaluationConfig>,
74    ) -> Result<ScenarioEvalResults, EvaluationError> {
75        let config = Arc::new(config.unwrap_or_default());
76
77        if tokio::runtime::Handle::try_current().is_ok() {
78            return Err(EvaluationError::GenAIEvaluatorError(
79                "EvalRunner.evaluate() cannot be called from within an async context. \
80                 Use evaluate_async() or call from a synchronous Python context."
81                    .to_string(),
82            ));
83        }
84
85        app_state()
86            .handle()
87            .block_on(async { self.evaluate_async(&config).await })
88    }
89
90    /// Populate scenario data into the internal `EvalScenarios` container.
91    ///
92    /// # Arguments
93    /// * `records` - Map of alias → eval records for this scenario
94    /// * `response` - The agent's final response for this scenario
95    /// * `scenario` - Reference to the scenario definition
96    #[pyo3(signature = (records, response, scenario))]
97    pub fn collect_scenario_data(
98        &mut self,
99        records: HashMap<String, Vec<EvalRecord>>,
100        response: String,
101        scenario: &EvalScenario,
102    ) -> Result<(), EvaluationError> {
103        let mut alias_datasets: HashMap<String, EvalDataset> = HashMap::new();
104        let scenario_id = scenario.id.clone();
105
106        for (alias, mut alias_records) in records {
107            // Tag each record with the scenario_id
108            let scenario_tag = format!("scouter.eval.scenario_id={}", scenario_id);
109            for record in &mut alias_records {
110                if !record.tags.contains(&scenario_tag) {
111                    record.tags.push(scenario_tag.clone());
112                }
113            }
114
115            let profile = self.profiles.get(&alias).ok_or_else(|| {
116                EvaluationError::MissingKeyError(format!(
117                    "No profile found for alias '{}' in scenario '{}'",
118                    alias, scenario_id
119                ))
120            })?;
121
122            alias_datasets.insert(
123                alias,
124                EvalDataset {
125                    records: Arc::new(alias_records),
126                    profile: Arc::clone(profile),
127                    spans: Arc::new(vec![]),
128                },
129            );
130        }
131
132        if self.scenarios.scenario_datasets.contains_key(&scenario_id) {
133            return Err(EvaluationError::MissingKeyError(format!(
134                "Scenario '{}' already has data — collect_scenario_data called twice",
135                scenario_id
136            )));
137        }
138
139        self.scenarios
140            .scenario_datasets
141            .insert(scenario_id.clone(), alias_datasets);
142
143        // Build context JSON for scenario-level evaluation
144        let context = json!({
145            "response": response,
146            "expected_outcome": scenario.expected_outcome,
147            "metadata": scenario.metadata,
148        });
149        self.scenarios
150            .scenario_contexts
151            .insert(scenario_id, context);
152
153        Ok(())
154    }
155}
156
157impl EvalRunner {
158    async fn evaluate_async(
159        &mut self,
160        config: &Arc<EvaluationConfig>,
161    ) -> Result<ScenarioEvalResults, EvaluationError> {
162        let scenario_ids: HashSet<String> = self
163            .scenarios
164            .scenarios
165            .iter()
166            .map(|s| s.id.clone())
167            .collect();
168
169        // Single buffer scan → grouped by scenario_id (via span attribute)
170        let mut raw_by_scenario =
171            scouter_types::span_capture::get_spans_grouped_by_scenario_id(&scenario_ids);
172
173        // Also resolve spans via EvalRecord.trace_id for scenarios whose spans were not
174        // captured with the scouter.eval.scenario_id attribute (e.g. plain OTel spans).
175        let mut record_trace_to_scenario: HashMap<ScouterTraceId, String> = HashMap::new();
176        for (scenario_id, datasets) in &self.scenarios.scenario_datasets {
177            for dataset in datasets.values() {
178                for record in dataset.records.iter() {
179                    if let Some(tid) = record.trace_id {
180                        record_trace_to_scenario.insert(tid, scenario_id.clone());
181                    }
182                }
183            }
184        }
185
186        if !record_trace_to_scenario.is_empty() {
187            // Only fetch trace_ids not already covered by the attribute-based pass
188            let already_covered: HashSet<ScouterTraceId> = raw_by_scenario
189                .values()
190                .flat_map(|spans: &Vec<_>| spans.iter().map(|s| s.trace_id))
191                .collect();
192            let new_trace_ids: HashSet<ScouterTraceId> = record_trace_to_scenario
193                .keys()
194                .filter(|tid| !already_covered.contains(tid))
195                .copied()
196                .collect();
197            if !new_trace_ids.is_empty() {
198                let extra =
199                    scouter_types::span_capture::get_captured_spans_by_trace_ids(&new_trace_ids);
200                for span in extra {
201                    if let Some(sid) = record_trace_to_scenario.get(&span.trace_id) {
202                        raw_by_scenario.entry(sid.clone()).or_default().push(span);
203                    }
204                }
205            }
206        }
207
208        // Convert each group to tree-enriched TraceSpans
209        let spans_by_scenario: HashMap<String, Arc<Vec<TraceSpan>>> = raw_by_scenario
210            .into_iter()
211            .map(|(id, records)| (id, Arc::new(build_trace_spans(records))))
212            .collect();
213
214        // Assign spans to datasets by scenario_id key lookup
215        for (scenario_id, datasets) in &mut self.scenarios.scenario_datasets {
216            if let Some(spans) = spans_by_scenario.get(scenario_id) {
217                for dataset in datasets.values_mut() {
218                    dataset.spans = Arc::clone(spans);
219                }
220            }
221        }
222
223        // Level 1: Sub-agent evaluation
224        let dataset_results = self.evaluate_datasets(config).await?;
225
226        // Level 2: Scenario evaluation
227        let scenario_results = self.evaluate_scenarios(&spans_by_scenario).await?;
228
229        // Level 3: Aggregate metrics
230        let metrics = compute_metrics(&dataset_results, &scenario_results);
231
232        // Store results on the EvalScenarios container
233        self.scenarios.dataset_results = dataset_results.clone();
234        self.scenarios.scenario_results = scenario_results.clone();
235        self.scenarios.metrics = Some(metrics.clone());
236
237        Ok(ScenarioEvalResults {
238            dataset_results,
239            scenario_results,
240            metrics,
241        })
242    }
243
244    /// LEVEL 1: For each alias across all scenarios, flatten records into one
245    /// dataset and evaluate holistically.
246    async fn evaluate_datasets(
247        &self,
248        config: &Arc<EvaluationConfig>,
249    ) -> Result<HashMap<String, EvalResults>, EvaluationError> {
250        // Collect all aliases
251        let mut alias_data: HashMap<String, AliasData> = HashMap::new();
252
253        for datasets in self.scenarios.scenario_datasets.values() {
254            for (alias, dataset) in datasets {
255                let entry = alias_data
256                    .entry(alias.clone())
257                    .or_insert_with(|| AliasData {
258                        records: Vec::new(),
259                        profile: None,
260                        spans: Vec::new(),
261                    });
262                entry.records.extend(dataset.records.iter().cloned());
263                if entry.profile.is_none() {
264                    entry.profile = Some(Arc::clone(&dataset.profile));
265                } else {
266                    // First-seen profile wins per alias — warn when a subsequent scenario
267                    // provides a different profile instance for the same alias.
268                    warn!(
269                        "Alias '{}': profile already set — first-seen profile wins; ignoring profile from a subsequent scenario. Ensure all scenarios use the same profile for this alias.",
270                        alias
271                    );
272                }
273                entry.spans.extend(dataset.spans.iter().cloned());
274            }
275        }
276
277        let mut results = HashMap::new();
278
279        for (
280            alias,
281            AliasData {
282                records,
283                profile,
284                spans,
285            },
286        ) in alias_data
287        {
288            if records.is_empty() {
289                continue;
290            }
291
292            let profile = match profile {
293                Some(p) => p,
294                None => continue,
295            };
296
297            let dataset = EvalDataset {
298                records: Arc::new(records),
299                profile,
300                spans: Arc::new(spans),
301            };
302
303            debug!("Evaluating sub-agent dataset for alias '{}'", alias);
304            match evaluate_genai_dataset(&dataset, config).await {
305                Ok(eval_results) => {
306                    results.insert(alias, eval_results);
307                }
308                Err(e) => {
309                    error!("Failed to evaluate dataset for alias '{}': {:?}", alias, e);
310                    return Err(e);
311                }
312            }
313        }
314
315        Ok(results)
316    }
317
318    /// LEVEL 2: For each scenario that has tasks, build a record from
319    /// the scenario context and evaluate against the profile + spans looked up by scenario_id.
320    async fn evaluate_scenarios(
321        &self,
322        spans_by_scenario: &HashMap<String, Arc<Vec<TraceSpan>>>,
323    ) -> Result<Vec<ScenarioResult>, EvaluationError> {
324        let mut results = Vec::new();
325
326        for scenario in &self.scenarios.scenarios {
327            if !scenario.has_tasks() {
328                continue;
329            }
330
331            let context = self
332                .scenarios
333                .scenario_contexts
334                .get(&scenario.id)
335                .cloned()
336                .ok_or_else(|| {
337                    EvaluationError::MissingKeyError(format!(
338                        "Scenario '{}' has tasks but no context — call collect_scenario_data() first",
339                        scenario.id
340                    ))
341                })?;
342
343            // Build EvalRecord from scenario context
344            let record = EvalRecord {
345                context,
346                record_id: scenario.id.clone(),
347                tags: vec![format!("scouter.eval.scenario_id={}", scenario.id)],
348                ..Default::default()
349            };
350
351            // Build profile from scenario tasks
352            let profile = GenAIEvalProfile::build_from_parts_async(
353                GenAIEvalConfig::default(),
354                scenario.tasks.clone(),
355                None,
356            )
357            .await?;
358            let profile = Arc::new(profile);
359
360            // Look up spans directly by scenario_id — no filtering needed
361            let spans_arc = spans_by_scenario
362                .get(&scenario.id)
363                .cloned()
364                .unwrap_or_else(|| Arc::new(Vec::new()));
365
366            // Evaluate
367            match GenAIEvaluator::process_event_record(&record, profile, spans_arc).await {
368                Ok(eval_set) => {
369                    let task_results: Vec<TaskSummary> = eval_set
370                        .records
371                        .iter()
372                        .map(|r| TaskSummary {
373                            task_id: r.task_id.clone(),
374                            passed: r.passed,
375                            value: if r.passed { 1.0 } else { 0.0 },
376                        })
377                        .collect();
378
379                    let mut eval_results = EvalResults::new();
380                    eval_results.add_success(&record, eval_set, BTreeMap::new());
381
382                    let (passed, pass_rate) = compute_pass_rate(&eval_results);
383
384                    results.push(ScenarioResult {
385                        scenario_id: scenario.id.clone(),
386                        initial_query: scenario.initial_query.clone(),
387                        eval_results,
388                        passed,
389                        pass_rate,
390                        task_results,
391                    });
392                }
393                Err(e) => {
394                    error!("Failed to evaluate scenario '{}': {:?}", scenario.id, e);
395                    let mut eval_results = EvalResults::new();
396                    eval_results.add_failure(&record, e.to_string());
397
398                    results.push(ScenarioResult {
399                        scenario_id: scenario.id.clone(),
400                        initial_query: scenario.initial_query.clone(),
401                        eval_results,
402                        passed: false,
403                        pass_rate: 0.0,
404                        task_results: vec![],
405                    });
406                }
407            }
408        }
409
410        Ok(results)
411    }
412}
413
414/// LEVEL 3: Compute aggregate metrics
415fn compute_metrics(
416    dataset_results: &HashMap<String, EvalResults>,
417    scenario_results: &[ScenarioResult],
418) -> EvalMetrics {
419    let mut dataset_pass_rates: HashMap<String, f64> = HashMap::new();
420    for (alias, results) in dataset_results {
421        let (_, pass_rate) = compute_pass_rate(results);
422        dataset_pass_rates.insert(alias.clone(), pass_rate);
423    }
424
425    let total_scenarios = scenario_results.len();
426    let passed_scenarios = scenario_results.iter().filter(|s| s.passed).count();
427    let scenario_pass_rate = if total_scenarios > 0 {
428        passed_scenarios as f64 / total_scenarios as f64
429    } else {
430        0.0
431    };
432
433    let mut all_rates: Vec<f64> = dataset_pass_rates.values().copied().collect();
434    if total_scenarios > 0 {
435        all_rates.push(scenario_pass_rate);
436    }
437    let overall_pass_rate = if all_rates.is_empty() {
438        0.0
439    } else {
440        all_rates.iter().sum::<f64>() / all_rates.len() as f64
441    };
442
443    // Build per-scenario, per-task pass rates
444    let mut scenario_task_pass_rates: HashMap<String, HashMap<String, f64>> = HashMap::new();
445    for sr in scenario_results {
446        if !sr.task_results.is_empty() {
447            let task_rates: HashMap<String, f64> = sr
448                .task_results
449                .iter()
450                .map(|t| (t.task_id.clone(), if t.passed { 1.0 } else { 0.0 }))
451                .collect();
452            scenario_task_pass_rates.insert(sr.scenario_id.clone(), task_rates);
453        }
454    }
455
456    EvalMetrics {
457        overall_pass_rate,
458        dataset_pass_rates,
459        scenario_pass_rate,
460        total_scenarios,
461        passed_scenarios,
462        scenario_task_pass_rates,
463    }
464}
465
466/// Compute pass/fail and pass rate from EvalResults
467fn compute_pass_rate(results: &EvalResults) -> (bool, f64) {
468    if results.aligned_results.is_empty() {
469        return (false, 0.0);
470    }
471
472    let mut total_tasks = 0;
473    let mut passed_tasks = 0;
474
475    for aligned in &results.aligned_results {
476        for task_result in &aligned.eval_set.records {
477            total_tasks += 1;
478            if task_result.passed {
479                passed_tasks += 1;
480            }
481        }
482    }
483
484    if total_tasks == 0 {
485        return (false, 0.0);
486    }
487
488    let pass_rate = passed_tasks as f64 / total_tasks as f64;
489    let passed = passed_tasks == total_tasks;
490    (passed, pass_rate)
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use scouter_types::genai::utils::AssertionTasks;
497    use scouter_types::genai::EvalScenario;
498
499    fn empty_tasks() -> AssertionTasks {
500        AssertionTasks {
501            assertion: vec![],
502            judge: vec![],
503            trace: vec![],
504            agent: vec![],
505        }
506    }
507
508    fn make_scenario(id: &str, query: &str) -> EvalScenario {
509        EvalScenario {
510            id: id.to_string(),
511            initial_query: query.to_string(),
512            predefined_turns: vec![],
513            simulated_user_persona: None,
514            termination_signal: None,
515            max_turns: 10,
516            expected_outcome: Some("Expected output".to_string()),
517            tasks: empty_tasks(),
518            metadata: None,
519        }
520    }
521
522    fn make_scenario_with_tasks(id: &str, query: &str) -> EvalScenario {
523        use scouter_types::genai::{AssertionTask, ComparisonOperator, EvaluationTaskType};
524
525        let task = AssertionTask {
526            id: "check_response".to_string(),
527            context_path: Some("response".to_string()),
528            item_context_path: None,
529            operator: ComparisonOperator::Contains,
530            expected_value: serde_json::Value::String("hello".to_string()),
531            description: None,
532            depends_on: vec![],
533            task_type: EvaluationTaskType::Assertion,
534            result: None,
535            condition: false,
536        };
537
538        EvalScenario {
539            id: id.to_string(),
540            initial_query: query.to_string(),
541            predefined_turns: vec![],
542            simulated_user_persona: None,
543            termination_signal: None,
544            max_turns: 10,
545            expected_outcome: Some("Response contains hello".to_string()),
546            tasks: AssertionTasks {
547                assertion: vec![task],
548                judge: vec![],
549                trace: vec![],
550                agent: vec![],
551            },
552            metadata: None,
553        }
554    }
555
556    fn make_default_profiles() -> HashMap<String, GenAIEvalProfile> {
557        let mut profiles = HashMap::new();
558        profiles.insert("agent_a".to_string(), GenAIEvalProfile::default());
559        profiles
560    }
561
562    #[test]
563    fn collect_scenario_data_stores_datasets_and_contexts() {
564        let mut runner = EvalRunner::new(
565            EvalScenarios::new(vec![make_scenario("s1", "Hello")]),
566            make_default_profiles(),
567        );
568
569        let mut records = HashMap::new();
570        let record = EvalRecord::default();
571        records.insert("agent_a".to_string(), vec![record]);
572
573        let scenario = runner.scenarios.scenarios[0].clone();
574
575        runner
576            .collect_scenario_data(records, "Agent response".to_string(), &scenario)
577            .unwrap();
578
579        assert!(runner.scenarios.scenario_datasets.contains_key("s1"));
580        let datasets = &runner.scenarios.scenario_datasets["s1"];
581        assert!(datasets.contains_key("agent_a"));
582        assert_eq!(datasets["agent_a"].records.len(), 1);
583
584        assert!(datasets["agent_a"].records[0]
585            .tags
586            .contains(&"scouter.eval.scenario_id=s1".to_string()));
587
588        assert!(runner.scenarios.scenario_contexts.contains_key("s1"));
589        let ctx = &runner.scenarios.scenario_contexts["s1"];
590        assert_eq!(ctx["response"], "Agent response");
591    }
592
593    #[test]
594    fn collect_scenario_data_duplicate_returns_error() {
595        let mut runner = EvalRunner::new(
596            EvalScenarios::new(vec![make_scenario("s1", "Hello")]),
597            make_default_profiles(),
598        );
599        let mut records = HashMap::new();
600        records.insert("agent_a".to_string(), vec![EvalRecord::default()]);
601        let scenario = runner.scenarios.scenarios[0].clone();
602
603        runner
604            .collect_scenario_data(records.clone(), "Response".to_string(), &scenario)
605            .unwrap();
606        let result = runner.collect_scenario_data(records, "Response again".to_string(), &scenario);
607        assert!(result.is_err());
608        assert!(result.unwrap_err().to_string().contains("already has data"));
609    }
610
611    #[test]
612    fn collect_scenario_data_missing_profile_errors() {
613        let mut runner = EvalRunner::new(
614            EvalScenarios::new(vec![make_scenario("s1", "Hello")]),
615            HashMap::new(),
616        );
617
618        let mut records = HashMap::new();
619        records.insert("agent_a".to_string(), vec![EvalRecord::default()]);
620
621        let scenario = runner.scenarios.scenarios[0].clone();
622
623        let result = runner.collect_scenario_data(records, "Response".to_string(), &scenario);
624
625        assert!(result.is_err());
626    }
627
628    #[test]
629    fn collect_scenario_data_multiple_aliases() {
630        let mut profiles = make_default_profiles();
631        profiles.insert("agent_b".to_string(), GenAIEvalProfile::default());
632
633        let mut runner = EvalRunner::new(
634            EvalScenarios::new(vec![make_scenario("s1", "Hello")]),
635            profiles,
636        );
637
638        let mut records = HashMap::new();
639        records.insert("agent_a".to_string(), vec![EvalRecord::default()]);
640        records.insert(
641            "agent_b".to_string(),
642            vec![EvalRecord::default(), EvalRecord::default()],
643        );
644
645        let scenario = runner.scenarios.scenarios[0].clone();
646        runner
647            .collect_scenario_data(records, "Response".to_string(), &scenario)
648            .unwrap();
649
650        let datasets = &runner.scenarios.scenario_datasets["s1"];
651        assert_eq!(datasets["agent_a"].records.len(), 1);
652        assert_eq!(datasets["agent_b"].records.len(), 2);
653    }
654
655    #[test]
656    fn evaluate_no_tasks_only_datasets() {
657        let mut runner = EvalRunner::new(
658            EvalScenarios::new(vec![make_scenario("s1", "Hello")]),
659            make_default_profiles(),
660        );
661
662        let mut records = HashMap::new();
663        records.insert("agent_a".to_string(), vec![EvalRecord::default()]);
664
665        let scenario = runner.scenarios.scenarios[0].clone();
666        runner
667            .collect_scenario_data(records, "Response".to_string(), &scenario)
668            .unwrap();
669
670        let result = runner.evaluate(None).unwrap();
671
672        assert!(result.dataset_results.contains_key("agent_a"));
673        assert!(result.scenario_results.is_empty());
674        assert!(result.metrics.dataset_pass_rates.contains_key("agent_a"));
675        assert_eq!(result.metrics.total_scenarios, 0);
676    }
677
678    #[test]
679    fn evaluate_with_assertion_tasks() {
680        let scenario = make_scenario_with_tasks("s1", "Say hello");
681        let mut runner =
682            EvalRunner::new(EvalScenarios::new(vec![scenario.clone()]), HashMap::new());
683
684        let context = json!({
685            "response": "hello world",
686            "expected_outcome": "Response contains hello",
687            "metadata": null,
688        });
689        runner
690            .scenarios
691            .scenario_contexts
692            .insert("s1".to_string(), context);
693
694        let result = runner.evaluate(None).unwrap();
695
696        assert_eq!(result.scenario_results.len(), 1);
697        assert_eq!(result.scenario_results[0].scenario_id, "s1");
698        assert!(result.scenario_results[0].passed);
699        assert_eq!(result.scenario_results[0].pass_rate, 1.0);
700        assert_eq!(result.metrics.total_scenarios, 1);
701        assert_eq!(result.metrics.passed_scenarios, 1);
702        assert_eq!(result.metrics.scenario_pass_rate, 1.0);
703    }
704
705    #[test]
706    fn evaluate_with_failing_assertion() {
707        let scenario = make_scenario_with_tasks("s1", "Say hello");
708        let mut runner =
709            EvalRunner::new(EvalScenarios::new(vec![scenario.clone()]), HashMap::new());
710
711        let context = json!({
712            "response": "goodbye world",
713            "expected_outcome": "Response contains hello",
714            "metadata": null,
715        });
716        runner
717            .scenarios
718            .scenario_contexts
719            .insert("s1".to_string(), context);
720
721        let result = runner.evaluate(None).unwrap();
722
723        assert_eq!(result.scenario_results.len(), 1);
724        assert!(!result.scenario_results[0].passed);
725        assert_eq!(result.scenario_results[0].pass_rate, 0.0);
726        assert_eq!(result.metrics.passed_scenarios, 0);
727    }
728
729    #[test]
730    fn evaluate_with_assertion_tasks_populates_task_results() {
731        let scenario = make_scenario_with_tasks("s1", "Say hello");
732        let mut runner =
733            EvalRunner::new(EvalScenarios::new(vec![scenario.clone()]), HashMap::new());
734
735        let context = json!({
736            "response": "hello world",
737            "expected_outcome": "Response contains hello",
738            "metadata": null,
739        });
740        runner
741            .scenarios
742            .scenario_contexts
743            .insert("s1".to_string(), context);
744
745        let result = runner.evaluate(None).unwrap();
746
747        let sr = &result.scenario_results[0];
748        assert_eq!(sr.task_results.len(), 1);
749        assert_eq!(sr.task_results[0].task_id, "check_response");
750        assert!(sr.task_results[0].passed);
751        assert_eq!(sr.task_results[0].value, 1.0);
752    }
753
754    #[test]
755    fn evaluate_with_failing_assertion_task_results() {
756        let scenario = make_scenario_with_tasks("s1", "Say hello");
757        let mut runner =
758            EvalRunner::new(EvalScenarios::new(vec![scenario.clone()]), HashMap::new());
759
760        let context = json!({
761            "response": "goodbye world",
762            "expected_outcome": "Response contains hello",
763            "metadata": null,
764        });
765        runner
766            .scenarios
767            .scenario_contexts
768            .insert("s1".to_string(), context);
769
770        let result = runner.evaluate(None).unwrap();
771
772        let sr = &result.scenario_results[0];
773        assert_eq!(sr.task_results.len(), 1);
774        assert_eq!(sr.task_results[0].task_id, "check_response");
775        assert!(!sr.task_results[0].passed);
776        assert_eq!(sr.task_results[0].value, 0.0);
777
778        let rates = &result.metrics.scenario_task_pass_rates;
779        assert!(rates.contains_key("s1"));
780        assert_eq!(rates["s1"]["check_response"], 0.0);
781    }
782
783    #[test]
784    fn compute_metrics_scenario_task_pass_rates() {
785        use crate::evaluate::scenario_results::{ScenarioResult, TaskSummary};
786        use crate::evaluate::types::EvalResults;
787
788        let make_result = |id: &str, passed: bool| ScenarioResult {
789            scenario_id: id.to_string(),
790            initial_query: "q".to_string(),
791            eval_results: EvalResults::new(),
792            passed,
793            pass_rate: if passed { 1.0 } else { 0.0 },
794            task_results: vec![
795                TaskSummary {
796                    task_id: "t1".to_string(),
797                    passed,
798                    value: if passed { 1.0 } else { 0.0 },
799                },
800                TaskSummary {
801                    task_id: "t2".to_string(),
802                    passed: true,
803                    value: 1.0,
804                },
805            ],
806        };
807
808        let scenario_results = vec![make_result("s1", true), make_result("s2", false)];
809        let metrics = compute_metrics(&HashMap::new(), &scenario_results);
810
811        assert!(metrics.scenario_task_pass_rates.contains_key("s1"));
812        assert!(metrics.scenario_task_pass_rates.contains_key("s2"));
813        assert_eq!(metrics.scenario_task_pass_rates["s1"]["t1"], 1.0);
814        assert_eq!(metrics.scenario_task_pass_rates["s2"]["t1"], 0.0);
815        assert_eq!(metrics.scenario_task_pass_rates["s1"]["t2"], 1.0);
816        assert_eq!(metrics.scenario_task_pass_rates["s2"]["t2"], 1.0);
817    }
818
819    #[test]
820    fn evaluate_scenario_with_tasks_but_no_context_errors() {
821        let scenario = make_scenario_with_tasks("s1", "Say hello");
822        let mut runner = EvalRunner::new(EvalScenarios::new(vec![scenario]), HashMap::new());
823
824        let result = runner.evaluate(None);
825        assert!(result.is_err());
826        let err_msg = result.unwrap_err().to_string();
827        assert!(err_msg.contains("no context"));
828    }
829
830    #[test]
831    fn compute_pass_rate_empty_results() {
832        let results = EvalResults::new();
833        let (passed, rate) = compute_pass_rate(&results);
834        assert!(!passed);
835        assert_eq!(rate, 0.0);
836    }
837
838    #[test]
839    fn compute_pass_rate_zero_tasks() {
840        let mut results = EvalResults::new();
841        let record = EvalRecord::default();
842        let eval_set = scouter_types::genai::EvalSet::new(vec![], Default::default());
843        results.add_success(&record, eval_set, BTreeMap::new());
844
845        let (passed, rate) = compute_pass_rate(&results);
846        assert!(!passed);
847        assert_eq!(rate, 0.0);
848    }
849
850    #[test]
851    fn evaluate_multiple_scenarios_mixed_results() {
852        let s_pass = make_scenario_with_tasks("s_pass", "Say hello");
853        let s_fail = make_scenario_with_tasks("s_fail", "Say hello");
854        let mut runner = EvalRunner::new(EvalScenarios::new(vec![s_pass, s_fail]), HashMap::new());
855
856        runner.scenarios.scenario_contexts.insert(
857            "s_pass".to_string(),
858            json!({"response": "hello world", "expected_outcome": null, "metadata": null}),
859        );
860        runner.scenarios.scenario_contexts.insert(
861            "s_fail".to_string(),
862            json!({"response": "goodbye", "expected_outcome": null, "metadata": null}),
863        );
864
865        let result = runner.evaluate(None).unwrap();
866        assert_eq!(result.scenario_results.len(), 2);
867        assert_eq!(result.metrics.total_scenarios, 2);
868        assert_eq!(result.metrics.passed_scenarios, 1);
869        assert_eq!(result.metrics.scenario_pass_rate, 0.5);
870    }
871
872    #[test]
873    fn compute_metrics_empty() {
874        let metrics = compute_metrics(&HashMap::new(), &[]);
875
876        assert_eq!(metrics.overall_pass_rate, 0.0);
877        assert_eq!(metrics.scenario_pass_rate, 0.0);
878        assert_eq!(metrics.total_scenarios, 0);
879        assert_eq!(metrics.passed_scenarios, 0);
880    }
881
882    #[test]
883    fn compute_metrics_with_scenario_results() {
884        let scenario_results = vec![
885            ScenarioResult {
886                scenario_id: "s1".to_string(),
887                initial_query: "Q1".to_string(),
888                eval_results: EvalResults::new(),
889                passed: true,
890                pass_rate: 1.0,
891                task_results: vec![],
892            },
893            ScenarioResult {
894                scenario_id: "s2".to_string(),
895                initial_query: "Q2".to_string(),
896                eval_results: EvalResults::new(),
897                passed: false,
898                pass_rate: 0.5,
899                task_results: vec![],
900            },
901        ];
902
903        let metrics = compute_metrics(&HashMap::new(), &scenario_results);
904
905        assert_eq!(metrics.total_scenarios, 2);
906        assert_eq!(metrics.passed_scenarios, 1);
907        assert_eq!(metrics.scenario_pass_rate, 0.5);
908        assert_eq!(metrics.overall_pass_rate, 0.5);
909    }
910}