Skip to main content

symbi_runtime/toolclad/
scope.rs

1//! Scope enforcement for ToolClad arguments
2//!
3//! Validates scope_target, ip_address, cidr, and url arguments
4//! against a project scope definition (scope/scope.toml).
5
6use std::net::IpAddr;
7use std::path::Path;
8
9/// Project scope definition.
10#[derive(Debug, Clone, Default)]
11pub struct Scope {
12    pub targets: Vec<String>,
13    pub domains: Vec<String>,
14    pub exclude: Vec<String>,
15}
16
17impl Scope {
18    /// Load scope from scope/scope.toml if it exists.
19    pub fn load(project_dir: &Path) -> Option<Self> {
20        let path = project_dir.join("scope").join("scope.toml");
21        if !path.exists() {
22            return None;
23        }
24        let content = std::fs::read_to_string(&path).ok()?;
25        let table: toml::Table = toml::from_str(&content).ok()?;
26        let scope = table.get("scope")?;
27
28        let targets = scope
29            .get("targets")
30            .and_then(|v| v.as_array())
31            .map(|a| {
32                a.iter()
33                    .filter_map(|v| v.as_str().map(String::from))
34                    .collect()
35            })
36            .unwrap_or_default();
37        let domains = scope
38            .get("domains")
39            .and_then(|v| v.as_array())
40            .map(|a| {
41                a.iter()
42                    .filter_map(|v| v.as_str().map(String::from))
43                    .collect()
44            })
45            .unwrap_or_default();
46        let exclude = scope
47            .get("exclude")
48            .and_then(|v| v.as_array())
49            .map(|a| {
50                a.iter()
51                    .filter_map(|v| v.as_str().map(String::from))
52                    .collect()
53            })
54            .unwrap_or_default();
55
56        Some(Scope {
57            targets,
58            domains,
59            exclude,
60        })
61    }
62
63    /// Check if a target (IP, CIDR, or hostname) is within scope.
64    pub fn check(&self, target: &str) -> Result<(), String> {
65        // Check exclusions first
66        if self.exclude.contains(&target.to_string()) {
67            return Err(format!(
68                "Target '{}' is explicitly excluded from scope",
69                target
70            ));
71        }
72
73        // Try as IP address
74        if let Ok(ip) = target.parse::<IpAddr>() {
75            return self.check_ip(ip, target);
76        }
77
78        // Try as CIDR
79        if target.contains('/') {
80            let parts: Vec<&str> = target.split('/').collect();
81            if let Ok(ip) = parts[0].parse::<IpAddr>() {
82                return self.check_ip(ip, target);
83            }
84        }
85
86        // Treat as hostname
87        self.check_hostname(target)
88    }
89
90    fn check_ip(&self, ip: IpAddr, original: &str) -> Result<(), String> {
91        for scope_target in &self.targets {
92            if scope_target.contains('/') {
93                // CIDR range check
94                if ip_in_cidr(ip, scope_target) {
95                    return Ok(());
96                }
97            } else if let Ok(scope_ip) = scope_target.parse::<IpAddr>() {
98                if ip == scope_ip {
99                    return Ok(());
100                }
101            }
102        }
103        Err(format!(
104            "Target '{}' is not in scope (allowed: {})",
105            original,
106            self.targets.join(", ")
107        ))
108    }
109
110    fn check_hostname(&self, hostname: &str) -> Result<(), String> {
111        for domain in &self.domains {
112            if domain.starts_with("*.") {
113                let suffix = &domain[1..]; // .example.com
114                if hostname.ends_with(suffix) || hostname == &domain[2..] {
115                    return Ok(());
116                }
117            } else if hostname == domain {
118                return Ok(());
119            }
120        }
121        // Also check if hostname matches any target string exactly
122        if self.targets.contains(&hostname.to_string()) {
123            return Ok(());
124        }
125        Err(format!(
126            "Target '{}' is not in scope (allowed domains: {})",
127            hostname,
128            self.domains.join(", ")
129        ))
130    }
131}
132
133/// Check if an IP address falls within a CIDR range.
134fn ip_in_cidr(ip: IpAddr, cidr: &str) -> bool {
135    let parts: Vec<&str> = cidr.split('/').collect();
136    if parts.len() != 2 {
137        return false;
138    }
139    let Ok(network) = parts[0].parse::<IpAddr>() else {
140        return false;
141    };
142    let Ok(prefix_len) = parts[1].parse::<u32>() else {
143        return false;
144    };
145
146    match (ip, network) {
147        (IpAddr::V4(ip4), IpAddr::V4(net4)) => {
148            if prefix_len > 32 {
149                return false;
150            }
151            let mask = if prefix_len == 0 {
152                0u32
153            } else {
154                !0u32 << (32 - prefix_len)
155            };
156            let ip_bits = u32::from(ip4);
157            let net_bits = u32::from(net4);
158            (ip_bits & mask) == (net_bits & mask)
159        }
160        (IpAddr::V6(ip6), IpAddr::V6(net6)) => {
161            if prefix_len > 128 {
162                return false;
163            }
164            let ip_bits = u128::from(ip6);
165            let net_bits = u128::from(net6);
166            let mask = if prefix_len == 0 {
167                0u128
168            } else {
169                !0u128 << (128 - prefix_len)
170            };
171            (ip_bits & mask) == (net_bits & mask)
172        }
173        _ => false, // Mismatched IP versions
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_ip_in_cidr() {
183        assert!(ip_in_cidr("10.0.1.5".parse().unwrap(), "10.0.1.0/24"));
184        assert!(ip_in_cidr("10.0.1.255".parse().unwrap(), "10.0.1.0/24"));
185        assert!(!ip_in_cidr("10.0.2.1".parse().unwrap(), "10.0.1.0/24"));
186        assert!(ip_in_cidr("192.168.0.1".parse().unwrap(), "192.168.0.0/16"));
187    }
188
189    #[test]
190    fn test_scope_check_ip() {
191        let scope = Scope {
192            targets: vec!["10.0.1.0/24".into(), "192.168.1.0/24".into()],
193            domains: vec![],
194            exclude: vec!["10.0.1.1".into()],
195        };
196        assert!(scope.check("10.0.1.5").is_ok());
197        assert!(scope.check("10.0.1.1").is_err()); // excluded
198        assert!(scope.check("10.0.2.1").is_err()); // out of range
199        assert!(scope.check("192.168.1.100").is_ok());
200    }
201
202    #[test]
203    fn test_scope_check_hostname() {
204        let scope = Scope {
205            targets: vec![],
206            domains: vec!["example.com".into(), "*.test.example.com".into()],
207            exclude: vec![],
208        };
209        assert!(scope.check("example.com").is_ok());
210        assert!(scope.check("foo.test.example.com").is_ok());
211        assert!(scope.check("test.example.com").is_ok());
212        assert!(scope.check("evil.com").is_err());
213    }
214
215    #[test]
216    fn test_scope_check_cidr_target() {
217        let scope = Scope {
218            targets: vec!["10.0.1.0/24".into()],
219            domains: vec![],
220            exclude: vec![],
221        };
222        // CIDR target — check the network address
223        assert!(scope.check("10.0.1.0/28").is_ok());
224        assert!(scope.check("10.0.2.0/24").is_err());
225    }
226}