1use serde::{Deserialize, Serialize};
7
8use crate::evaluator::Evaluator;
9use crate::score::Score;
10use crate::types::{EvalCase, EvalMetricResult, ExpectedToolCall, Invocation, RecordedToolCall};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum MatchMode {
16 Exact,
18 InOrder,
20 AnyOrder,
22}
23
24pub struct TrajectoryMatcher {
28 mode: MatchMode,
29}
30
31impl TrajectoryMatcher {
32 #[must_use]
34 pub const fn new(mode: MatchMode) -> Self {
35 Self { mode }
36 }
37
38 #[must_use]
40 pub const fn exact() -> Self {
41 Self::new(MatchMode::Exact)
42 }
43
44 #[must_use]
46 pub const fn in_order() -> Self {
47 Self::new(MatchMode::InOrder)
48 }
49
50 #[must_use]
52 pub const fn any_order() -> Self {
53 Self::new(MatchMode::AnyOrder)
54 }
55}
56
57impl Evaluator for TrajectoryMatcher {
58 fn name(&self) -> &'static str {
59 "trajectory"
60 }
61
62 fn evaluate(&self, case: &EvalCase, invocation: &Invocation) -> Option<EvalMetricResult> {
63 let expected = case.expected_trajectory.as_ref()?;
64
65 let actual: Vec<&RecordedToolCall> = invocation
67 .turns
68 .iter()
69 .flat_map(|t| &t.tool_calls)
70 .collect();
71
72 let (score, details) = match self.mode {
73 MatchMode::Exact => score_exact(expected, &actual),
74 MatchMode::InOrder => score_in_order(expected, &actual),
75 MatchMode::AnyOrder => score_any_order(expected, &actual),
76 };
77
78 Some(EvalMetricResult {
79 evaluator_name: "trajectory".to_string(),
80 score,
81 details: Some(details),
82 })
83 }
84}
85
86fn matches_expected(expected: &ExpectedToolCall, actual: &RecordedToolCall) -> bool {
88 if expected.tool_name != actual.name {
89 return false;
90 }
91 expected
92 .arguments
93 .as_ref()
94 .is_none_or(|expected_args| *expected_args == actual.arguments)
95}
96
97#[allow(clippy::cast_precision_loss)]
99fn score_exact(expected: &[ExpectedToolCall], actual: &[&RecordedToolCall]) -> (Score, String) {
100 if expected.len() != actual.len() {
101 return (
102 Score::new(0.0, 1.0),
103 format!(
104 "expected {} tool calls, got {}",
105 expected.len(),
106 actual.len()
107 ),
108 );
109 }
110
111 let matched = expected
112 .iter()
113 .zip(actual.iter())
114 .filter(|(e, a)| matches_expected(e, a))
115 .count();
116
117 let total = expected.len().max(1);
118 let value = matched as f64 / total as f64;
119 let details = format!("{matched}/{total} tool calls matched exactly");
120 (Score::new(value, 1.0), details)
121}
122
123#[allow(clippy::cast_precision_loss)]
125fn score_in_order(expected: &[ExpectedToolCall], actual: &[&RecordedToolCall]) -> (Score, String) {
126 if expected.is_empty() {
127 return (Score::pass(), "no expected tool calls".to_string());
128 }
129
130 let mut expected_idx = 0;
131 for actual_call in actual {
132 if expected_idx >= expected.len() {
133 break;
134 }
135 if matches_expected(&expected[expected_idx], actual_call) {
136 expected_idx += 1;
137 }
138 }
139
140 let total = expected.len();
141 let value = expected_idx as f64 / total as f64;
142 let details = format!("{expected_idx}/{total} expected tool calls found in order");
143 (Score::new(value, 1.0), details)
144}
145
146#[allow(clippy::cast_precision_loss)]
148fn score_any_order(expected: &[ExpectedToolCall], actual: &[&RecordedToolCall]) -> (Score, String) {
149 if expected.is_empty() {
150 return (Score::pass(), "no expected tool calls".to_string());
151 }
152
153 let matched = expected
154 .iter()
155 .filter(|e| actual.iter().any(|a| matches_expected(e, a)))
156 .count();
157
158 let total = expected.len();
159 let value = matched as f64 / total as f64;
160 let details = format!("{matched}/{total} expected tool calls found (any order)");
161 (Score::new(value, 1.0), details)
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use serde_json::json;
168
169 fn recorded(name: &str, args: serde_json::Value) -> RecordedToolCall {
170 RecordedToolCall {
171 id: "id".to_string(),
172 name: name.to_string(),
173 arguments: args,
174 }
175 }
176
177 fn expected(name: &str, args: Option<serde_json::Value>) -> ExpectedToolCall {
178 ExpectedToolCall {
179 tool_name: name.to_string(),
180 arguments: args,
181 }
182 }
183
184 #[test]
185 fn exact_match_all() {
186 let exp = vec![
187 expected("read", Some(json!({"path": "a.txt"}))),
188 expected("write", None),
189 ];
190 let act = [
191 recorded("read", json!({"path": "a.txt"})),
192 recorded("write", json!({"path": "b.txt"})),
193 ];
194 let refs: Vec<&RecordedToolCall> = act.iter().collect();
195 let (score, _) = score_exact(&exp, &refs);
196 assert!((score.value - 1.0).abs() < f64::EPSILON);
197 }
198
199 #[test]
200 fn exact_match_wrong_order() {
201 let exp = vec![expected("read", None), expected("write", None)];
202 let act = [recorded("write", json!({})), recorded("read", json!({}))];
203 let refs: Vec<&RecordedToolCall> = act.iter().collect();
204 let (score, _) = score_exact(&exp, &refs);
205 assert!((score.value - 0.0).abs() < f64::EPSILON);
206 }
207
208 #[test]
209 fn in_order_with_extras() {
210 let exp = vec![expected("read", None), expected("write", None)];
211 let act = [
212 recorded("search", json!({})),
213 recorded("read", json!({})),
214 recorded("think", json!({})),
215 recorded("write", json!({})),
216 ];
217 let refs: Vec<&RecordedToolCall> = act.iter().collect();
218 let (score, _) = score_in_order(&exp, &refs);
219 assert!((score.value - 1.0).abs() < f64::EPSILON);
220 }
221
222 #[test]
223 fn any_order_finds_all() {
224 let exp = vec![expected("write", None), expected("read", None)];
225 let act = [recorded("read", json!({})), recorded("write", json!({}))];
226 let refs: Vec<&RecordedToolCall> = act.iter().collect();
227 let (score, _) = score_any_order(&exp, &refs);
228 assert!((score.value - 1.0).abs() < f64::EPSILON);
229 }
230
231 #[test]
232 fn any_order_partial_match() {
233 let exp = vec![expected("read", None), expected("delete", None)];
234 let act = [recorded("read", json!({})), recorded("write", json!({}))];
235 let refs: Vec<&RecordedToolCall> = act.iter().collect();
236 let (score, _) = score_any_order(&exp, &refs);
237 assert!((score.value - 0.5).abs() < f64::EPSILON);
238 }
239}