Skip to main content

taudit_core/
custom_rules.rs

1use crate::finding::{Finding, FindingCategory, Recommendation, Severity};
2use crate::graph::{AuthorityGraph, NodeKind, TrustZone};
3use crate::propagation::PropagationPath;
4use serde::Deserialize;
5use std::collections::HashMap;
6use std::fmt;
7use std::fs;
8use std::io;
9use std::path::{Path, PathBuf};
10
11/// A user-defined rule loaded from YAML. Fires when source, sink, and path
12/// predicates all match a propagation path produced by the engine.
13#[derive(Debug, Clone, Deserialize)]
14pub struct CustomRule {
15    pub id: String,
16    pub name: String,
17    #[serde(default)]
18    pub description: String,
19    pub severity: Severity,
20    pub category: FindingCategory,
21    #[serde(rename = "match", default)]
22    pub match_spec: MatchSpec,
23}
24
25#[derive(Debug, Clone, Default, Deserialize)]
26pub struct MatchSpec {
27    #[serde(default)]
28    pub source: NodeMatcher,
29    #[serde(default)]
30    pub sink: NodeMatcher,
31    #[serde(default)]
32    pub path: PathMatcher,
33}
34
35#[derive(Debug, Clone, Default, Deserialize)]
36pub struct NodeMatcher {
37    #[serde(default)]
38    pub node_type: Option<NodeKind>,
39    #[serde(default)]
40    pub trust_zone: Option<TrustZone>,
41    #[serde(default)]
42    pub metadata: HashMap<String, String>,
43}
44
45#[derive(Debug, Clone, Default, Deserialize)]
46pub struct PathMatcher {
47    #[serde(default)]
48    pub crosses_to: Vec<TrustZone>,
49}
50
51#[derive(Debug)]
52pub enum CustomRuleError {
53    FileRead(PathBuf, io::Error),
54    YamlParse(PathBuf, serde_yaml::Error),
55}
56
57impl fmt::Display for CustomRuleError {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        match self {
60            CustomRuleError::FileRead(path, err) => {
61                write!(
62                    f,
63                    "failed to read custom rule file {}: {err}",
64                    path.display()
65                )
66            }
67            CustomRuleError::YamlParse(path, err) => {
68                write!(
69                    f,
70                    "failed to parse custom rule file {}: {err}",
71                    path.display()
72                )
73            }
74        }
75    }
76}
77
78impl std::error::Error for CustomRuleError {
79    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
80        match self {
81            CustomRuleError::FileRead(_, err) => Some(err),
82            CustomRuleError::YamlParse(_, err) => Some(err),
83        }
84    }
85}
86
87/// Load all `*.yml` and `*.yaml` files from `dir`. Files are read in sorted
88/// order for deterministic output. Returns a list of all errors alongside
89/// successfully parsed rules — callers decide whether to fail fast or continue.
90pub fn load_rules_dir(dir: &Path) -> Result<Vec<CustomRule>, Vec<CustomRuleError>> {
91    let mut entries: Vec<PathBuf> = Vec::new();
92    let read_dir = match fs::read_dir(dir) {
93        Ok(rd) => rd,
94        Err(err) => return Err(vec![CustomRuleError::FileRead(dir.to_path_buf(), err)]),
95    };
96
97    for entry in read_dir.flatten() {
98        let path = entry.path();
99        if !path.is_file() {
100            continue;
101        }
102        match path.extension().and_then(|e| e.to_str()) {
103            Some("yml") | Some("yaml") => entries.push(path),
104            _ => {}
105        }
106    }
107    entries.sort();
108
109    let mut rules = Vec::new();
110    let mut errors = Vec::new();
111    for path in entries {
112        match fs::read_to_string(&path) {
113            Ok(content) => match serde_yaml::from_str::<CustomRule>(&content) {
114                Ok(rule) => rules.push(rule),
115                Err(err) => errors.push(CustomRuleError::YamlParse(path, err)),
116            },
117            Err(err) => errors.push(CustomRuleError::FileRead(path, err)),
118        }
119    }
120
121    if errors.is_empty() {
122        Ok(rules)
123    } else {
124        Err(errors)
125    }
126}
127
128impl NodeMatcher {
129    fn matches(&self, node: &crate::graph::Node) -> bool {
130        if let Some(kind) = self.node_type {
131            if node.kind != kind {
132                return false;
133            }
134        }
135        if let Some(zone) = self.trust_zone {
136            if node.trust_zone != zone {
137                return false;
138            }
139        }
140        for (key, expected) in &self.metadata {
141            match node.metadata.get(key) {
142                Some(actual) if actual == expected => {}
143                _ => return false,
144            }
145        }
146        true
147    }
148}
149
150impl PathMatcher {
151    fn matches(&self, path: &PropagationPath) -> bool {
152        if self.crosses_to.is_empty() {
153            return true;
154        }
155        match path.boundary_crossing {
156            Some((_, to_zone)) => self.crosses_to.contains(&to_zone),
157            None => false,
158        }
159    }
160}
161
162/// Evaluate every (rule, path) pair. A finding is produced when the rule's
163/// source, sink, and path predicates all match. Findings carry the rule id in
164/// the message so operators can trace back to the originating YAML.
165pub fn evaluate_custom_rules(
166    graph: &AuthorityGraph,
167    paths: &[PropagationPath],
168    rules: &[CustomRule],
169) -> Vec<Finding> {
170    let mut findings = Vec::new();
171
172    for rule in rules {
173        for path in paths {
174            let source_node = match graph.node(path.source) {
175                Some(n) => n,
176                None => continue,
177            };
178            let sink_node = match graph.node(path.sink) {
179                Some(n) => n,
180                None => continue,
181            };
182
183            if !rule.match_spec.source.matches(source_node) {
184                continue;
185            }
186            if !rule.match_spec.sink.matches(sink_node) {
187                continue;
188            }
189            if !rule.match_spec.path.matches(path) {
190                continue;
191            }
192
193            findings.push(Finding {
194                severity: rule.severity,
195                category: rule.category,
196                nodes_involved: vec![path.source, path.sink],
197                message: format!(
198                    "[{}] {}: {} -> {}",
199                    rule.id, rule.name, source_node.name, sink_node.name
200                ),
201                recommendation: Recommendation::Manual {
202                    action: if rule.description.is_empty() {
203                        format!("Review custom rule '{}'", rule.id)
204                    } else {
205                        rule.description.clone()
206                    },
207                },
208                path: Some(path.clone()),
209            });
210        }
211    }
212
213    findings
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::graph::{AuthorityGraph, EdgeKind, PipelineSource};
220    use crate::propagation::{propagation_analysis, DEFAULT_MAX_HOPS};
221
222    fn source() -> PipelineSource {
223        PipelineSource {
224            file: "test.yml".into(),
225            repo: None,
226            git_ref: None,
227        }
228    }
229
230    fn build_graph_with_paths() -> (AuthorityGraph, Vec<PropagationPath>) {
231        let mut g = AuthorityGraph::new(source());
232        let secret = g.add_node(NodeKind::Secret, "API_KEY", TrustZone::FirstParty);
233        let trusted = g.add_node(NodeKind::Step, "build", TrustZone::FirstParty);
234        let untrusted = g.add_node(NodeKind::Step, "third-party", TrustZone::Untrusted);
235
236        g.add_edge(trusted, secret, EdgeKind::HasAccessTo);
237        g.add_edge(trusted, untrusted, EdgeKind::DelegatesTo);
238
239        let paths = propagation_analysis(&g, DEFAULT_MAX_HOPS);
240        (g, paths)
241    }
242
243    #[test]
244    fn custom_rule_fires_on_matching_path() {
245        let (graph, paths) = build_graph_with_paths();
246
247        let rule = CustomRule {
248            id: "secret_to_untrusted".into(),
249            name: "Secret reaching untrusted step".into(),
250            description: "Custom policy".into(),
251            severity: Severity::Critical,
252            category: FindingCategory::AuthorityPropagation,
253            match_spec: MatchSpec {
254                source: NodeMatcher {
255                    node_type: None,
256                    trust_zone: Some(TrustZone::FirstParty),
257                    metadata: HashMap::new(),
258                },
259                sink: NodeMatcher {
260                    node_type: None,
261                    trust_zone: Some(TrustZone::Untrusted),
262                    metadata: HashMap::new(),
263                },
264                path: PathMatcher::default(),
265            },
266        };
267
268        let findings = evaluate_custom_rules(&graph, &paths, &[rule]);
269        assert_eq!(findings.len(), 1);
270        assert_eq!(findings[0].severity, Severity::Critical);
271        assert!(findings[0].message.contains("secret_to_untrusted"));
272    }
273
274    #[test]
275    fn custom_rule_does_not_fire_when_predicates_miss() {
276        let (graph, paths) = build_graph_with_paths();
277
278        let rule = CustomRule {
279            id: "miss".into(),
280            name: "Untrusted source".into(),
281            description: String::new(),
282            severity: Severity::Critical,
283            category: FindingCategory::AuthorityPropagation,
284            match_spec: MatchSpec {
285                source: NodeMatcher {
286                    node_type: None,
287                    trust_zone: Some(TrustZone::Untrusted),
288                    metadata: HashMap::new(),
289                },
290                sink: NodeMatcher::default(),
291                path: PathMatcher::default(),
292            },
293        };
294
295        let findings = evaluate_custom_rules(&graph, &paths, &[rule]);
296        assert!(findings.is_empty());
297    }
298
299    #[test]
300    fn yaml_round_trip_loads_full_rule() {
301        let yaml = r#"
302id: my_secret_to_untrusted
303name: Secret reaching untrusted step
304description: "Custom policy: secrets must not reach untrusted steps"
305severity: critical
306category: authority_propagation
307match:
308  source:
309    node_type: secret
310    trust_zone: first_party
311  sink:
312    node_type: step
313    trust_zone: untrusted
314  path:
315    crosses_to: [untrusted]
316"#;
317        let rule: CustomRule = serde_yaml::from_str(yaml).expect("yaml must parse");
318        assert_eq!(rule.id, "my_secret_to_untrusted");
319        assert_eq!(rule.severity, Severity::Critical);
320        assert_eq!(rule.match_spec.source.node_type, Some(NodeKind::Secret));
321        assert_eq!(rule.match_spec.sink.trust_zone, Some(TrustZone::Untrusted));
322        assert_eq!(rule.match_spec.path.crosses_to, vec![TrustZone::Untrusted]);
323    }
324
325    #[test]
326    fn metadata_predicate_must_match_all_keys() {
327        let mut g = AuthorityGraph::new(source());
328        let mut meta = HashMap::new();
329        meta.insert("kind".to_string(), "deploy".to_string());
330        let secret =
331            g.add_node_with_metadata(NodeKind::Secret, "TOKEN", TrustZone::FirstParty, meta);
332        let sink = g.add_node(NodeKind::Step, "remote", TrustZone::Untrusted);
333        let step = g.add_node(NodeKind::Step, "use", TrustZone::FirstParty);
334        g.add_edge(step, secret, EdgeKind::HasAccessTo);
335        g.add_edge(step, sink, EdgeKind::DelegatesTo);
336
337        let paths = propagation_analysis(&g, DEFAULT_MAX_HOPS);
338
339        let mut want = HashMap::new();
340        want.insert("kind".to_string(), "deploy".to_string());
341        let hit = CustomRule {
342            id: "hit".into(),
343            name: "n".into(),
344            description: String::new(),
345            severity: Severity::High,
346            category: FindingCategory::AuthorityPropagation,
347            match_spec: MatchSpec {
348                source: NodeMatcher {
349                    node_type: Some(NodeKind::Secret),
350                    trust_zone: None,
351                    metadata: want.clone(),
352                },
353                sink: NodeMatcher::default(),
354                path: PathMatcher::default(),
355            },
356        };
357        assert_eq!(evaluate_custom_rules(&g, &paths, &[hit]).len(), 1);
358
359        let mut wrong = HashMap::new();
360        wrong.insert("kind".to_string(), "build".to_string());
361        let miss = CustomRule {
362            id: "miss".into(),
363            name: "n".into(),
364            description: String::new(),
365            severity: Severity::High,
366            category: FindingCategory::AuthorityPropagation,
367            match_spec: MatchSpec {
368                source: NodeMatcher {
369                    node_type: Some(NodeKind::Secret),
370                    trust_zone: None,
371                    metadata: wrong,
372                },
373                sink: NodeMatcher::default(),
374                path: PathMatcher::default(),
375            },
376        };
377        assert!(evaluate_custom_rules(&g, &paths, &[miss]).is_empty());
378    }
379
380    #[test]
381    fn load_rules_dir_reads_yml_and_yaml() {
382        let tmp = std::env::temp_dir().join(format!("taudit-custom-rules-{}", std::process::id()));
383        fs::create_dir_all(&tmp).unwrap();
384        let yml_path = tmp.join("a.yml");
385        let yaml_path = tmp.join("b.yaml");
386        let other_path = tmp.join("c.txt");
387
388        fs::write(
389            &yml_path,
390            "id: a\nname: a\nseverity: high\ncategory: authority_propagation\n",
391        )
392        .unwrap();
393        fs::write(
394            &yaml_path,
395            "id: b\nname: b\nseverity: medium\ncategory: unpinned_action\n",
396        )
397        .unwrap();
398        fs::write(&other_path, "ignored").unwrap();
399
400        let rules = load_rules_dir(&tmp).expect("load must succeed");
401        assert_eq!(rules.len(), 2);
402        assert_eq!(rules[0].id, "a");
403        assert_eq!(rules[1].id, "b");
404
405        // cleanup
406        let _ = fs::remove_dir_all(&tmp);
407    }
408
409    #[test]
410    fn load_rules_dir_reports_yaml_errors_with_path() {
411        let tmp =
412            std::env::temp_dir().join(format!("taudit-custom-rules-bad-{}", std::process::id()));
413        fs::create_dir_all(&tmp).unwrap();
414        let bad = tmp.join("bad.yml");
415        fs::write(&bad, "id: x\nseverity: not-a-real-severity\n").unwrap();
416
417        let errs = load_rules_dir(&tmp).expect_err("should fail");
418        assert_eq!(errs.len(), 1);
419        let msg = errs[0].to_string();
420        assert!(msg.contains("bad.yml"), "error must mention path: {msg}");
421
422        let _ = fs::remove_dir_all(&tmp);
423    }
424}