Skip to main content

rust_bash/
network.rs

1use std::collections::HashSet;
2use std::time::Duration;
3
4/// Policy controlling network access for commands like `curl`.
5///
6/// Disabled by default — scripts have no network access unless the embedder
7/// explicitly enables it and configures an allow-list.
8#[derive(Clone, Debug)]
9pub struct NetworkPolicy {
10    pub enabled: bool,
11    pub allowed_url_prefixes: Vec<String>,
12    pub allowed_methods: HashSet<String>,
13    pub max_redirects: usize,
14    pub max_response_size: usize,
15    pub timeout: Duration,
16}
17
18impl Default for NetworkPolicy {
19    fn default() -> Self {
20        Self {
21            enabled: false,
22            allowed_url_prefixes: Vec::new(),
23            allowed_methods: HashSet::from(["GET".to_string(), "POST".to_string()]),
24            max_redirects: 5,
25            max_response_size: 10 * 1024 * 1024, // 10 MB
26            timeout: Duration::from_secs(30),
27        }
28    }
29}
30
31impl NetworkPolicy {
32    /// Validate that `url` matches at least one entry in `allowed_url_prefixes`.
33    ///
34    /// The raw URL is first parsed and re-serialized via `url::Url` to
35    /// normalize it (resolve default ports, percent-encoding, etc.), and then
36    /// each allowed prefix is checked with a simple `starts_with`.
37    /// Prefixes are also normalized via `url::Url` when possible to prevent
38    /// subdomain confusion attacks (e.g. a prefix of `"https://api.example.com"`
39    /// without a trailing slash would otherwise match `"https://api.example.com.evil.com/"`).
40    pub fn validate_url(&self, url: &str) -> Result<(), String> {
41        let parsed = url::Url::parse(url).map_err(|e| format!("invalid URL '{url}': {e}"))?;
42        let normalized = parsed.as_str();
43
44        for prefix in &self.allowed_url_prefixes {
45            let norm_prefix = url::Url::parse(prefix)
46                .map(|u| u.to_string())
47                .unwrap_or_else(|_| prefix.clone());
48            if normalized.starts_with(&norm_prefix) {
49                return Ok(());
50            }
51        }
52
53        Err(format!("URL not allowed by network policy: {normalized}"))
54    }
55
56    /// Validate that `method` is in the set of allowed HTTP methods.
57    pub fn validate_method(&self, method: &str) -> Result<(), String> {
58        let upper = method.to_uppercase();
59        if self.allowed_methods.contains(&upper) {
60            Ok(())
61        } else {
62            Err(format!(
63                "HTTP method not allowed by network policy: {upper}"
64            ))
65        }
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72
73    #[test]
74    fn default_is_disabled() {
75        let policy = NetworkPolicy::default();
76        assert!(!policy.enabled);
77    }
78
79    #[test]
80    fn default_allows_get_and_post() {
81        let policy = NetworkPolicy::default();
82        assert!(policy.allowed_methods.contains("GET"));
83        assert!(policy.allowed_methods.contains("POST"));
84        assert!(!policy.allowed_methods.contains("DELETE"));
85    }
86
87    #[test]
88    fn validate_url_matches_prefix() {
89        let policy = NetworkPolicy {
90            allowed_url_prefixes: vec!["https://api.example.com/".to_string()],
91            ..Default::default()
92        };
93        assert!(
94            policy
95                .validate_url("https://api.example.com/v1/data")
96                .is_ok()
97        );
98        assert!(
99            policy
100                .validate_url("https://api.example.com/users?id=1")
101                .is_ok()
102        );
103    }
104
105    #[test]
106    fn validate_url_rejects_different_domain() {
107        let policy = NetworkPolicy {
108            allowed_url_prefixes: vec!["https://api.example.com/".to_string()],
109            ..Default::default()
110        };
111        assert!(
112            policy
113                .validate_url("https://api.example.com.evil.org/")
114                .is_err()
115        );
116    }
117
118    #[test]
119    fn validate_url_rejects_different_scheme() {
120        let policy = NetworkPolicy {
121            allowed_url_prefixes: vec!["https://api.example.com/".to_string()],
122            ..Default::default()
123        };
124        assert!(policy.validate_url("http://api.example.com/").is_err());
125    }
126
127    #[test]
128    fn validate_url_rejects_subdomain_without_trailing_slash() {
129        let policy = NetworkPolicy {
130            allowed_url_prefixes: vec!["https://api.example.com".to_string()],
131            ..Default::default()
132        };
133        // Must NOT match evil subdomain even without trailing slash in prefix
134        assert!(
135            policy
136                .validate_url("https://api.example.com.evil.com/")
137                .is_err()
138        );
139        // But the intended domain should still work
140        assert!(
141            policy
142                .validate_url("https://api.example.com/v1/data")
143                .is_ok()
144        );
145    }
146
147    #[test]
148    fn validate_url_rejects_userinfo_attack() {
149        let policy = NetworkPolicy {
150            allowed_url_prefixes: vec!["https://api.example.com/".to_string()],
151            ..Default::default()
152        };
153        // url::Url normalizes this so the prefix check catches it
154        assert!(
155            policy
156                .validate_url("https://api.example.com@evil.com/")
157                .is_err()
158        );
159    }
160
161    #[test]
162    fn validate_url_no_prefixes_rejects_all() {
163        let policy = NetworkPolicy::default();
164        assert!(policy.validate_url("https://example.com/").is_err());
165    }
166
167    #[test]
168    fn validate_url_invalid_url() {
169        let policy = NetworkPolicy::default();
170        assert!(policy.validate_url("not a url").is_err());
171    }
172
173    #[test]
174    fn validate_method_allowed() {
175        let policy = NetworkPolicy::default();
176        assert!(policy.validate_method("GET").is_ok());
177        assert!(policy.validate_method("get").is_ok());
178        assert!(policy.validate_method("POST").is_ok());
179    }
180
181    #[test]
182    fn validate_method_rejected() {
183        let policy = NetworkPolicy::default();
184        assert!(policy.validate_method("DELETE").is_err());
185        assert!(policy.validate_method("PUT").is_err());
186    }
187}