Skip to main content

traitclaw_eval/
runner.rs

1//! Async `Metric` trait and `EvalRunner` execution engine.
2//!
3//! # Example
4//!
5//! ```rust
6//! use traitclaw_eval::runner::{AsyncMetric, EvalRunner};
7//! use traitclaw_eval::{EvalSuite, TestCase};
8//! use async_trait::async_trait;
9//!
10//! struct AlwaysOne;
11//!
12//! #[async_trait]
13//! impl AsyncMetric for AlwaysOne {
14//!     fn name(&self) -> &'static str { "always_one" }
15//!     async fn score(&self, _input: &str, _output: &str, _kw: &[&str]) -> f64 { 1.0 }
16//! }
17//!
18//! # async fn example() {
19//! let runner = EvalRunner::new().metric(Box::new(AlwaysOne)).threshold(0.8);
20//! # }
21//! ```
22
23use std::sync::Arc;
24
25use async_trait::async_trait;
26
27use crate::{EvalReport, EvalSuite, TestResult};
28
29/// Async trait for evaluation metrics.
30///
31/// Implement this to add custom scoring logic.
32#[async_trait]
33pub trait AsyncMetric: Send + Sync + 'static {
34    /// Metric name — used as the key in `TestResult.scores`.
35    fn name(&self) -> &'static str;
36
37    /// Score the actual output.
38    ///
39    /// Returns a score from `0.0` (worst) to `1.0` (best).
40    async fn score(&self, input: &str, actual_output: &str, expected_keywords: &[&str]) -> f64;
41}
42
43/// A callable async agent for use with `EvalRunner`.
44///
45/// Returns the agent's response for a given input string.
46#[async_trait]
47pub trait EvalAgent: Send + Sync {
48    /// Run the agent on the given input and return a response.
49    async fn respond(&self, input: &str) -> traitclaw_core::Result<String>;
50}
51
52/// Evaluation runner — executes a suite against an agent using async metrics.
53pub struct EvalRunner {
54    metrics: Vec<Arc<dyn AsyncMetric>>,
55    threshold: f64,
56}
57
58impl EvalRunner {
59    /// Create a new `EvalRunner` with no metrics and default threshold 0.7.
60    #[must_use]
61    pub fn new() -> Self {
62        Self {
63            metrics: Vec::new(),
64            threshold: 0.7,
65        }
66    }
67
68    /// Add a metric to score agent outputs with.
69    #[must_use]
70    pub fn metric(mut self, metric: Box<dyn AsyncMetric>) -> Self {
71        self.metrics.push(Arc::from(metric));
72        self
73    }
74
75    /// Set the minimum score threshold for a test case to pass.
76    ///
77    /// A case passes if **all** metric scores are ≥ threshold.
78    #[must_use]
79    pub fn threshold(mut self, threshold: f64) -> Self {
80        self.threshold = threshold;
81        self
82    }
83
84    /// Execute the evaluation suite against the agent.
85    ///
86    /// For each test case: calls `agent.respond(input)`, scores with all metrics,
87    /// marks passed/failed, and aggregates into an `EvalReport`.
88    ///
89    /// # Errors
90    ///
91    /// Returns an error if the agent fails on any test case.
92    pub async fn run(
93        &self,
94        agent: &dyn EvalAgent,
95        suite: &EvalSuite,
96    ) -> traitclaw_core::Result<EvalReport> {
97        let mut results = Vec::new();
98        let mut total_score = 0.0;
99        let mut passed_count = 0;
100        let mut score_count = 0;
101
102        for case in suite.cases() {
103            let actual_output = agent.respond(&case.input).await?;
104
105            let keywords: Vec<&str> = case.expected_keywords.iter().map(String::as_str).collect();
106
107            let mut scores = std::collections::HashMap::new();
108            for metric in &self.metrics {
109                let s = metric.score(&case.input, &actual_output, &keywords).await;
110                scores.insert(metric.name().to_string(), s);
111                total_score += s;
112                score_count += 1;
113            }
114
115            // If no metrics configured, auto-pass based on keywords (keyword_match)
116            if self.metrics.is_empty() {
117                let kw_score = score_keywords(&actual_output, &keywords);
118                scores.insert("keyword_match".to_string(), kw_score);
119                total_score += kw_score;
120                score_count += 1;
121            }
122
123            let all_pass = scores.values().all(|&s| s >= self.threshold);
124            if all_pass {
125                passed_count += 1;
126            }
127
128            results.push(TestResult {
129                case_id: case.id.clone(),
130                actual_output,
131                scores,
132                passed: all_pass,
133            });
134        }
135
136        let average_score = if score_count > 0 {
137            total_score / score_count as f64
138        } else {
139            0.0
140        };
141
142        Ok(EvalReport {
143            suite_name: suite.name().to_string(),
144            results,
145            average_score,
146            passed: passed_count,
147            total: suite.cases().len(),
148        })
149    }
150}
151
152impl Default for EvalRunner {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158fn score_keywords(output: &str, keywords: &[&str]) -> f64 {
159    if keywords.is_empty() {
160        return 1.0;
161    }
162    let lower = output.to_lowercase();
163    let matched = keywords.iter().filter(|&&kw| lower.contains(kw)).count();
164    matched as f64 / keywords.len() as f64
165}
166
167// ── Adapters for sync Metric ─────────────────────────────────────────────────
168
169/// Wraps a sync `Metric` impl as an `AsyncMetric`.
170pub struct SyncMetricAdapter<M: crate::Metric>(pub M);
171
172#[async_trait]
173impl<M: crate::Metric> AsyncMetric for SyncMetricAdapter<M> {
174    fn name(&self) -> &'static str {
175        self.0.name()
176    }
177
178    async fn score(&self, input: &str, actual_output: &str, expected_keywords: &[&str]) -> f64 {
179        self.0.score(input, actual_output, expected_keywords)
180    }
181}
182
183// ─────────────────────────────────────────────────────────────────────────────
184// Tests
185// ─────────────────────────────────────────────────────────────────────────────
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::{EvalSuite, TestCase};
191
192    struct EchoAgent;
193
194    #[async_trait]
195    impl EvalAgent for EchoAgent {
196        async fn respond(&self, input: &str) -> traitclaw_core::Result<String> {
197            Ok(format!("echo: {input}"))
198        }
199    }
200
201    struct FixedMetric(f64, &'static str);
202
203    #[async_trait]
204    impl AsyncMetric for FixedMetric {
205        fn name(&self) -> &'static str {
206            self.1
207        }
208        async fn score(&self, _: &str, _: &str, _: &[&str]) -> f64 {
209            self.0
210        }
211    }
212
213    struct KeywordAsyncMetric;
214
215    #[async_trait]
216    impl AsyncMetric for KeywordAsyncMetric {
217        fn name(&self) -> &'static str {
218            "keyword"
219        }
220        async fn score(&self, _: &str, output: &str, kw: &[&str]) -> f64 {
221            if kw.is_empty() {
222                return 1.0;
223            }
224            let low = output.to_lowercase();
225            let m = kw.iter().filter(|&&k| low.contains(k)).count();
226            m as f64 / kw.len() as f64
227        }
228    }
229
230    #[tokio::test]
231    async fn test_eval_runner_three_cases() {
232        // AC #7: 3 test cases with KeywordMetric → report with 3 results
233        let suite = EvalSuite::new("suite")
234            .add_case(TestCase::new("c1", "hello").expect_contains("echo"))
235            .add_case(TestCase::new("c2", "world").expect_contains("echo"))
236            .add_case(TestCase::new("c3", "foo").expect_contains("echo"));
237
238        let runner = EvalRunner::new()
239            .metric(Box::new(KeywordAsyncMetric))
240            .threshold(0.8);
241
242        let report = runner.run(&EchoAgent, &suite).await.unwrap();
243
244        assert_eq!(report.results.len(), 3);
245        assert_eq!(report.total, 3);
246        // EchoAgent always includes "echo" → all should pass
247        assert_eq!(report.passed, 3);
248    }
249
250    #[tokio::test]
251    async fn test_eval_runner_threshold_fail() {
252        // AC #8: threshold 0.8 → case scoring 0.0 marked as failed
253        let suite =
254            EvalSuite::new("s").add_case(TestCase::new("c1", "hello").expect_contains("xyzabc")); // won't match
255
256        let runner = EvalRunner::new()
257            .metric(Box::new(KeywordAsyncMetric))
258            .threshold(0.8);
259
260        let report = runner.run(&EchoAgent, &suite).await.unwrap();
261        assert_eq!(report.passed, 0, "case with 0.0 keyword score should fail");
262    }
263
264    #[tokio::test]
265    async fn test_eval_runner_average_score() {
266        // Average across metrics and cases
267        let suite = EvalSuite::new("s")
268            .add_case(TestCase::new("c1", "hello"))
269            .add_case(TestCase::new("c2", "world"));
270
271        let runner = EvalRunner::new()
272            .metric(Box::new(FixedMetric(0.8, "m")))
273            .threshold(0.7);
274
275        let report = runner.run(&EchoAgent, &suite).await.unwrap();
276        assert!((report.average_score - 0.8).abs() < 1e-6);
277        assert_eq!(report.passed, 2);
278    }
279
280    #[tokio::test]
281    async fn test_sync_metric_adapter() {
282        let adapter = SyncMetricAdapter(crate::KeywordMetric);
283        let score = adapter.score("in", "hello world", &["hello"]).await;
284        assert!((score - 1.0).abs() < 1e-6);
285    }
286
287    #[tokio::test]
288    async fn test_empty_suite_gives_zero_results() {
289        let suite = EvalSuite::new("empty");
290        let runner = EvalRunner::new().metric(Box::new(KeywordAsyncMetric));
291        let report = runner.run(&EchoAgent, &suite).await.unwrap();
292        assert_eq!(report.results.len(), 0);
293        assert_eq!(report.total, 0);
294    }
295}