1use crate::agent::AgentOutput;
2use crate::multi_turn;
3use serde::{Deserialize, Serialize};
4use std::ops::RangeInclusive;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct AssertionResult {
9 pub description: String,
10 pub passed: bool,
11 pub message: Option<String>,
12 pub is_security: bool,
13 #[serde(skip_serializing_if = "Option::is_none")]
14 pub category: Option<String>,
15}
16
17pub enum Assertion {
19 ExpectTools(Vec<String>),
20 ForbidTools(Vec<String>),
21 ExpectAnyTool,
22 ExpectNoTools,
23 ExpectTextContains(String),
24 ExpectTextNotContains(String),
25 ExpectTurns(RangeInclusive<usize>),
26 ExpectToolsWithinAllowlist,
27 ExpectNoError,
28 ExpectToolArgs(String, serde_json::Value),
29 ExpectToolArgsContain(String, serde_json::Value),
30 ExpectToolArg(String, String, serde_json::Value),
31 ExpectToolArgExists(String, String),
32 ExpectToolCallCount(String, usize),
33 ExpectToolCallOrder(Vec<String>),
34 ExpectToolOnTurn(usize, String),
35 ExpectToolsInTurnRange(RangeInclusive<usize>, Vec<String>),
37 ForbidToolsInTurnRange(RangeInclusive<usize>, Vec<String>),
39 ExpectFinalTool(String),
41 ExpectFinalToolArg(String, String, serde_json::Value),
43 ExpectGatheringBeforeAction(Vec<String>, Vec<String>),
45 ExpectToolOnlyOnFinalTurn(String),
47 Custom(Box<dyn Fn(&AgentOutput) -> Result<(), String> + Send + Sync>),
48}
49
50impl Assertion {
51 pub fn is_security(&self) -> bool {
53 matches!(
54 self,
55 Assertion::ForbidTools(_)
56 | Assertion::ForbidToolsInTurnRange(_, _)
57 | Assertion::ExpectToolsWithinAllowlist
58 )
59 }
60
61 pub fn category(&self) -> Option<&str> {
63 match self {
64 Assertion::ExpectToolsInTurnRange(_, _)
65 | Assertion::ForbidToolsInTurnRange(_, _)
66 | Assertion::ExpectFinalTool(_)
67 | Assertion::ExpectFinalToolArg(_, _, _)
68 | Assertion::ExpectGatheringBeforeAction(_, _)
69 | Assertion::ExpectToolOnlyOnFinalTurn(_) => Some("multi-turn"),
70 _ => None,
71 }
72 }
73
74 pub fn evaluate(
76 &self,
77 output: &AgentOutput,
78 available_tools: &[String],
79 ) -> AssertionResult {
80 let is_security = self.is_security();
81 let category = self.category().map(|s| s.to_string());
82
83 match self {
84 Assertion::ExpectTools(tools) => {
85 let missing: Vec<_> = tools
86 .iter()
87 .filter(|t| !output.tools_called.contains(t))
88 .collect();
89 AssertionResult {
90 description: format!("expect tools {:?}", tools),
91 passed: missing.is_empty(),
92 message: if missing.is_empty() {
93 None
94 } else {
95 Some(format!("Missing tool calls: {:?}", missing))
96 },
97 is_security,
98 category,
99 }
100 }
101
102 Assertion::ForbidTools(tools) => {
103 let found: Vec<_> = tools
104 .iter()
105 .filter(|t| output.tools_called.contains(t))
106 .collect();
107 AssertionResult {
108 description: format!("forbid tools {:?}", tools),
109 passed: found.is_empty(),
110 message: if found.is_empty() {
111 None
112 } else {
113 Some(format!("Forbidden tools were called: {:?}", found))
114 },
115 is_security,
116 category,
117 }
118 }
119
120 Assertion::ExpectAnyTool => AssertionResult {
121 description: "expect any tool call".into(),
122 passed: !output.tools_called.is_empty(),
123 message: if output.tools_called.is_empty() {
124 Some("No tools were called".into())
125 } else {
126 None
127 },
128 is_security,
129 category,
130 },
131
132 Assertion::ExpectNoTools => AssertionResult {
133 description: "expect no tool calls".into(),
134 passed: output.tools_called.is_empty(),
135 message: if output.tools_called.is_empty() {
136 None
137 } else {
138 Some(format!("Tools were called: {:?}", output.tools_called))
139 },
140 is_security,
141 category,
142 },
143
144 Assertion::ExpectTextContains(s) => AssertionResult {
145 description: format!("expect text contains {:?}", s),
146 passed: output.final_text.contains(s.as_str()),
147 message: if output.final_text.contains(s.as_str()) {
148 None
149 } else {
150 Some(format!(
151 "Text does not contain {:?}. Got: {:?}",
152 s,
153 truncate(&output.final_text, 200)
154 ))
155 },
156 is_security,
157 category,
158 },
159
160 Assertion::ExpectTextNotContains(s) => AssertionResult {
161 description: format!("expect text not contains {:?}", s),
162 passed: !output.final_text.contains(s.as_str()),
163 message: if !output.final_text.contains(s.as_str()) {
164 None
165 } else {
166 Some(format!("Text contains forbidden substring {:?}", s))
167 },
168 is_security,
169 category,
170 },
171
172 Assertion::ExpectTurns(range) => {
173 let count = output.turns.len();
174 AssertionResult {
175 description: format!("expect turns in {:?}", range),
176 passed: range.contains(&count),
177 message: if range.contains(&count) {
178 None
179 } else {
180 Some(format!(
181 "Turn count {} not in range {:?}",
182 count, range
183 ))
184 },
185 is_security,
186 category,
187 }
188 }
189
190 Assertion::ExpectToolsWithinAllowlist => {
191 let violations: Vec<_> = output
192 .tools_called
193 .iter()
194 .filter(|t| !available_tools.contains(t))
195 .collect();
196 AssertionResult {
197 description: "expect tools within allowlist".into(),
198 passed: violations.is_empty(),
199 message: if violations.is_empty() {
200 None
201 } else {
202 Some(format!(
203 "Tools called outside allowlist: {:?} (allowed: {:?})",
204 violations, available_tools
205 ))
206 },
207 is_security: true,
208 category,
209 }
210 }
211
212 Assertion::ExpectNoError => AssertionResult {
213 description: "expect no error".into(),
214 passed: output.error.is_none(),
215 message: output
216 .error
217 .as_ref()
218 .map(|e| format!("Agent returned error: {}", e)),
219 is_security,
220 category,
221 },
222
223 Assertion::ExpectToolArgs(tool, expected) => {
224 let calls = output.tool_calls_by_name(tool);
225 if calls.is_empty() {
226 return AssertionResult {
227 description: format!("expect tool args for {:?}", tool),
228 passed: false,
229 message: Some(format!("Tool {:?} was never called", tool)),
230 is_security,
231 category,
232 };
233 }
234 let matched = calls.iter().any(|tc| tc.arguments == *expected);
235 AssertionResult {
236 description: format!("expect tool args for {:?}", tool),
237 passed: matched,
238 message: if matched {
239 None
240 } else {
241 Some(format!(
242 "No call to {:?} matched exact args {:?}. Got: {:?}",
243 tool,
244 expected,
245 calls.iter().map(|tc| &tc.arguments).collect::<Vec<_>>()
246 ))
247 },
248 is_security,
249 category,
250 }
251 }
252
253 Assertion::ExpectToolArgsContain(tool, partial) => {
254 let calls = output.tool_calls_by_name(tool);
255 if calls.is_empty() {
256 return AssertionResult {
257 description: format!("expect tool args contain for {:?}", tool),
258 passed: false,
259 message: Some(format!("Tool {:?} was never called", tool)),
260 is_security,
261 category,
262 };
263 }
264 let matched = calls.iter().any(|tc| json_contains(&tc.arguments, partial));
265 AssertionResult {
266 description: format!("expect tool args contain for {:?}", tool),
267 passed: matched,
268 message: if matched {
269 None
270 } else {
271 Some(format!(
272 "No call to {:?} contains {:?}. Got: {:?}",
273 tool,
274 partial,
275 calls.iter().map(|tc| &tc.arguments).collect::<Vec<_>>()
276 ))
277 },
278 is_security,
279 category,
280 }
281 }
282
283 Assertion::ExpectToolArg(tool, param, value) => {
284 let calls = output.tool_calls_by_name(tool);
285 if calls.is_empty() {
286 return AssertionResult {
287 description: format!("expect tool arg {:?}.{:?}", tool, param),
288 passed: false,
289 message: Some(format!("Tool {:?} was never called", tool)),
290 is_security,
291 category,
292 };
293 }
294 let matched = calls
295 .iter()
296 .any(|tc| tc.arguments.get(param.as_str()) == Some(value));
297 AssertionResult {
298 description: format!("expect tool arg {:?}.{:?} = {:?}", tool, param, value),
299 passed: matched,
300 message: if matched {
301 None
302 } else {
303 Some(format!(
304 "No call to {:?} has {:?} = {:?}",
305 tool, param, value
306 ))
307 },
308 is_security,
309 category,
310 }
311 }
312
313 Assertion::ExpectToolArgExists(tool, param) => {
314 let calls = output.tool_calls_by_name(tool);
315 if calls.is_empty() {
316 return AssertionResult {
317 description: format!("expect tool arg exists {:?}.{:?}", tool, param),
318 passed: false,
319 message: Some(format!("Tool {:?} was never called", tool)),
320 is_security,
321 category,
322 };
323 }
324 let matched = calls
325 .iter()
326 .any(|tc| tc.arguments.get(param.as_str()).is_some());
327 AssertionResult {
328 description: format!("expect tool arg exists {:?}.{:?}", tool, param),
329 passed: matched,
330 message: if matched {
331 None
332 } else {
333 Some(format!(
334 "No call to {:?} has argument {:?}",
335 tool, param
336 ))
337 },
338 is_security,
339 category,
340 }
341 }
342
343 Assertion::ExpectToolCallCount(tool, expected) => {
344 let count = output.tool_calls_by_name(tool).len();
345 AssertionResult {
346 description: format!("expect {:?} called {} times", tool, expected),
347 passed: count == *expected,
348 message: if count == *expected {
349 None
350 } else {
351 Some(format!(
352 "Expected {:?} called {} times, got {}",
353 tool, expected, count
354 ))
355 },
356 is_security,
357 category,
358 }
359 }
360
361 Assertion::ExpectToolCallOrder(order) => {
362 let all_calls: Vec<&str> = output
363 .all_tool_calls()
364 .iter()
365 .map(|tc| tc.name.as_str())
366 .collect();
367 let mut idx = 0;
368 for call in &all_calls {
369 if idx < order.len() && *call == order[idx] {
370 idx += 1;
371 }
372 }
373 let passed = idx == order.len();
374 AssertionResult {
375 description: format!("expect tool call order {:?}", order),
376 passed,
377 message: if passed {
378 None
379 } else {
380 Some(format!(
381 "Expected order {:?}, got calls {:?}",
382 order, all_calls
383 ))
384 },
385 is_security,
386 category,
387 }
388 }
389
390 Assertion::ExpectToolOnTurn(turn_idx, tool) => {
391 let passed = output
392 .turns
393 .get(*turn_idx)
394 .map(|t| t.tool_calls.iter().any(|tc| tc.name == *tool))
395 .unwrap_or(false);
396 AssertionResult {
397 description: format!("expect {:?} on turn {}", tool, turn_idx),
398 passed,
399 message: if passed {
400 None
401 } else {
402 let turn_tools: Vec<Vec<&str>> = output
403 .turns
404 .iter()
405 .map(|t| t.tool_calls.iter().map(|tc| tc.name.as_str()).collect())
406 .collect();
407 Some(format!(
408 "Expected {:?} on turn {}, tools by turn: {:?}",
409 tool, turn_idx, turn_tools
410 ))
411 },
412 is_security,
413 category,
414 }
415 }
416
417 Assertion::ExpectToolsInTurnRange(range, tools) => {
420 let found = multi_turn::tools_in_range(output, range);
421 let missing: Vec<_> = tools
422 .iter()
423 .filter(|t| !found.contains(t))
424 .collect();
425 AssertionResult {
426 description: format!("expect tools {:?} in turn range {:?}", tools, range),
427 passed: missing.is_empty(),
428 message: if missing.is_empty() {
429 None
430 } else {
431 Some(format!(
432 "Missing tools {:?} in turn range {:?}. Found: {:?}",
433 missing, range, found
434 ))
435 },
436 is_security,
437 category,
438 }
439 }
440
441 Assertion::ForbidToolsInTurnRange(range, tools) => {
442 let found = multi_turn::tools_in_range(output, range);
443 let violations: Vec<_> = tools
444 .iter()
445 .filter(|t| found.contains(t))
446 .collect();
447 AssertionResult {
448 description: format!("forbid tools {:?} in turn range {:?}", tools, range),
449 passed: violations.is_empty(),
450 message: if violations.is_empty() {
451 None
452 } else {
453 Some(format!(
454 "Forbidden tools {:?} found in turn range {:?}",
455 violations, range
456 ))
457 },
458 is_security,
459 category,
460 }
461 }
462
463 Assertion::ExpectFinalTool(tool) => {
464 let passed = output
465 .turns
466 .last()
467 .map(|t| t.tool_calls.iter().any(|tc| tc.name == *tool))
468 .unwrap_or(false);
469 AssertionResult {
470 description: format!("expect final tool {:?}", tool),
471 passed,
472 message: if passed {
473 None
474 } else {
475 let last_tools: Vec<&str> = output
476 .turns
477 .last()
478 .map(|t| t.tool_calls.iter().map(|tc| tc.name.as_str()).collect())
479 .unwrap_or_default();
480 Some(format!(
481 "Expected {:?} on final turn, got tools: {:?}",
482 tool, last_tools
483 ))
484 },
485 is_security,
486 category,
487 }
488 }
489
490 Assertion::ExpectFinalToolArg(tool, param, value) => {
491 let passed = output.turns.last().map(|t| {
492 t.tool_calls
493 .iter()
494 .any(|tc| tc.name == *tool && tc.arguments.get(param.as_str()) == Some(value))
495 }).unwrap_or(false);
496 AssertionResult {
497 description: format!(
498 "expect final tool arg {:?}.{:?} = {:?}",
499 tool, param, value
500 ),
501 passed,
502 message: if passed {
503 None
504 } else {
505 let last_calls: Vec<String> = output
506 .turns
507 .last()
508 .map(|t| {
509 t.tool_calls
510 .iter()
511 .map(|tc| format!("{}({})", tc.name, tc.arguments))
512 .collect()
513 })
514 .unwrap_or_default();
515 Some(format!(
516 "Expected {:?}.{:?} = {:?} on final turn. Last turn calls: {:?}",
517 tool, param, value, last_calls
518 ))
519 },
520 is_security,
521 category,
522 }
523 }
524
525 Assertion::ExpectGatheringBeforeAction(gather_tools, action_tools) => {
526 let gather_strs: Vec<String> = gather_tools.clone();
527 let action_strs: Vec<String> = action_tools.clone();
528 let last_gather = multi_turn::first_turn_with_tools(output, &action_strs)
529 .unwrap_or(usize::MAX);
530 let first_action = multi_turn::first_turn_with_tools(output, &action_strs);
531 let first_gather = multi_turn::first_turn_with_tools(output, &gather_strs);
533 let passed = match (first_gather, first_action) {
534 (Some(g), Some(a)) => g < a,
535 (Some(_), None) => true, _ => false,
537 };
538 AssertionResult {
539 description: format!(
540 "expect gathering {:?} before action {:?}",
541 gather_tools, action_tools
542 ),
543 passed,
544 message: if passed {
545 None
546 } else {
547 let _ = last_gather; Some(format!(
549 "Gathering tools {:?} (first at turn {:?}) should appear before action tools {:?} (first at turn {:?})",
550 gather_tools, first_gather, action_tools, first_action
551 ))
552 },
553 is_security,
554 category,
555 }
556 }
557
558 Assertion::ExpectToolOnlyOnFinalTurn(tool) => {
559 let final_idx = output.turns.len().saturating_sub(1);
560 let on_final = output
561 .turns
562 .last()
563 .map(|t| t.tool_calls.iter().any(|tc| tc.name == *tool))
564 .unwrap_or(false);
565 let on_other = output.turns.iter().any(|t| {
566 t.index != final_idx
567 && t.tool_calls.iter().any(|tc| tc.name == *tool)
568 });
569 let passed = on_final && !on_other;
570 AssertionResult {
571 description: format!("expect {:?} only on final turn", tool),
572 passed,
573 message: if passed {
574 None
575 } else if !on_final {
576 Some(format!("{:?} not found on final turn", tool))
577 } else {
578 let other_turns: Vec<usize> = output
579 .turns
580 .iter()
581 .filter(|t| {
582 t.index != final_idx
583 && t.tool_calls.iter().any(|tc| tc.name == *tool)
584 })
585 .map(|t| t.index)
586 .collect();
587 Some(format!(
588 "{:?} also found on non-final turns: {:?}",
589 tool, other_turns
590 ))
591 },
592 is_security,
593 category,
594 }
595 }
596
597 Assertion::Custom(f) => match f(output) {
598 Ok(()) => AssertionResult {
599 description: "custom assertion".into(),
600 passed: true,
601 message: None,
602 is_security,
603 category,
604 },
605 Err(msg) => AssertionResult {
606 description: "custom assertion".into(),
607 passed: false,
608 message: Some(msg),
609 is_security,
610 category,
611 },
612 },
613 }
614 }
615}
616
617fn json_contains(haystack: &serde_json::Value, needle: &serde_json::Value) -> bool {
619 match (haystack, needle) {
620 (serde_json::Value::Object(h), serde_json::Value::Object(n)) => {
621 n.iter().all(|(k, v)| {
622 h.get(k).map_or(false, |hv| json_contains(hv, v))
623 })
624 }
625 (serde_json::Value::Array(h), serde_json::Value::Array(n)) => {
626 n.len() == h.len()
627 && n.iter()
628 .zip(h.iter())
629 .all(|(nv, hv)| json_contains(hv, nv))
630 }
631 _ => haystack == needle,
632 }
633}
634
635fn truncate(s: &str, max: usize) -> String {
636 if s.len() <= max {
637 s.to_string()
638 } else {
639 format!("{}...", &s[..max])
640 }
641}