Skip to main content

traitclaw_eval/
metrics.rs

1//! Specialized metrics for LLM output evaluation.
2//!
3//! - [`LlmJudgeMetric`]: LLM-powered quality scoring
4//! - [`SchemaValidationMetric`]: JSON schema schema validation
5//! - [`ToolUsageMetric`]: verifies expected tool calls were made
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use crate::runner::AsyncMetric;
12
13// ─────────────────────────────────────────────────────────────────────────────
14// LLM Judge provider trait
15// ─────────────────────────────────────────────────────────────────────────────
16
17/// A minimal provider interface for LLM judge calls.
18///
19/// Implement this to connect `LlmJudgeMetric` to any LLM backend.
20#[async_trait]
21pub trait JudgeProvider: Send + Sync + 'static {
22    /// Call the LLM with the given prompt, return the text response.
23    async fn complete(&self, prompt: &str) -> traitclaw_core::Result<String>;
24}
25
26// ─────────────────────────────────────────────────────────────────────────────
27// LlmJudgeMetric
28// ─────────────────────────────────────────────────────────────────────────────
29
30/// LLM-based evaluation metric with named criteria.
31///
32/// Calls an LLM with a custom evaluation prompt and parses a 0.0–1.0 score.
33///
34/// # Example
35///
36/// ```rust
37/// use traitclaw_eval::metrics::{JudgeProvider, LlmJudgeMetric};
38/// use async_trait::async_trait;
39///
40/// struct MockJudge;
41///
42/// #[async_trait]
43/// impl JudgeProvider for MockJudge {
44///     async fn complete(&self, _prompt: &str) -> traitclaw_core::Result<String> {
45///         Ok("Score: 0.85".to_string())
46///     }
47/// }
48///
49/// let metric = LlmJudgeMetric::new(MockJudge)
50///     .with_criteria("accuracy", "Is the answer factually correct?");
51/// ```
52pub struct LlmJudgeMetric<P: JudgeProvider> {
53    provider: Arc<P>,
54    criteria: Vec<(String, String)>,
55}
56
57impl<P: JudgeProvider> LlmJudgeMetric<P> {
58    /// Create a new `LlmJudgeMetric` backed by the given provider.
59    #[must_use]
60    pub fn new(provider: P) -> Self {
61        Self {
62            provider: Arc::new(provider),
63            criteria: Vec::new(),
64        }
65    }
66
67    /// Add a named evaluation criterion.
68    ///
69    /// The criterion name and prompt are used to build the judge prompt.
70    #[must_use]
71    pub fn with_criteria(mut self, name: impl Into<String>, prompt: impl Into<String>) -> Self {
72        self.criteria.push((name.into(), prompt.into()));
73        self
74    }
75}
76
77#[async_trait]
78impl<P: JudgeProvider> AsyncMetric for LlmJudgeMetric<P> {
79    fn name(&self) -> &'static str {
80        "llm_judge"
81    }
82
83    async fn score(&self, input: &str, actual_output: &str, _kw: &[&str]) -> f64 {
84        let criteria_text = if self.criteria.is_empty() {
85            "Is this a high-quality response?".to_string()
86        } else {
87            self.criteria
88                .iter()
89                .map(|(name, prompt)| format!("- {name}: {prompt}"))
90                .collect::<Vec<_>>()
91                .join("\n")
92        };
93
94        let prompt = format!(
95            "Evaluate the following agent response:\n\nInput: {input}\n\nResponse: {actual_output}\n\nCriteria:\n{criteria_text}\n\nProvide a score from 0.0 to 1.0. Respond with only: Score: <number>"
96        );
97
98        match self.provider.complete(&prompt).await {
99            Ok(response) => parse_score(&response),
100            Err(_) => 0.0,
101        }
102    }
103}
104
105/// Parse `Score: 0.85` or standalone `0.85` from LLM response.
106pub(crate) fn parse_score(response: &str) -> f64 {
107    // Try "Score: X.XX" format first
108    for line in response.lines() {
109        let line = line.trim();
110        if let Some(rest) = line.strip_prefix("Score:") {
111            if let Ok(score) = rest.trim().parse::<f64>() {
112                return score.clamp(0.0, 1.0);
113            }
114        }
115        // Fallback: try parsing the whole line as a number
116        if let Ok(score) = line.parse::<f64>() {
117            return score.clamp(0.0, 1.0);
118        }
119    }
120    0.0
121}
122
123// ─────────────────────────────────────────────────────────────────────────────
124// SchemaValidationMetric
125// ─────────────────────────────────────────────────────────────────────────────
126
127/// Validates that the agent output is valid JSON matching an expected schema shape.
128///
129/// The "schema" here is a `serde_json::Value` used as a **template** — all keys
130/// present in the schema must be present in the output. A full JSON schema validator
131/// would require an external crate; this is a lightweight key-presence check.
132///
133/// # Example
134///
135/// ```rust
136/// use traitclaw_eval::metrics::SchemaValidationMetric;
137/// use traitclaw_eval::runner::AsyncMetric;
138///
139/// let metric = SchemaValidationMetric::new(serde_json::json!({
140///     "name": "string",
141///     "score": "number"
142/// }));
143/// ```
144pub struct SchemaValidationMetric {
145    schema: serde_json::Value,
146}
147
148impl SchemaValidationMetric {
149    /// Create a new `SchemaValidationMetric` with the given expected schema.
150    #[must_use]
151    pub fn new(schema: serde_json::Value) -> Self {
152        Self { schema }
153    }
154}
155
156#[async_trait]
157impl AsyncMetric for SchemaValidationMetric {
158    fn name(&self) -> &'static str {
159        "schema_validation"
160    }
161
162    async fn score(&self, _input: &str, actual_output: &str, _kw: &[&str]) -> f64 {
163        // Parse output as JSON
164        let Ok(output_val) = serde_json::from_str::<serde_json::Value>(actual_output) else {
165            return 0.0; // not valid JSON
166        };
167
168        // Check that all schema keys exist in output
169        let schema_obj = match &self.schema {
170            serde_json::Value::Object(m) => m,
171            _ => return if output_val == self.schema { 1.0 } else { 0.0 },
172        };
173
174        let output_obj = match &output_val {
175            serde_json::Value::Object(m) => m,
176            _ => return 0.0,
177        };
178
179        if schema_obj.is_empty() {
180            return 1.0;
181        }
182
183        let present = schema_obj
184            .keys()
185            .filter(|k| output_obj.contains_key(*k))
186            .count();
187        present as f64 / schema_obj.len() as f64
188    }
189}
190
191// ─────────────────────────────────────────────────────────────────────────────
192// ToolUsageMetric
193// ─────────────────────────────────────────────────────────────────────────────
194
195/// Checks whether the agent output mentions expected tool names.
196///
197/// A simple heuristic: score = fraction of expected tool names mentioned
198/// in the output. Works well when agents include tool call summaries in responses.
199///
200/// # Example
201///
202/// ```rust
203/// use traitclaw_eval::metrics::ToolUsageMetric;
204/// use traitclaw_eval::runner::AsyncMetric;
205///
206/// let metric = ToolUsageMetric::new(vec!["search", "calculator"]);
207/// ```
208pub struct ToolUsageMetric {
209    expected_tools: Vec<String>,
210}
211
212impl ToolUsageMetric {
213    /// Create a new `ToolUsageMetric` expecting the given tool names.
214    #[must_use]
215    pub fn new(expected_tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
216        Self {
217            expected_tools: expected_tools.into_iter().map(Into::into).collect(),
218        }
219    }
220}
221
222#[async_trait]
223impl AsyncMetric for ToolUsageMetric {
224    fn name(&self) -> &'static str {
225        "tool_usage"
226    }
227
228    async fn score(&self, _input: &str, actual_output: &str, _kw: &[&str]) -> f64 {
229        if self.expected_tools.is_empty() {
230            return 1.0;
231        }
232
233        let output_lower = actual_output.to_lowercase();
234        let found = self
235            .expected_tools
236            .iter()
237            .filter(|tool| output_lower.contains(tool.to_lowercase().as_str()))
238            .count();
239
240        found as f64 / self.expected_tools.len() as f64
241    }
242}
243
244// ─────────────────────────────────────────────────────────────────────────────
245// Tests
246// ─────────────────────────────────────────────────────────────────────────────
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    // ── LlmJudgeMetric ───────────────────────────────────────────────────────
253
254    struct MockJudge(String);
255
256    #[async_trait]
257    impl JudgeProvider for MockJudge {
258        async fn complete(&self, _prompt: &str) -> traitclaw_core::Result<String> {
259            Ok(self.0.clone())
260        }
261    }
262
263    #[tokio::test]
264    async fn test_llm_judge_parses_score() {
265        // AC #6: mock returns 0.85 → metric score = 0.85
266        let metric = LlmJudgeMetric::new(MockJudge("Score: 0.85".to_string()))
267            .with_criteria("accuracy", "Is it accurate?");
268
269        let score = metric.score("input", "output", &[]).await;
270        assert!((score - 0.85).abs() < 1e-6, "expected 0.85, got {score}");
271    }
272
273    #[tokio::test]
274    async fn test_llm_judge_clamps_above_one() {
275        let metric = LlmJudgeMetric::new(MockJudge("Score: 1.5".to_string()));
276        let score = metric.score("in", "out", &[]).await;
277        assert!((score - 1.0).abs() < 1e-6);
278    }
279
280    #[tokio::test]
281    async fn test_llm_judge_invalid_response_returns_zero() {
282        let metric = LlmJudgeMetric::new(MockJudge("I cannot provide a score.".to_string()));
283        let score = metric.score("in", "out", &[]).await;
284        assert!((score - 0.0).abs() < 1e-6);
285    }
286
287    #[test]
288    fn test_parse_score_variants() {
289        assert!((parse_score("Score: 0.75") - 0.75).abs() < 1e-6);
290        assert!((parse_score("0.90") - 0.90).abs() < 1e-6);
291        assert!((parse_score("no score here") - 0.0).abs() < 1e-6);
292        assert!((parse_score("Score: 1.5") - 1.0).abs() < 1e-6); // clamped
293    }
294
295    // ── SchemaValidationMetric ───────────────────────────────────────────────
296
297    #[tokio::test]
298    async fn test_schema_validation_valid_json() {
299        // AC #7: valid JSON with all required keys → score = 1.0
300        let metric = SchemaValidationMetric::new(serde_json::json!({
301            "name": "string",
302            "score": "number"
303        }));
304        let output = r#"{"name": "test", "score": 42}"#;
305        let score = metric.score("in", output, &[]).await;
306        assert!((score - 1.0).abs() < 1e-6, "expected 1.0, got {score}");
307    }
308
309    #[tokio::test]
310    async fn test_schema_validation_partial_keys() {
311        let metric = SchemaValidationMetric::new(serde_json::json!({
312            "name": "string",
313            "score": "number",
314            "extra": "string"
315        }));
316        let output = r#"{"name": "test"}"#; // only 1/3 keys
317        let score = metric.score("in", output, &[]).await;
318        // 1/3 ≈ 0.333
319        assert!(score < 0.5, "expected < 0.5, got {score}");
320    }
321
322    #[tokio::test]
323    async fn test_schema_validation_invalid_json() {
324        // AC #7: invalid JSON → score = 0.0
325        let metric = SchemaValidationMetric::new(serde_json::json!({"name": "string"}));
326        let score = metric.score("in", "not json at all", &[]).await;
327        assert!((score - 0.0).abs() < 1e-6);
328    }
329
330    // ── ToolUsageMetric ──────────────────────────────────────────────────────
331
332    #[tokio::test]
333    async fn test_tool_usage_all_found() {
334        let metric = ToolUsageMetric::new(vec!["search", "calculator"]);
335        let score = metric
336            .score("in", "I used search and calculator tools", &[])
337            .await;
338        assert!((score - 1.0).abs() < 1e-6);
339    }
340
341    #[tokio::test]
342    async fn test_tool_usage_partial() {
343        let metric = ToolUsageMetric::new(vec!["search", "calculator"]);
344        let score = metric.score("in", "I only used search", &[]).await;
345        assert!((score - 0.5).abs() < 1e-6);
346    }
347
348    #[tokio::test]
349    async fn test_tool_usage_none_found() {
350        let metric = ToolUsageMetric::new(vec!["search"]);
351        let score = metric.score("in", "I didn't call any tools", &[]).await;
352        assert!((score - 0.0).abs() < 1e-6);
353    }
354
355    #[tokio::test]
356    async fn test_tool_usage_empty_expected() {
357        let metric = ToolUsageMetric::new(Vec::<String>::new());
358        let score = metric.score("in", "anything", &[]).await;
359        assert!((score - 1.0).abs() < 1e-6);
360    }
361}