Skip to main content

zerobox_network_proxy/
state.rs

1use crate::config::NetworkDomainPermissions;
2use crate::config::NetworkMode;
3use crate::config::NetworkProxyConfig;
4use crate::config::NetworkUnixSocketPermissions;
5use crate::mitm::MitmState;
6use crate::policy::DomainPattern;
7use crate::policy::compile_allowlist_globset;
8use crate::policy::compile_denylist_globset;
9use crate::policy::is_global_wildcard_domain_pattern;
10use crate::runtime::ConfigState;
11use serde::Deserialize;
12use std::collections::HashSet;
13use std::sync::Arc;
14
15pub use crate::runtime::BlockedRequest;
16pub use crate::runtime::BlockedRequestArgs;
17pub use crate::runtime::NetworkProxyAuditMetadata;
18pub use crate::runtime::NetworkProxyState;
19#[cfg(test)]
20pub(crate) use crate::runtime::network_proxy_state_for_policy;
21
22#[derive(Debug, Default, Clone, PartialEq, Eq)]
23pub struct NetworkProxyConstraints {
24    pub enabled: Option<bool>,
25    pub mode: Option<NetworkMode>,
26    pub allow_upstream_proxy: Option<bool>,
27    pub dangerously_allow_non_loopback_proxy: Option<bool>,
28    pub dangerously_allow_all_unix_sockets: Option<bool>,
29    pub allowed_domains: Option<Vec<String>>,
30    pub allowlist_expansion_enabled: Option<bool>,
31    pub denied_domains: Option<Vec<String>>,
32    pub denylist_expansion_enabled: Option<bool>,
33    pub allow_unix_sockets: Option<Vec<String>>,
34    pub allow_local_binding: Option<bool>,
35}
36
37#[derive(Debug, Clone, Deserialize)]
38pub struct PartialNetworkProxyConfig {
39    #[serde(default)]
40    pub network: PartialNetworkConfig,
41}
42
43#[derive(Debug, Default, Clone, Deserialize)]
44pub struct PartialNetworkConfig {
45    pub enabled: Option<bool>,
46    pub mode: Option<NetworkMode>,
47    pub allow_upstream_proxy: Option<bool>,
48    pub dangerously_allow_non_loopback_proxy: Option<bool>,
49    pub dangerously_allow_all_unix_sockets: Option<bool>,
50    #[serde(default)]
51    pub domains: Option<NetworkDomainPermissions>,
52    #[serde(default)]
53    pub unix_sockets: Option<NetworkUnixSocketPermissions>,
54    pub allow_local_binding: Option<bool>,
55}
56
57pub fn build_config_state(
58    config: NetworkProxyConfig,
59    constraints: NetworkProxyConstraints,
60) -> anyhow::Result<ConfigState> {
61    crate::config::validate_unix_socket_allowlist_paths(&config)?;
62    let allowed_domains = config.network.allowed_domains().unwrap_or_default();
63    let denied_domains = config.network.denied_domains().unwrap_or_default();
64    validate_non_global_wildcard_domain_patterns("network.denied_domains", &denied_domains)
65        .map_err(NetworkProxyConstraintError::into_anyhow)?;
66    let deny_set = compile_denylist_globset(&denied_domains)?;
67    let allow_set = compile_allowlist_globset(&allowed_domains)?;
68    let mitm = if config.network.mitm {
69        Some(Arc::new(MitmState::new(
70            config.network.allow_upstream_proxy,
71        )?))
72    } else {
73        None
74    };
75    Ok(ConfigState {
76        config,
77        allow_set,
78        deny_set,
79        mitm,
80        constraints,
81        blocked: std::collections::VecDeque::new(),
82        blocked_total: 0,
83    })
84}
85
86pub fn validate_policy_against_constraints(
87    config: &NetworkProxyConfig,
88    constraints: &NetworkProxyConstraints,
89) -> Result<(), NetworkProxyConstraintError> {
90    fn invalid_value(
91        field_name: &'static str,
92        candidate: impl Into<String>,
93        allowed: impl Into<String>,
94    ) -> NetworkProxyConstraintError {
95        NetworkProxyConstraintError::InvalidValue {
96            field_name,
97            candidate: candidate.into(),
98            allowed: allowed.into(),
99        }
100    }
101
102    fn validate<T>(
103        candidate: T,
104        validator: impl FnOnce(&T) -> Result<(), NetworkProxyConstraintError>,
105    ) -> Result<(), NetworkProxyConstraintError> {
106        validator(&candidate)
107    }
108
109    let enabled = config.network.enabled;
110    let config_allowed_domains = config.network.allowed_domains().unwrap_or_default();
111    let config_denied_domains = config.network.denied_domains().unwrap_or_default();
112    let denied_domain_overrides: HashSet<String> = config_denied_domains
113        .iter()
114        .map(|entry| entry.to_ascii_lowercase())
115        .collect();
116    let config_allow_unix_sockets = config.network.allow_unix_sockets();
117    validate_non_global_wildcard_domain_patterns("network.denied_domains", &config_denied_domains)?;
118    if let Some(max_enabled) = constraints.enabled {
119        validate(enabled, move |candidate| {
120            if *candidate && !max_enabled {
121                Err(invalid_value(
122                    "network.enabled",
123                    "true",
124                    "false (disabled by managed config)",
125                ))
126            } else {
127                Ok(())
128            }
129        })?;
130    }
131
132    if let Some(max_mode) = constraints.mode {
133        validate(config.network.mode, move |candidate| {
134            if network_mode_rank(*candidate) > network_mode_rank(max_mode) {
135                Err(invalid_value(
136                    "network.mode",
137                    format!("{candidate:?}"),
138                    format!("{max_mode:?} or more restrictive"),
139                ))
140            } else {
141                Ok(())
142            }
143        })?;
144    }
145
146    let allow_upstream_proxy = constraints.allow_upstream_proxy;
147    validate(
148        config.network.allow_upstream_proxy,
149        move |candidate| match allow_upstream_proxy {
150            Some(true) | None => Ok(()),
151            Some(false) => {
152                if *candidate {
153                    Err(invalid_value(
154                        "network.allow_upstream_proxy",
155                        "true",
156                        "false (disabled by managed config)",
157                    ))
158                } else {
159                    Ok(())
160                }
161            }
162        },
163    )?;
164
165    let allow_non_loopback_proxy = constraints.dangerously_allow_non_loopback_proxy;
166    validate(
167        config.network.dangerously_allow_non_loopback_proxy,
168        move |candidate| match allow_non_loopback_proxy {
169            Some(true) | None => Ok(()),
170            Some(false) => {
171                if *candidate {
172                    Err(invalid_value(
173                        "network.dangerously_allow_non_loopback_proxy",
174                        "true",
175                        "false (disabled by managed config)",
176                    ))
177                } else {
178                    Ok(())
179                }
180            }
181        },
182    )?;
183
184    let allow_all_unix_sockets = constraints
185        .dangerously_allow_all_unix_sockets
186        .unwrap_or(constraints.allow_unix_sockets.is_none());
187    validate(
188        config.network.dangerously_allow_all_unix_sockets,
189        move |candidate| {
190            if *candidate && !allow_all_unix_sockets {
191                Err(invalid_value(
192                    "network.dangerously_allow_all_unix_sockets",
193                    "true",
194                    "false (disabled by managed config)",
195                ))
196            } else {
197                Ok(())
198            }
199        },
200    )?;
201
202    if let Some(allow_local_binding) = constraints.allow_local_binding {
203        validate(config.network.allow_local_binding, move |candidate| {
204            if *candidate && !allow_local_binding {
205                Err(invalid_value(
206                    "network.allow_local_binding",
207                    "true",
208                    "false (disabled by managed config)",
209                ))
210            } else {
211                Ok(())
212            }
213        })?;
214    }
215
216    if let Some(allowed_domains) = &constraints.allowed_domains {
217        validate_non_global_wildcard_domain_patterns("network.allowed_domains", allowed_domains)?;
218        match constraints.allowlist_expansion_enabled {
219            Some(true) => {
220                let required_set: HashSet<String> = allowed_domains
221                    .iter()
222                    .map(|entry| entry.to_ascii_lowercase())
223                    .collect();
224                validate(config_allowed_domains, |candidate| {
225                    let candidate_set: HashSet<String> = candidate
226                        .iter()
227                        .map(|entry| entry.to_ascii_lowercase())
228                        .collect();
229                    let missing: Vec<String> = required_set
230                        .iter()
231                        .filter(|entry| {
232                            !candidate_set.contains(*entry)
233                                && !denied_domain_overrides.contains(*entry)
234                        })
235                        .cloned()
236                        .collect();
237                    if missing.is_empty() {
238                        Ok(())
239                    } else {
240                        Err(invalid_value(
241                            "network.allowed_domains",
242                            "missing managed allowed_domains entries",
243                            format!("{missing:?}"),
244                        ))
245                    }
246                })?;
247            }
248            Some(false) => {
249                let required_set: HashSet<String> = allowed_domains
250                    .iter()
251                    .map(|entry| entry.to_ascii_lowercase())
252                    .collect();
253                validate(config_allowed_domains, |candidate| {
254                    let candidate_set: HashSet<String> = candidate
255                        .iter()
256                        .map(|entry| entry.to_ascii_lowercase())
257                        .collect();
258                    let expected_set: HashSet<String> = required_set
259                        .difference(&denied_domain_overrides)
260                        .cloned()
261                        .collect();
262                    if candidate_set == expected_set {
263                        Ok(())
264                    } else {
265                        Err(invalid_value(
266                            "network.allowed_domains",
267                            format!("{candidate:?}"),
268                            "must match managed allowed_domains",
269                        ))
270                    }
271                })?;
272            }
273            None => {
274                let managed_patterns: Vec<DomainPattern> = allowed_domains
275                    .iter()
276                    .map(|entry| DomainPattern::parse_for_constraints(entry))
277                    .collect();
278                validate(config_allowed_domains, move |candidate| {
279                    let mut invalid = Vec::new();
280                    for entry in candidate {
281                        let candidate_pattern = DomainPattern::parse_for_constraints(entry);
282                        if !managed_patterns
283                            .iter()
284                            .any(|managed| managed.allows(&candidate_pattern))
285                        {
286                            invalid.push(entry.clone());
287                        }
288                    }
289                    if invalid.is_empty() {
290                        Ok(())
291                    } else {
292                        Err(invalid_value(
293                            "network.allowed_domains",
294                            format!("{invalid:?}"),
295                            "subset of managed allowed_domains",
296                        ))
297                    }
298                })?;
299            }
300        }
301    }
302
303    if let Some(denied_domains) = &constraints.denied_domains {
304        validate_non_global_wildcard_domain_patterns("network.denied_domains", denied_domains)?;
305        let required_set: HashSet<String> = denied_domains
306            .iter()
307            .map(|s| s.to_ascii_lowercase())
308            .collect();
309        match constraints.denylist_expansion_enabled {
310            Some(false) => {
311                validate(config_denied_domains, move |candidate| {
312                    let candidate_set: HashSet<String> = candidate
313                        .iter()
314                        .map(|entry| entry.to_ascii_lowercase())
315                        .collect();
316                    if candidate_set == required_set {
317                        Ok(())
318                    } else {
319                        Err(invalid_value(
320                            "network.denied_domains",
321                            format!("{candidate:?}"),
322                            "must match managed denied_domains",
323                        ))
324                    }
325                })?;
326            }
327            Some(true) | None => {
328                validate(config_denied_domains, move |candidate| {
329                    let candidate_set: HashSet<String> =
330                        candidate.iter().map(|s| s.to_ascii_lowercase()).collect();
331                    let missing: Vec<String> = required_set
332                        .iter()
333                        .filter(|entry| !candidate_set.contains(*entry))
334                        .cloned()
335                        .collect();
336                    if missing.is_empty() {
337                        Ok(())
338                    } else {
339                        Err(invalid_value(
340                            "network.denied_domains",
341                            "missing managed denied_domains entries",
342                            format!("{missing:?}"),
343                        ))
344                    }
345                })?;
346            }
347        }
348    }
349
350    if let Some(allow_unix_sockets) = &constraints.allow_unix_sockets {
351        let allowed_set: HashSet<String> = allow_unix_sockets
352            .iter()
353            .map(|s| s.to_ascii_lowercase())
354            .collect();
355        validate(config_allow_unix_sockets, move |candidate| {
356            let mut invalid = Vec::new();
357            for entry in candidate {
358                if !allowed_set.contains(&entry.to_ascii_lowercase()) {
359                    invalid.push(entry.clone());
360                }
361            }
362            if invalid.is_empty() {
363                Ok(())
364            } else {
365                Err(invalid_value(
366                    "network.allow_unix_sockets",
367                    format!("{invalid:?}"),
368                    "subset of managed allow_unix_sockets",
369                ))
370            }
371        })?;
372    }
373
374    Ok(())
375}
376
377fn validate_non_global_wildcard_domain_patterns(
378    field_name: &'static str,
379    patterns: &[String],
380) -> Result<(), NetworkProxyConstraintError> {
381    if let Some(pattern) = patterns
382        .iter()
383        .find(|pattern| is_global_wildcard_domain_pattern(pattern))
384    {
385        return Err(NetworkProxyConstraintError::InvalidValue {
386            field_name,
387            candidate: pattern.trim().to_string(),
388            allowed: "exact hosts or scoped wildcards like *.example.com or **.example.com"
389                .to_string(),
390        });
391    }
392    Ok(())
393}
394
395#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
396pub enum NetworkProxyConstraintError {
397    #[error("invalid value for {field_name}: {candidate} (allowed {allowed})")]
398    InvalidValue {
399        field_name: &'static str,
400        candidate: String,
401        allowed: String,
402    },
403}
404
405impl NetworkProxyConstraintError {
406    pub fn into_anyhow(self) -> anyhow::Error {
407        anyhow::anyhow!(self)
408    }
409}
410
411fn network_mode_rank(mode: NetworkMode) -> u8 {
412    match mode {
413        NetworkMode::Limited => 0,
414        NetworkMode::Full => 1,
415    }
416}
417
418#[cfg(test)]
419mod tests {}