Skip to main content

spice_framework/
assertion.rs

1use crate::agent::AgentOutput;
2use crate::multi_turn;
3use serde::{Deserialize, Serialize};
4use std::ops::RangeInclusive;
5
6/// Result of evaluating a single assertion.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct AssertionResult {
9    pub description: String,
10    pub passed: bool,
11    pub message: Option<String>,
12    pub is_security: bool,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub category: Option<String>,
15}
16
17/// An assertion to evaluate against agent output.
18pub enum Assertion {
19    ExpectTools(Vec<String>),
20    ForbidTools(Vec<String>),
21    ExpectAnyTool,
22    ExpectNoTools,
23    ExpectTextContains(String),
24    ExpectTextNotContains(String),
25    ExpectTurns(RangeInclusive<usize>),
26    ExpectToolsWithinAllowlist,
27    ExpectNoError,
28    ExpectToolArgs(String, serde_json::Value),
29    ExpectToolArgsContain(String, serde_json::Value),
30    ExpectToolArg(String, String, serde_json::Value),
31    ExpectToolArgExists(String, String),
32    ExpectToolCallCount(String, usize),
33    ExpectToolCallOrder(Vec<String>),
34    ExpectToolOnTurn(usize, String),
35    /// Tools must appear in the specified turn range.
36    ExpectToolsInTurnRange(RangeInclusive<usize>, Vec<String>),
37    /// Tools must NOT appear in the specified turn range.
38    ForbidToolsInTurnRange(RangeInclusive<usize>, Vec<String>),
39    /// Last turn must contain this tool call.
40    ExpectFinalTool(String),
41    /// Last turn's tool call must have this argument value.
42    ExpectFinalToolArg(String, String, serde_json::Value),
43    /// Gather tools must appear before action tools.
44    ExpectGatheringBeforeAction(Vec<String>, Vec<String>),
45    /// Tool appears on last turn and no other.
46    ExpectToolOnlyOnFinalTurn(String),
47    Custom(Box<dyn Fn(&AgentOutput) -> Result<(), String> + Send + Sync>),
48}
49
50impl Assertion {
51    /// Whether this is a security-related assertion.
52    pub fn is_security(&self) -> bool {
53        matches!(
54            self,
55            Assertion::ForbidTools(_)
56                | Assertion::ForbidToolsInTurnRange(_, _)
57                | Assertion::ExpectToolsWithinAllowlist
58        )
59    }
60
61    /// Category for reporting grouping.
62    pub fn category(&self) -> Option<&str> {
63        match self {
64            Assertion::ExpectToolsInTurnRange(_, _)
65            | Assertion::ForbidToolsInTurnRange(_, _)
66            | Assertion::ExpectFinalTool(_)
67            | Assertion::ExpectFinalToolArg(_, _, _)
68            | Assertion::ExpectGatheringBeforeAction(_, _)
69            | Assertion::ExpectToolOnlyOnFinalTurn(_) => Some("multi-turn"),
70            _ => None,
71        }
72    }
73
74    /// Evaluate this assertion against agent output.
75    pub fn evaluate(
76        &self,
77        output: &AgentOutput,
78        available_tools: &[String],
79    ) -> AssertionResult {
80        let is_security = self.is_security();
81        let category = self.category().map(|s| s.to_string());
82
83        match self {
84            Assertion::ExpectTools(tools) => {
85                let missing: Vec<_> = tools
86                    .iter()
87                    .filter(|t| !output.tools_called.contains(t))
88                    .collect();
89                AssertionResult {
90                    description: format!("expect tools {:?}", tools),
91                    passed: missing.is_empty(),
92                    message: if missing.is_empty() {
93                        None
94                    } else {
95                        Some(format!("Missing tool calls: {:?}", missing))
96                    },
97                    is_security,
98                    category,
99                }
100            }
101
102            Assertion::ForbidTools(tools) => {
103                let found: Vec<_> = tools
104                    .iter()
105                    .filter(|t| output.tools_called.contains(t))
106                    .collect();
107                AssertionResult {
108                    description: format!("forbid tools {:?}", tools),
109                    passed: found.is_empty(),
110                    message: if found.is_empty() {
111                        None
112                    } else {
113                        Some(format!("Forbidden tools were called: {:?}", found))
114                    },
115                    is_security,
116                    category,
117                }
118            }
119
120            Assertion::ExpectAnyTool => AssertionResult {
121                description: "expect any tool call".into(),
122                passed: !output.tools_called.is_empty(),
123                message: if output.tools_called.is_empty() {
124                    Some("No tools were called".into())
125                } else {
126                    None
127                },
128                is_security,
129                category,
130            },
131
132            Assertion::ExpectNoTools => AssertionResult {
133                description: "expect no tool calls".into(),
134                passed: output.tools_called.is_empty(),
135                message: if output.tools_called.is_empty() {
136                    None
137                } else {
138                    Some(format!("Tools were called: {:?}", output.tools_called))
139                },
140                is_security,
141                category,
142            },
143
144            Assertion::ExpectTextContains(s) => AssertionResult {
145                description: format!("expect text contains {:?}", s),
146                passed: output.final_text.contains(s.as_str()),
147                message: if output.final_text.contains(s.as_str()) {
148                    None
149                } else {
150                    Some(format!(
151                        "Text does not contain {:?}. Got: {:?}",
152                        s,
153                        truncate(&output.final_text, 200)
154                    ))
155                },
156                is_security,
157                category,
158            },
159
160            Assertion::ExpectTextNotContains(s) => AssertionResult {
161                description: format!("expect text not contains {:?}", s),
162                passed: !output.final_text.contains(s.as_str()),
163                message: if !output.final_text.contains(s.as_str()) {
164                    None
165                } else {
166                    Some(format!("Text contains forbidden substring {:?}", s))
167                },
168                is_security,
169                category,
170            },
171
172            Assertion::ExpectTurns(range) => {
173                let count = output.turns.len();
174                AssertionResult {
175                    description: format!("expect turns in {:?}", range),
176                    passed: range.contains(&count),
177                    message: if range.contains(&count) {
178                        None
179                    } else {
180                        Some(format!(
181                            "Turn count {} not in range {:?}",
182                            count, range
183                        ))
184                    },
185                    is_security,
186                    category,
187                }
188            }
189
190            Assertion::ExpectToolsWithinAllowlist => {
191                let violations: Vec<_> = output
192                    .tools_called
193                    .iter()
194                    .filter(|t| !available_tools.contains(t))
195                    .collect();
196                AssertionResult {
197                    description: "expect tools within allowlist".into(),
198                    passed: violations.is_empty(),
199                    message: if violations.is_empty() {
200                        None
201                    } else {
202                        Some(format!(
203                            "Tools called outside allowlist: {:?} (allowed: {:?})",
204                            violations, available_tools
205                        ))
206                    },
207                    is_security: true,
208                    category,
209                }
210            }
211
212            Assertion::ExpectNoError => AssertionResult {
213                description: "expect no error".into(),
214                passed: output.error.is_none(),
215                message: output
216                    .error
217                    .as_ref()
218                    .map(|e| format!("Agent returned error: {}", e)),
219                is_security,
220                category,
221            },
222
223            Assertion::ExpectToolArgs(tool, expected) => {
224                let calls = output.tool_calls_by_name(tool);
225                if calls.is_empty() {
226                    return AssertionResult {
227                        description: format!("expect tool args for {:?}", tool),
228                        passed: false,
229                        message: Some(format!("Tool {:?} was never called", tool)),
230                        is_security,
231                        category,
232                    };
233                }
234                let matched = calls.iter().any(|tc| tc.arguments == *expected);
235                AssertionResult {
236                    description: format!("expect tool args for {:?}", tool),
237                    passed: matched,
238                    message: if matched {
239                        None
240                    } else {
241                        Some(format!(
242                            "No call to {:?} matched exact args {:?}. Got: {:?}",
243                            tool,
244                            expected,
245                            calls.iter().map(|tc| &tc.arguments).collect::<Vec<_>>()
246                        ))
247                    },
248                    is_security,
249                    category,
250                }
251            }
252
253            Assertion::ExpectToolArgsContain(tool, partial) => {
254                let calls = output.tool_calls_by_name(tool);
255                if calls.is_empty() {
256                    return AssertionResult {
257                        description: format!("expect tool args contain for {:?}", tool),
258                        passed: false,
259                        message: Some(format!("Tool {:?} was never called", tool)),
260                        is_security,
261                        category,
262                    };
263                }
264                let matched = calls.iter().any(|tc| json_contains(&tc.arguments, partial));
265                AssertionResult {
266                    description: format!("expect tool args contain for {:?}", tool),
267                    passed: matched,
268                    message: if matched {
269                        None
270                    } else {
271                        Some(format!(
272                            "No call to {:?} contains {:?}. Got: {:?}",
273                            tool,
274                            partial,
275                            calls.iter().map(|tc| &tc.arguments).collect::<Vec<_>>()
276                        ))
277                    },
278                    is_security,
279                    category,
280                }
281            }
282
283            Assertion::ExpectToolArg(tool, param, value) => {
284                let calls = output.tool_calls_by_name(tool);
285                if calls.is_empty() {
286                    return AssertionResult {
287                        description: format!("expect tool arg {:?}.{:?}", tool, param),
288                        passed: false,
289                        message: Some(format!("Tool {:?} was never called", tool)),
290                        is_security,
291                        category,
292                    };
293                }
294                let matched = calls
295                    .iter()
296                    .any(|tc| tc.arguments.get(param.as_str()) == Some(value));
297                AssertionResult {
298                    description: format!("expect tool arg {:?}.{:?} = {:?}", tool, param, value),
299                    passed: matched,
300                    message: if matched {
301                        None
302                    } else {
303                        Some(format!(
304                            "No call to {:?} has {:?} = {:?}",
305                            tool, param, value
306                        ))
307                    },
308                    is_security,
309                    category,
310                }
311            }
312
313            Assertion::ExpectToolArgExists(tool, param) => {
314                let calls = output.tool_calls_by_name(tool);
315                if calls.is_empty() {
316                    return AssertionResult {
317                        description: format!("expect tool arg exists {:?}.{:?}", tool, param),
318                        passed: false,
319                        message: Some(format!("Tool {:?} was never called", tool)),
320                        is_security,
321                        category,
322                    };
323                }
324                let matched = calls
325                    .iter()
326                    .any(|tc| tc.arguments.get(param.as_str()).is_some());
327                AssertionResult {
328                    description: format!("expect tool arg exists {:?}.{:?}", tool, param),
329                    passed: matched,
330                    message: if matched {
331                        None
332                    } else {
333                        Some(format!(
334                            "No call to {:?} has argument {:?}",
335                            tool, param
336                        ))
337                    },
338                    is_security,
339                    category,
340                }
341            }
342
343            Assertion::ExpectToolCallCount(tool, expected) => {
344                let count = output.tool_calls_by_name(tool).len();
345                AssertionResult {
346                    description: format!("expect {:?} called {} times", tool, expected),
347                    passed: count == *expected,
348                    message: if count == *expected {
349                        None
350                    } else {
351                        Some(format!(
352                            "Expected {:?} called {} times, got {}",
353                            tool, expected, count
354                        ))
355                    },
356                    is_security,
357                    category,
358                }
359            }
360
361            Assertion::ExpectToolCallOrder(order) => {
362                let all_calls: Vec<&str> = output
363                    .all_tool_calls()
364                    .iter()
365                    .map(|tc| tc.name.as_str())
366                    .collect();
367                let mut idx = 0;
368                for call in &all_calls {
369                    if idx < order.len() && *call == order[idx] {
370                        idx += 1;
371                    }
372                }
373                let passed = idx == order.len();
374                AssertionResult {
375                    description: format!("expect tool call order {:?}", order),
376                    passed,
377                    message: if passed {
378                        None
379                    } else {
380                        Some(format!(
381                            "Expected order {:?}, got calls {:?}",
382                            order, all_calls
383                        ))
384                    },
385                    is_security,
386                    category,
387                }
388            }
389
390            Assertion::ExpectToolOnTurn(turn_idx, tool) => {
391                let passed = output
392                    .turns
393                    .get(*turn_idx)
394                    .map(|t| t.tool_calls.iter().any(|tc| tc.name == *tool))
395                    .unwrap_or(false);
396                AssertionResult {
397                    description: format!("expect {:?} on turn {}", tool, turn_idx),
398                    passed,
399                    message: if passed {
400                        None
401                    } else {
402                        let turn_tools: Vec<Vec<&str>> = output
403                            .turns
404                            .iter()
405                            .map(|t| t.tool_calls.iter().map(|tc| tc.name.as_str()).collect())
406                            .collect();
407                        Some(format!(
408                            "Expected {:?} on turn {}, tools by turn: {:?}",
409                            tool, turn_idx, turn_tools
410                        ))
411                    },
412                    is_security,
413                    category,
414                }
415            }
416
417            // --- Multi-turn assertions ---
418
419            Assertion::ExpectToolsInTurnRange(range, tools) => {
420                let found = multi_turn::tools_in_range(output, range);
421                let missing: Vec<_> = tools
422                    .iter()
423                    .filter(|t| !found.contains(t))
424                    .collect();
425                AssertionResult {
426                    description: format!("expect tools {:?} in turn range {:?}", tools, range),
427                    passed: missing.is_empty(),
428                    message: if missing.is_empty() {
429                        None
430                    } else {
431                        Some(format!(
432                            "Missing tools {:?} in turn range {:?}. Found: {:?}",
433                            missing, range, found
434                        ))
435                    },
436                    is_security,
437                    category,
438                }
439            }
440
441            Assertion::ForbidToolsInTurnRange(range, tools) => {
442                let found = multi_turn::tools_in_range(output, range);
443                let violations: Vec<_> = tools
444                    .iter()
445                    .filter(|t| found.contains(t))
446                    .collect();
447                AssertionResult {
448                    description: format!("forbid tools {:?} in turn range {:?}", tools, range),
449                    passed: violations.is_empty(),
450                    message: if violations.is_empty() {
451                        None
452                    } else {
453                        Some(format!(
454                            "Forbidden tools {:?} found in turn range {:?}",
455                            violations, range
456                        ))
457                    },
458                    is_security,
459                    category,
460                }
461            }
462
463            Assertion::ExpectFinalTool(tool) => {
464                let passed = output
465                    .turns
466                    .last()
467                    .map(|t| t.tool_calls.iter().any(|tc| tc.name == *tool))
468                    .unwrap_or(false);
469                AssertionResult {
470                    description: format!("expect final tool {:?}", tool),
471                    passed,
472                    message: if passed {
473                        None
474                    } else {
475                        let last_tools: Vec<&str> = output
476                            .turns
477                            .last()
478                            .map(|t| t.tool_calls.iter().map(|tc| tc.name.as_str()).collect())
479                            .unwrap_or_default();
480                        Some(format!(
481                            "Expected {:?} on final turn, got tools: {:?}",
482                            tool, last_tools
483                        ))
484                    },
485                    is_security,
486                    category,
487                }
488            }
489
490            Assertion::ExpectFinalToolArg(tool, param, value) => {
491                let passed = output.turns.last().map(|t| {
492                    t.tool_calls
493                        .iter()
494                        .any(|tc| tc.name == *tool && tc.arguments.get(param.as_str()) == Some(value))
495                }).unwrap_or(false);
496                AssertionResult {
497                    description: format!(
498                        "expect final tool arg {:?}.{:?} = {:?}",
499                        tool, param, value
500                    ),
501                    passed,
502                    message: if passed {
503                        None
504                    } else {
505                        let last_calls: Vec<String> = output
506                            .turns
507                            .last()
508                            .map(|t| {
509                                t.tool_calls
510                                    .iter()
511                                    .map(|tc| format!("{}({})", tc.name, tc.arguments))
512                                    .collect()
513                            })
514                            .unwrap_or_default();
515                        Some(format!(
516                            "Expected {:?}.{:?} = {:?} on final turn. Last turn calls: {:?}",
517                            tool, param, value, last_calls
518                        ))
519                    },
520                    is_security,
521                    category,
522                }
523            }
524
525            Assertion::ExpectGatheringBeforeAction(gather_tools, action_tools) => {
526                let gather_strs: Vec<String> = gather_tools.clone();
527                let action_strs: Vec<String> = action_tools.clone();
528                let last_gather = multi_turn::first_turn_with_tools(output, &action_strs)
529                    .unwrap_or(usize::MAX);
530                let first_action = multi_turn::first_turn_with_tools(output, &action_strs);
531                // Check that at least one gather tool was called before any action tool
532                let first_gather = multi_turn::first_turn_with_tools(output, &gather_strs);
533                let passed = match (first_gather, first_action) {
534                    (Some(g), Some(a)) => g < a,
535                    (Some(_), None) => true, // gathered but no action (still valid)
536                    _ => false,
537                };
538                AssertionResult {
539                    description: format!(
540                        "expect gathering {:?} before action {:?}",
541                        gather_tools, action_tools
542                    ),
543                    passed,
544                    message: if passed {
545                        None
546                    } else {
547                        let _ = last_gather; // suppress warning
548                        Some(format!(
549                            "Gathering tools {:?} (first at turn {:?}) should appear before action tools {:?} (first at turn {:?})",
550                            gather_tools, first_gather, action_tools, first_action
551                        ))
552                    },
553                    is_security,
554                    category,
555                }
556            }
557
558            Assertion::ExpectToolOnlyOnFinalTurn(tool) => {
559                let final_idx = output.turns.len().saturating_sub(1);
560                let on_final = output
561                    .turns
562                    .last()
563                    .map(|t| t.tool_calls.iter().any(|tc| tc.name == *tool))
564                    .unwrap_or(false);
565                let on_other = output.turns.iter().any(|t| {
566                    t.index != final_idx
567                        && t.tool_calls.iter().any(|tc| tc.name == *tool)
568                });
569                let passed = on_final && !on_other;
570                AssertionResult {
571                    description: format!("expect {:?} only on final turn", tool),
572                    passed,
573                    message: if passed {
574                        None
575                    } else if !on_final {
576                        Some(format!("{:?} not found on final turn", tool))
577                    } else {
578                        let other_turns: Vec<usize> = output
579                            .turns
580                            .iter()
581                            .filter(|t| {
582                                t.index != final_idx
583                                    && t.tool_calls.iter().any(|tc| tc.name == *tool)
584                            })
585                            .map(|t| t.index)
586                            .collect();
587                        Some(format!(
588                            "{:?} also found on non-final turns: {:?}",
589                            tool, other_turns
590                        ))
591                    },
592                    is_security,
593                    category,
594                }
595            }
596
597            Assertion::Custom(f) => match f(output) {
598                Ok(()) => AssertionResult {
599                    description: "custom assertion".into(),
600                    passed: true,
601                    message: None,
602                    is_security,
603                    category,
604                },
605                Err(msg) => AssertionResult {
606                    description: "custom assertion".into(),
607                    passed: false,
608                    message: Some(msg),
609                    is_security,
610                    category,
611                },
612            },
613        }
614    }
615}
616
617/// Check if `haystack` is a superset of `needle` (partial JSON match).
618fn json_contains(haystack: &serde_json::Value, needle: &serde_json::Value) -> bool {
619    match (haystack, needle) {
620        (serde_json::Value::Object(h), serde_json::Value::Object(n)) => {
621            n.iter().all(|(k, v)| {
622                h.get(k).map_or(false, |hv| json_contains(hv, v))
623            })
624        }
625        (serde_json::Value::Array(h), serde_json::Value::Array(n)) => {
626            n.len() == h.len()
627                && n.iter()
628                    .zip(h.iter())
629                    .all(|(nv, hv)| json_contains(hv, nv))
630        }
631        _ => haystack == needle,
632    }
633}
634
635fn truncate(s: &str, max: usize) -> String {
636    if s.len() <= max {
637        s.to_string()
638    } else {
639        format!("{}...", &s[..max])
640    }
641}