Skip to main content

swink_agent_eval/
match_.rs

1//! Trajectory matching evaluator.
2//!
3//! Compares the actual tool call sequence against an expected golden path
4//! using one of three matching modes.
5
6use serde::{Deserialize, Serialize};
7
8use crate::evaluator::Evaluator;
9use crate::score::Score;
10use crate::types::{EvalCase, EvalMetricResult, ExpectedToolCall, Invocation, RecordedToolCall};
11
12/// How to compare actual tool calls against expected.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum MatchMode {
16    /// Same tools, same order, same count. No extras allowed.
17    Exact,
18    /// Expected tools must appear in order. Extra tools between are allowed.
19    InOrder,
20    /// All expected tools must appear somewhere. Order and extras don't matter.
21    AnyOrder,
22}
23
24/// Evaluator that compares actual tool call trajectories against expected golden paths.
25///
26/// Returns `None` when the case has no `expected_trajectory`.
27pub struct TrajectoryMatcher {
28    mode: MatchMode,
29}
30
31impl TrajectoryMatcher {
32    /// Create a matcher with the given mode.
33    #[must_use]
34    pub const fn new(mode: MatchMode) -> Self {
35        Self { mode }
36    }
37
38    /// Exact matching: same tools, same order, same count.
39    #[must_use]
40    pub const fn exact() -> Self {
41        Self::new(MatchMode::Exact)
42    }
43
44    /// In-order matching: expected tools appear in order, extras allowed.
45    #[must_use]
46    pub const fn in_order() -> Self {
47        Self::new(MatchMode::InOrder)
48    }
49
50    /// Any-order matching: all expected tools appear, any order.
51    #[must_use]
52    pub const fn any_order() -> Self {
53        Self::new(MatchMode::AnyOrder)
54    }
55}
56
57impl Evaluator for TrajectoryMatcher {
58    fn name(&self) -> &'static str {
59        "trajectory"
60    }
61
62    fn evaluate(&self, case: &EvalCase, invocation: &Invocation) -> Option<EvalMetricResult> {
63        let expected = case.expected_trajectory.as_ref()?;
64
65        // Flatten all actual tool calls across turns.
66        let actual: Vec<&RecordedToolCall> = invocation
67            .turns
68            .iter()
69            .flat_map(|t| &t.tool_calls)
70            .collect();
71
72        let (score, details) = match self.mode {
73            MatchMode::Exact => score_exact(expected, &actual),
74            MatchMode::InOrder => score_in_order(expected, &actual),
75            MatchMode::AnyOrder => score_any_order(expected, &actual),
76        };
77
78        Some(EvalMetricResult {
79            evaluator_name: "trajectory".to_string(),
80            score,
81            details: Some(details),
82        })
83    }
84}
85
86/// Check if a recorded tool call matches an expected one.
87fn matches_expected(expected: &ExpectedToolCall, actual: &RecordedToolCall) -> bool {
88    if expected.tool_name != actual.name {
89        return false;
90    }
91    expected
92        .arguments
93        .as_ref()
94        .is_none_or(|expected_args| *expected_args == actual.arguments)
95}
96
97/// Exact: same count, same order, each pair matches.
98#[allow(clippy::cast_precision_loss)]
99fn score_exact(expected: &[ExpectedToolCall], actual: &[&RecordedToolCall]) -> (Score, String) {
100    if expected.len() != actual.len() {
101        return (
102            Score::new(0.0, 1.0),
103            format!(
104                "expected {} tool calls, got {}",
105                expected.len(),
106                actual.len()
107            ),
108        );
109    }
110
111    let matched = expected
112        .iter()
113        .zip(actual.iter())
114        .filter(|(e, a)| matches_expected(e, a))
115        .count();
116
117    let total = expected.len().max(1);
118    let value = matched as f64 / total as f64;
119    let details = format!("{matched}/{total} tool calls matched exactly");
120    (Score::new(value, 1.0), details)
121}
122
123/// In-order: expected tools appear in sequence, extras between are fine.
124#[allow(clippy::cast_precision_loss)]
125fn score_in_order(expected: &[ExpectedToolCall], actual: &[&RecordedToolCall]) -> (Score, String) {
126    if expected.is_empty() {
127        return (Score::pass(), "no expected tool calls".to_string());
128    }
129
130    let mut expected_idx = 0;
131    for actual_call in actual {
132        if expected_idx >= expected.len() {
133            break;
134        }
135        if matches_expected(&expected[expected_idx], actual_call) {
136            expected_idx += 1;
137        }
138    }
139
140    let total = expected.len();
141    let value = expected_idx as f64 / total as f64;
142    let details = format!("{expected_idx}/{total} expected tool calls found in order");
143    (Score::new(value, 1.0), details)
144}
145
146/// Any-order: each expected call must appear at least once.
147#[allow(clippy::cast_precision_loss)]
148fn score_any_order(expected: &[ExpectedToolCall], actual: &[&RecordedToolCall]) -> (Score, String) {
149    if expected.is_empty() {
150        return (Score::pass(), "no expected tool calls".to_string());
151    }
152
153    let matched = expected
154        .iter()
155        .filter(|e| actual.iter().any(|a| matches_expected(e, a)))
156        .count();
157
158    let total = expected.len();
159    let value = matched as f64 / total as f64;
160    let details = format!("{matched}/{total} expected tool calls found (any order)");
161    (Score::new(value, 1.0), details)
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use serde_json::json;
168
169    fn recorded(name: &str, args: serde_json::Value) -> RecordedToolCall {
170        RecordedToolCall {
171            id: "id".to_string(),
172            name: name.to_string(),
173            arguments: args,
174        }
175    }
176
177    fn expected(name: &str, args: Option<serde_json::Value>) -> ExpectedToolCall {
178        ExpectedToolCall {
179            tool_name: name.to_string(),
180            arguments: args,
181        }
182    }
183
184    #[test]
185    fn exact_match_all() {
186        let exp = vec![
187            expected("read", Some(json!({"path": "a.txt"}))),
188            expected("write", None),
189        ];
190        let act = [
191            recorded("read", json!({"path": "a.txt"})),
192            recorded("write", json!({"path": "b.txt"})),
193        ];
194        let refs: Vec<&RecordedToolCall> = act.iter().collect();
195        let (score, _) = score_exact(&exp, &refs);
196        assert!((score.value - 1.0).abs() < f64::EPSILON);
197    }
198
199    #[test]
200    fn exact_match_wrong_order() {
201        let exp = vec![expected("read", None), expected("write", None)];
202        let act = [recorded("write", json!({})), recorded("read", json!({}))];
203        let refs: Vec<&RecordedToolCall> = act.iter().collect();
204        let (score, _) = score_exact(&exp, &refs);
205        assert!((score.value - 0.0).abs() < f64::EPSILON);
206    }
207
208    #[test]
209    fn in_order_with_extras() {
210        let exp = vec![expected("read", None), expected("write", None)];
211        let act = [
212            recorded("search", json!({})),
213            recorded("read", json!({})),
214            recorded("think", json!({})),
215            recorded("write", json!({})),
216        ];
217        let refs: Vec<&RecordedToolCall> = act.iter().collect();
218        let (score, _) = score_in_order(&exp, &refs);
219        assert!((score.value - 1.0).abs() < f64::EPSILON);
220    }
221
222    #[test]
223    fn any_order_finds_all() {
224        let exp = vec![expected("write", None), expected("read", None)];
225        let act = [recorded("read", json!({})), recorded("write", json!({}))];
226        let refs: Vec<&RecordedToolCall> = act.iter().collect();
227        let (score, _) = score_any_order(&exp, &refs);
228        assert!((score.value - 1.0).abs() < f64::EPSILON);
229    }
230
231    #[test]
232    fn any_order_partial_match() {
233        let exp = vec![expected("read", None), expected("delete", None)];
234        let act = [recorded("read", json!({})), recorded("write", json!({}))];
235        let refs: Vec<&RecordedToolCall> = act.iter().collect();
236        let (score, _) = score_any_order(&exp, &refs);
237        assert!((score.value - 0.5).abs() < f64::EPSILON);
238    }
239}