1use std::collections::{BTreeMap, BTreeSet};
22
23use thiserror::Error;
24
25use crate::property::dsl::{parse_with_overrides, DslError, InvariantFile, MAX_EXTENDS_DEPTH};
26
27pub const EMBEDDED_PACKS: &[(&str, &str)] = &[
34 ("auth", include_str!("../../packs/auth.yaml")),
35 (
36 "authorization",
37 include_str!("../../packs/authorization.yaml"),
38 ),
39 ("error-shape", include_str!("../../packs/error-shape.yaml")),
40 ("idempotency", include_str!("../../packs/idempotency.yaml")),
41 (
42 "injection-shell",
43 include_str!("../../packs/injection-shell.yaml"),
44 ),
45 (
46 "injection-sql",
47 include_str!("../../packs/injection-sql.yaml"),
48 ),
49 (
50 "large-payload",
51 include_str!("../../packs/large-payload.yaml"),
52 ),
53 ("pagination", include_str!("../../packs/pagination.yaml")),
54 (
55 "path-traversal",
56 include_str!("../../packs/path-traversal.yaml"),
57 ),
58 (
59 "prompt-injection",
60 include_str!("../../packs/prompt-injection.yaml"),
61 ),
62 ("rate-limit", include_str!("../../packs/rate-limit.yaml")),
63 (
64 "secrets-leakage",
65 include_str!("../../packs/secrets-leakage.yaml"),
66 ),
67 ("security", include_str!("../../packs/security.yaml")),
68 (
69 "tool-annotations",
70 include_str!("../../packs/tool-annotations.yaml"),
71 ),
72 ("unicode", include_str!("../../packs/unicode.yaml")),
73];
74
75pub fn embedded_pack_names() -> impl Iterator<Item = &'static str> {
77 EMBEDDED_PACKS.iter().map(|(name, _)| *name)
78}
79
80pub fn embedded_pack_source(name: &str) -> Option<&'static str> {
84 EMBEDDED_PACKS
85 .iter()
86 .find(|(candidate, _)| *candidate == name)
87 .map(|(_, source)| *source)
88}
89
90#[derive(Debug, Error)]
92pub enum PackError {
93 #[error("pack `{name}` could not be loaded: {message}")]
95 Loader {
96 name: String,
98 message: String,
100 },
101 #[error(transparent)]
103 Dsl(#[from] DslError),
104 #[error("cyclic `extends` chain: {0}")]
106 Cycle(String),
107 #[error("`extends` chain exceeded depth {MAX_EXTENDS_DEPTH}")]
109 DepthExceeded,
110}
111
112pub trait PackLoader {
116 fn load(&self, name: &str) -> std::result::Result<String, String>;
119}
120
121impl<F> PackLoader for F
122where
123 F: Fn(&str) -> std::result::Result<String, String>,
124{
125 fn load(&self, name: &str) -> std::result::Result<String, String> {
126 self(name)
127 }
128}
129
130#[derive(Debug, Default, Clone, Copy)]
134pub struct EmbeddedLoader;
135
136impl PackLoader for EmbeddedLoader {
137 fn load(&self, name: &str) -> std::result::Result<String, String> {
138 embedded_pack_source(name)
139 .map(|source| source.to_string())
140 .ok_or_else(|| format!("no embedded pack named `{name}`"))
141 }
142}
143
144pub struct LayeredLoader<P: PackLoader, S: PackLoader> {
150 pub primary: P,
152 pub secondary: S,
154}
155
156impl<P: PackLoader, S: PackLoader> LayeredLoader<P, S> {
157 pub fn new(primary: P, secondary: S) -> Self {
160 Self { primary, secondary }
161 }
162}
163
164impl<P: PackLoader, S: PackLoader> PackLoader for LayeredLoader<P, S> {
165 fn load(&self, name: &str) -> std::result::Result<String, String> {
166 match self.primary.load(name) {
167 Ok(source) => Ok(source),
168 Err(primary_err) => self
169 .secondary
170 .load(name)
171 .map_err(|secondary_err| format!("{primary_err}; {secondary_err}")),
172 }
173 }
174}
175
176pub fn resolve(
186 source: &str,
187 overrides: &BTreeMap<String, String>,
188 loader: &dyn PackLoader,
189) -> std::result::Result<InvariantFile, PackError> {
190 let mut visited: BTreeSet<String> = BTreeSet::new();
191 resolve_inner(source, overrides, loader, &mut visited, 0)
192}
193
194fn resolve_inner(
195 source: &str,
196 overrides: &BTreeMap<String, String>,
197 loader: &dyn PackLoader,
198 visited: &mut BTreeSet<String>,
199 depth: usize,
200) -> std::result::Result<InvariantFile, PackError> {
201 if depth > MAX_EXTENDS_DEPTH {
202 return Err(PackError::DepthExceeded);
203 }
204 let mut file = parse_with_overrides(source, overrides)?;
205
206 let extends = file
209 .metadata
210 .as_mut()
211 .map(|m| std::mem::take(&mut m.extends))
212 .unwrap_or_default();
213
214 if extends.is_empty() {
215 return Ok(file);
216 }
217
218 let mut imported: Vec<crate::property::dsl::Invariant> = Vec::new();
219 let mut imported_for_each: Vec<crate::property::dsl::ForEachToolBlock> = Vec::new();
220 for parent_name in extends {
221 if !visited.insert(parent_name.clone()) {
222 return Err(PackError::Cycle(parent_name));
223 }
224 let parent_source = loader
225 .load(&parent_name)
226 .map_err(|message| PackError::Loader {
227 name: parent_name.clone(),
228 message,
229 })?;
230 let parent = resolve_inner(&parent_source, overrides, loader, visited, depth + 1)?;
231 visited.remove(&parent_name);
234 imported.extend(parent.invariants);
235 imported_for_each.extend(parent.for_each_tool);
236 }
237
238 imported.append(&mut file.invariants);
240 file.invariants = imported;
241 imported_for_each.append(&mut file.for_each_tool);
242 file.for_each_tool = imported_for_each;
243 Ok(file)
244}
245
246#[cfg(test)]
247#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
248mod tests {
249 use super::*;
250 use std::collections::HashMap;
251
252 struct MapLoader(HashMap<String, String>);
254 impl PackLoader for MapLoader {
255 fn load(&self, name: &str) -> std::result::Result<String, String> {
256 self.0
257 .get(name)
258 .cloned()
259 .ok_or_else(|| format!("unknown pack `{name}`"))
260 }
261 }
262
263 fn loader(packs: &[(&str, &str)]) -> MapLoader {
264 MapLoader(
265 packs
266 .iter()
267 .map(|(name, source)| ((*name).to_string(), (*source).to_string()))
268 .collect(),
269 )
270 }
271
272 #[test]
273 fn no_extends_passes_through() {
274 let source = r#"
275version: 3
276metadata:
277 name: solo
278invariants:
279 - name: t
280 tool: echo
281 fixed: {}
282 assert:
283 - kind: equals
284 lhs: { value: 1 }
285 rhs: { value: 1 }
286"#;
287 let file = resolve(source, &BTreeMap::new(), &loader(&[])).unwrap();
288 assert_eq!(file.invariants.len(), 1);
289 }
290
291 #[test]
292 fn extends_prepends_parent_invariants() {
293 let parent = r#"
294version: 3
295metadata:
296 name: parent
297invariants:
298 - name: parent.a
299 tool: echo
300 fixed: {}
301 assert: []
302"#;
303 let child = r#"
304version: 3
305metadata:
306 name: child
307 extends: [parent]
308invariants:
309 - name: child.a
310 tool: echo
311 fixed: {}
312 assert: []
313"#;
314 let file = resolve(child, &BTreeMap::new(), &loader(&[("parent", parent)])).unwrap();
315 let names: Vec<_> = file.invariants.iter().map(|i| i.name.clone()).collect();
316 assert_eq!(names, vec!["parent.a".to_string(), "child.a".to_string()]);
317 }
318
319 #[test]
320 fn cycle_is_detected() {
321 let a = r#"
322version: 3
323metadata:
324 name: a
325 extends: [b]
326invariants: []
327"#;
328 let b = r#"
329version: 3
330metadata:
331 name: b
332 extends: [a]
333invariants: []
334"#;
335 let err = resolve(a, &BTreeMap::new(), &loader(&[("a", a), ("b", b)])).unwrap_err();
336 assert!(matches!(err, PackError::Cycle(_)));
337 }
338
339 #[test]
340 fn depth_cap_is_enforced() {
341 let chain: Vec<(String, String)> = (0..6)
343 .map(|i| {
344 let name = format!("p{i}");
345 let next = if i == 5 {
346 String::new()
347 } else {
348 format!("[p{}]", i + 1)
349 };
350 let source = format!(
351 "version: 3\nmetadata:\n name: {name}\n extends: {next}\ninvariants: []\n"
352 );
353 (name, source)
354 })
355 .collect();
356 let pairs: Vec<(&str, &str)> = chain
357 .iter()
358 .map(|(name, src)| (name.as_str(), src.as_str()))
359 .collect();
360 let err = resolve(&chain[0].1, &BTreeMap::new(), &loader(&pairs)).unwrap_err();
361 assert!(matches!(err, PackError::DepthExceeded));
362 }
363
364 #[test]
365 fn loader_failure_surfaces() {
366 let child = r#"
367version: 3
368metadata:
369 name: child
370 extends: [missing]
371invariants: []
372"#;
373 let err = resolve(child, &BTreeMap::new(), &loader(&[])).unwrap_err();
374 match err {
375 PackError::Loader { name, .. } => assert_eq!(name, "missing"),
376 other => panic!("expected loader error, got {other:?}"),
377 }
378 }
379}