1use std::collections::{BTreeMap, BTreeSet};
22
23use thiserror::Error;
24
25use crate::property::dsl::{parse_with_overrides, DslError, InvariantFile, MAX_EXTENDS_DEPTH};
26
27pub const EMBEDDED_PACKS: &[(&str, &str)] = &[
34 ("auth", include_str!("../../packs/auth.yaml")),
35 ("auth-flow", include_str!("../../packs/auth-flow.yaml")),
36 (
37 "authorization",
38 include_str!("../../packs/authorization.yaml"),
39 ),
40 (
41 "context-poisoning",
42 include_str!("../../packs/context-poisoning.yaml"),
43 ),
44 ("error-shape", include_str!("../../packs/error-shape.yaml")),
45 ("idempotency", include_str!("../../packs/idempotency.yaml")),
46 (
47 "injection-shell",
48 include_str!("../../packs/injection-shell.yaml"),
49 ),
50 (
51 "injection-sql",
52 include_str!("../../packs/injection-sql.yaml"),
53 ),
54 (
55 "large-payload",
56 include_str!("../../packs/large-payload.yaml"),
57 ),
58 (
59 "mcp-spec-conformance",
60 include_str!("../../packs/mcp-spec-conformance.yaml"),
61 ),
62 ("pagination", include_str!("../../packs/pagination.yaml")),
63 (
64 "path-traversal",
65 include_str!("../../packs/path-traversal.yaml"),
66 ),
67 (
68 "prompt-injection",
69 include_str!("../../packs/prompt-injection.yaml"),
70 ),
71 (
72 "prompt-injection-v2",
73 include_str!("../../packs/prompt-injection-v2.yaml"),
74 ),
75 ("rate-limit", include_str!("../../packs/rate-limit.yaml")),
76 (
77 "secrets-leakage",
78 include_str!("../../packs/secrets-leakage.yaml"),
79 ),
80 ("security", include_str!("../../packs/security.yaml")),
81 ("stateful", include_str!("../../packs/stateful.yaml")),
82 (
83 "tool-annotations",
84 include_str!("../../packs/tool-annotations.yaml"),
85 ),
86 ("unicode", include_str!("../../packs/unicode.yaml")),
87];
88
89pub fn embedded_pack_names() -> impl Iterator<Item = &'static str> {
91 EMBEDDED_PACKS.iter().map(|(name, _)| *name)
92}
93
94pub fn embedded_pack_source(name: &str) -> Option<&'static str> {
98 EMBEDDED_PACKS
99 .iter()
100 .find(|(candidate, _)| *candidate == name)
101 .map(|(_, source)| *source)
102}
103
104#[derive(Debug, Error)]
106pub enum PackError {
107 #[error("pack `{name}` could not be loaded: {message}")]
109 Loader {
110 name: String,
112 message: String,
114 },
115 #[error(transparent)]
117 Dsl(#[from] DslError),
118 #[error("cyclic `extends` chain: {0}")]
120 Cycle(String),
121 #[error("`extends` chain exceeded depth {MAX_EXTENDS_DEPTH}")]
123 DepthExceeded,
124}
125
126pub trait PackLoader {
130 fn load(&self, name: &str) -> std::result::Result<String, String>;
133}
134
135impl<F> PackLoader for F
136where
137 F: Fn(&str) -> std::result::Result<String, String>,
138{
139 fn load(&self, name: &str) -> std::result::Result<String, String> {
140 self(name)
141 }
142}
143
144#[derive(Debug, Default, Clone, Copy)]
148pub struct EmbeddedLoader;
149
150impl PackLoader for EmbeddedLoader {
151 fn load(&self, name: &str) -> std::result::Result<String, String> {
152 embedded_pack_source(name)
153 .map(|source| source.to_string())
154 .ok_or_else(|| format!("no embedded pack named `{name}`"))
155 }
156}
157
158pub struct LayeredLoader<P: PackLoader, S: PackLoader> {
164 pub primary: P,
166 pub secondary: S,
168}
169
170impl<P: PackLoader, S: PackLoader> LayeredLoader<P, S> {
171 pub fn new(primary: P, secondary: S) -> Self {
174 Self { primary, secondary }
175 }
176}
177
178impl<P: PackLoader, S: PackLoader> PackLoader for LayeredLoader<P, S> {
179 fn load(&self, name: &str) -> std::result::Result<String, String> {
180 match self.primary.load(name) {
181 Ok(source) => Ok(source),
182 Err(primary_err) => self
183 .secondary
184 .load(name)
185 .map_err(|secondary_err| format!("{primary_err}; {secondary_err}")),
186 }
187 }
188}
189
190pub fn resolve(
200 source: &str,
201 overrides: &BTreeMap<String, String>,
202 loader: &dyn PackLoader,
203) -> std::result::Result<InvariantFile, PackError> {
204 let mut visited: BTreeSet<String> = BTreeSet::new();
205 resolve_inner(source, overrides, loader, &mut visited, 0)
206}
207
208fn resolve_inner(
209 source: &str,
210 overrides: &BTreeMap<String, String>,
211 loader: &dyn PackLoader,
212 visited: &mut BTreeSet<String>,
213 depth: usize,
214) -> std::result::Result<InvariantFile, PackError> {
215 if depth > MAX_EXTENDS_DEPTH {
216 return Err(PackError::DepthExceeded);
217 }
218 let mut file = parse_with_overrides(source, overrides)?;
219
220 let extends = file
223 .metadata
224 .as_mut()
225 .map(|m| std::mem::take(&mut m.extends))
226 .unwrap_or_default();
227
228 if extends.is_empty() {
229 return Ok(file);
230 }
231
232 let mut imported: Vec<crate::property::dsl::Invariant> = Vec::new();
233 let mut imported_for_each: Vec<crate::property::dsl::ForEachToolBlock> = Vec::new();
234 for parent_name in extends {
235 if !visited.insert(parent_name.clone()) {
236 return Err(PackError::Cycle(parent_name));
237 }
238 let parent_source = loader
239 .load(&parent_name)
240 .map_err(|message| PackError::Loader {
241 name: parent_name.clone(),
242 message,
243 })?;
244 let parent = resolve_inner(&parent_source, overrides, loader, visited, depth + 1)?;
245 visited.remove(&parent_name);
248 imported.extend(parent.invariants);
249 imported_for_each.extend(parent.for_each_tool);
250 }
251
252 imported.append(&mut file.invariants);
254 file.invariants = imported;
255 imported_for_each.append(&mut file.for_each_tool);
256 file.for_each_tool = imported_for_each;
257 Ok(file)
258}
259
260#[cfg(test)]
261#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
262mod tests {
263 use super::*;
264 use std::collections::HashMap;
265
266 struct MapLoader(HashMap<String, String>);
268 impl PackLoader for MapLoader {
269 fn load(&self, name: &str) -> std::result::Result<String, String> {
270 self.0
271 .get(name)
272 .cloned()
273 .ok_or_else(|| format!("unknown pack `{name}`"))
274 }
275 }
276
277 fn loader(packs: &[(&str, &str)]) -> MapLoader {
278 MapLoader(
279 packs
280 .iter()
281 .map(|(name, source)| ((*name).to_string(), (*source).to_string()))
282 .collect(),
283 )
284 }
285
286 #[test]
287 fn no_extends_passes_through() {
288 let source = r#"
289version: 3
290metadata:
291 name: solo
292invariants:
293 - name: t
294 tool: echo
295 fixed: {}
296 assert:
297 - kind: equals
298 lhs: { value: 1 }
299 rhs: { value: 1 }
300"#;
301 let file = resolve(source, &BTreeMap::new(), &loader(&[])).unwrap();
302 assert_eq!(file.invariants.len(), 1);
303 }
304
305 #[test]
306 fn extends_prepends_parent_invariants() {
307 let parent = r#"
308version: 3
309metadata:
310 name: parent
311invariants:
312 - name: parent.a
313 tool: echo
314 fixed: {}
315 assert: []
316"#;
317 let child = r#"
318version: 3
319metadata:
320 name: child
321 extends: [parent]
322invariants:
323 - name: child.a
324 tool: echo
325 fixed: {}
326 assert: []
327"#;
328 let file = resolve(child, &BTreeMap::new(), &loader(&[("parent", parent)])).unwrap();
329 let names: Vec<_> = file.invariants.iter().map(|i| i.name.clone()).collect();
330 assert_eq!(names, vec!["parent.a".to_string(), "child.a".to_string()]);
331 }
332
333 #[test]
334 fn cycle_is_detected() {
335 let a = r#"
336version: 3
337metadata:
338 name: a
339 extends: [b]
340invariants: []
341"#;
342 let b = r#"
343version: 3
344metadata:
345 name: b
346 extends: [a]
347invariants: []
348"#;
349 let err = resolve(a, &BTreeMap::new(), &loader(&[("a", a), ("b", b)])).unwrap_err();
350 assert!(matches!(err, PackError::Cycle(_)));
351 }
352
353 #[test]
354 fn depth_cap_is_enforced() {
355 let chain: Vec<(String, String)> = (0..6)
357 .map(|i| {
358 let name = format!("p{i}");
359 let next = if i == 5 {
360 String::new()
361 } else {
362 format!("[p{}]", i + 1)
363 };
364 let source = format!(
365 "version: 3\nmetadata:\n name: {name}\n extends: {next}\ninvariants: []\n"
366 );
367 (name, source)
368 })
369 .collect();
370 let pairs: Vec<(&str, &str)> = chain
371 .iter()
372 .map(|(name, src)| (name.as_str(), src.as_str()))
373 .collect();
374 let err = resolve(&chain[0].1, &BTreeMap::new(), &loader(&pairs)).unwrap_err();
375 assert!(matches!(err, PackError::DepthExceeded));
376 }
377
378 #[test]
379 fn loader_failure_surfaces() {
380 let child = r#"
381version: 3
382metadata:
383 name: child
384 extends: [missing]
385invariants: []
386"#;
387 let err = resolve(child, &BTreeMap::new(), &loader(&[])).unwrap_err();
388 match err {
389 PackError::Loader { name, .. } => assert_eq!(name, "missing"),
390 other => panic!("expected loader error, got {other:?}"),
391 }
392 }
393}