Skip to main content

zerobox_network_proxy/
state.rs

1use crate::config::NetworkDomainPermissions;
2use crate::config::NetworkMode;
3use crate::config::NetworkProxyConfig;
4use crate::config::NetworkUnixSocketPermissions;
5use crate::mitm::MitmState;
6use crate::mitm::MitmUpstreamConfig;
7use crate::policy::DomainPattern;
8use crate::policy::compile_allowlist_globset;
9use crate::policy::compile_denylist_globset;
10use crate::policy::is_global_wildcard_domain_pattern;
11use crate::runtime::ConfigState;
12use serde::Deserialize;
13use std::collections::HashSet;
14use std::sync::Arc;
15
16pub use crate::runtime::BlockedRequest;
17pub use crate::runtime::BlockedRequestArgs;
18pub use crate::runtime::NetworkProxyAuditMetadata;
19pub use crate::runtime::NetworkProxyState;
20#[cfg(test)]
21pub(crate) use crate::runtime::network_proxy_state_for_policy;
22
23#[derive(Debug, Default, Clone, PartialEq, Eq)]
24pub struct NetworkProxyConstraints {
25    pub enabled: Option<bool>,
26    pub mode: Option<NetworkMode>,
27    pub allow_upstream_proxy: Option<bool>,
28    pub dangerously_allow_non_loopback_proxy: Option<bool>,
29    pub dangerously_allow_all_unix_sockets: Option<bool>,
30    pub allowed_domains: Option<Vec<String>>,
31    pub allowlist_expansion_enabled: Option<bool>,
32    pub denied_domains: Option<Vec<String>>,
33    pub denylist_expansion_enabled: Option<bool>,
34    pub allow_unix_sockets: Option<Vec<String>>,
35    pub allow_local_binding: Option<bool>,
36}
37
38#[derive(Debug, Clone, Deserialize)]
39pub struct PartialNetworkProxyConfig {
40    #[serde(default)]
41    pub network: PartialNetworkConfig,
42}
43
44#[derive(Debug, Default, Clone, Deserialize)]
45pub struct PartialNetworkConfig {
46    pub enabled: Option<bool>,
47    pub mode: Option<NetworkMode>,
48    pub allow_upstream_proxy: Option<bool>,
49    pub dangerously_allow_non_loopback_proxy: Option<bool>,
50    pub dangerously_allow_all_unix_sockets: Option<bool>,
51    #[serde(default)]
52    pub domains: Option<NetworkDomainPermissions>,
53    #[serde(default)]
54    pub unix_sockets: Option<NetworkUnixSocketPermissions>,
55    pub allow_local_binding: Option<bool>,
56}
57
58pub fn build_config_state(
59    config: NetworkProxyConfig,
60    constraints: NetworkProxyConstraints,
61) -> anyhow::Result<ConfigState> {
62    crate::config::validate_unix_socket_allowlist_paths(&config)?;
63    let allowed_domains = config.network.allowed_domains().unwrap_or_default();
64    let denied_domains = config.network.denied_domains().unwrap_or_default();
65    validate_non_global_wildcard_domain_patterns("network.denied_domains", &denied_domains)
66        .map_err(NetworkProxyConstraintError::into_anyhow)?;
67    let deny_set = compile_denylist_globset(&denied_domains)?;
68    let allow_set = compile_allowlist_globset(&allowed_domains)?;
69    let mitm = if config.network.mitm {
70        Some(Arc::new(MitmState::new(MitmUpstreamConfig {
71            allow_upstream_proxy: config.network.allow_upstream_proxy,
72            allow_local_binding: config.network.allow_local_binding,
73        })?))
74    } else {
75        None
76    };
77    Ok(ConfigState {
78        config,
79        allow_set,
80        deny_set,
81        mitm,
82        constraints,
83        blocked: std::collections::VecDeque::new(),
84        blocked_total: 0,
85    })
86}
87
88pub fn validate_policy_against_constraints(
89    config: &NetworkProxyConfig,
90    constraints: &NetworkProxyConstraints,
91) -> Result<(), NetworkProxyConstraintError> {
92    fn invalid_value(
93        field_name: &'static str,
94        candidate: impl Into<String>,
95        allowed: impl Into<String>,
96    ) -> NetworkProxyConstraintError {
97        NetworkProxyConstraintError::InvalidValue {
98            field_name,
99            candidate: candidate.into(),
100            allowed: allowed.into(),
101        }
102    }
103
104    fn validate<T>(
105        candidate: T,
106        validator: impl FnOnce(&T) -> Result<(), NetworkProxyConstraintError>,
107    ) -> Result<(), NetworkProxyConstraintError> {
108        validator(&candidate)
109    }
110
111    let enabled = config.network.enabled;
112    let config_allowed_domains = config.network.allowed_domains().unwrap_or_default();
113    let config_denied_domains = config.network.denied_domains().unwrap_or_default();
114    let denied_domain_overrides: HashSet<String> = config_denied_domains
115        .iter()
116        .map(|entry| entry.to_ascii_lowercase())
117        .collect();
118    let config_allow_unix_sockets = config.network.allow_unix_sockets();
119    validate_non_global_wildcard_domain_patterns("network.denied_domains", &config_denied_domains)?;
120    if let Some(max_enabled) = constraints.enabled {
121        validate(enabled, move |candidate| {
122            if *candidate && !max_enabled {
123                Err(invalid_value(
124                    "network.enabled",
125                    "true",
126                    "false (disabled by managed config)",
127                ))
128            } else {
129                Ok(())
130            }
131        })?;
132    }
133
134    if let Some(max_mode) = constraints.mode {
135        validate(config.network.mode, move |candidate| {
136            if network_mode_rank(*candidate) > network_mode_rank(max_mode) {
137                Err(invalid_value(
138                    "network.mode",
139                    format!("{candidate:?}"),
140                    format!("{max_mode:?} or more restrictive"),
141                ))
142            } else {
143                Ok(())
144            }
145        })?;
146    }
147
148    let allow_upstream_proxy = constraints.allow_upstream_proxy;
149    validate(
150        config.network.allow_upstream_proxy,
151        move |candidate| match allow_upstream_proxy {
152            Some(true) | None => Ok(()),
153            Some(false) => {
154                if *candidate {
155                    Err(invalid_value(
156                        "network.allow_upstream_proxy",
157                        "true",
158                        "false (disabled by managed config)",
159                    ))
160                } else {
161                    Ok(())
162                }
163            }
164        },
165    )?;
166
167    let allow_non_loopback_proxy = constraints.dangerously_allow_non_loopback_proxy;
168    validate(
169        config.network.dangerously_allow_non_loopback_proxy,
170        move |candidate| match allow_non_loopback_proxy {
171            Some(true) | None => Ok(()),
172            Some(false) => {
173                if *candidate {
174                    Err(invalid_value(
175                        "network.dangerously_allow_non_loopback_proxy",
176                        "true",
177                        "false (disabled by managed config)",
178                    ))
179                } else {
180                    Ok(())
181                }
182            }
183        },
184    )?;
185
186    let allow_all_unix_sockets = constraints
187        .dangerously_allow_all_unix_sockets
188        .unwrap_or(constraints.allow_unix_sockets.is_none());
189    validate(
190        config.network.dangerously_allow_all_unix_sockets,
191        move |candidate| {
192            if *candidate && !allow_all_unix_sockets {
193                Err(invalid_value(
194                    "network.dangerously_allow_all_unix_sockets",
195                    "true",
196                    "false (disabled by managed config)",
197                ))
198            } else {
199                Ok(())
200            }
201        },
202    )?;
203
204    if let Some(allow_local_binding) = constraints.allow_local_binding {
205        validate(config.network.allow_local_binding, move |candidate| {
206            if *candidate && !allow_local_binding {
207                Err(invalid_value(
208                    "network.allow_local_binding",
209                    "true",
210                    "false (disabled by managed config)",
211                ))
212            } else {
213                Ok(())
214            }
215        })?;
216    }
217
218    if let Some(allowed_domains) = &constraints.allowed_domains {
219        validate_non_global_wildcard_domain_patterns("network.allowed_domains", allowed_domains)?;
220        match constraints.allowlist_expansion_enabled {
221            Some(true) => {
222                let required_set: HashSet<String> = allowed_domains
223                    .iter()
224                    .map(|entry| entry.to_ascii_lowercase())
225                    .collect();
226                validate(config_allowed_domains, |candidate| {
227                    let candidate_set: HashSet<String> = candidate
228                        .iter()
229                        .map(|entry| entry.to_ascii_lowercase())
230                        .collect();
231                    let missing: Vec<String> = required_set
232                        .iter()
233                        .filter(|entry| {
234                            !candidate_set.contains(*entry)
235                                && !denied_domain_overrides.contains(*entry)
236                        })
237                        .cloned()
238                        .collect();
239                    if missing.is_empty() {
240                        Ok(())
241                    } else {
242                        Err(invalid_value(
243                            "network.allowed_domains",
244                            "missing managed allowed_domains entries",
245                            format!("{missing:?}"),
246                        ))
247                    }
248                })?;
249            }
250            Some(false) => {
251                let required_set: HashSet<String> = allowed_domains
252                    .iter()
253                    .map(|entry| entry.to_ascii_lowercase())
254                    .collect();
255                validate(config_allowed_domains, |candidate| {
256                    let candidate_set: HashSet<String> = candidate
257                        .iter()
258                        .map(|entry| entry.to_ascii_lowercase())
259                        .collect();
260                    let expected_set: HashSet<String> = required_set
261                        .difference(&denied_domain_overrides)
262                        .cloned()
263                        .collect();
264                    if candidate_set == expected_set {
265                        Ok(())
266                    } else {
267                        Err(invalid_value(
268                            "network.allowed_domains",
269                            format!("{candidate:?}"),
270                            "must match managed allowed_domains",
271                        ))
272                    }
273                })?;
274            }
275            None => {
276                let managed_patterns: Vec<DomainPattern> = allowed_domains
277                    .iter()
278                    .map(|entry| DomainPattern::parse_for_constraints(entry))
279                    .collect();
280                validate(config_allowed_domains, move |candidate| {
281                    let mut invalid = Vec::new();
282                    for entry in candidate {
283                        let candidate_pattern = DomainPattern::parse_for_constraints(entry);
284                        if !managed_patterns
285                            .iter()
286                            .any(|managed| managed.allows(&candidate_pattern))
287                        {
288                            invalid.push(entry.clone());
289                        }
290                    }
291                    if invalid.is_empty() {
292                        Ok(())
293                    } else {
294                        Err(invalid_value(
295                            "network.allowed_domains",
296                            format!("{invalid:?}"),
297                            "subset of managed allowed_domains",
298                        ))
299                    }
300                })?;
301            }
302        }
303    }
304
305    if let Some(denied_domains) = &constraints.denied_domains {
306        validate_non_global_wildcard_domain_patterns("network.denied_domains", denied_domains)?;
307        let required_set: HashSet<String> = denied_domains
308            .iter()
309            .map(|s| s.to_ascii_lowercase())
310            .collect();
311        match constraints.denylist_expansion_enabled {
312            Some(false) => {
313                validate(config_denied_domains, move |candidate| {
314                    let candidate_set: HashSet<String> = candidate
315                        .iter()
316                        .map(|entry| entry.to_ascii_lowercase())
317                        .collect();
318                    if candidate_set == required_set {
319                        Ok(())
320                    } else {
321                        Err(invalid_value(
322                            "network.denied_domains",
323                            format!("{candidate:?}"),
324                            "must match managed denied_domains",
325                        ))
326                    }
327                })?;
328            }
329            Some(true) | None => {
330                validate(config_denied_domains, move |candidate| {
331                    let candidate_set: HashSet<String> =
332                        candidate.iter().map(|s| s.to_ascii_lowercase()).collect();
333                    let missing: Vec<String> = required_set
334                        .iter()
335                        .filter(|entry| !candidate_set.contains(*entry))
336                        .cloned()
337                        .collect();
338                    if missing.is_empty() {
339                        Ok(())
340                    } else {
341                        Err(invalid_value(
342                            "network.denied_domains",
343                            "missing managed denied_domains entries",
344                            format!("{missing:?}"),
345                        ))
346                    }
347                })?;
348            }
349        }
350    }
351
352    if let Some(allow_unix_sockets) = &constraints.allow_unix_sockets {
353        let allowed_set: HashSet<String> = allow_unix_sockets
354            .iter()
355            .map(|s| s.to_ascii_lowercase())
356            .collect();
357        validate(config_allow_unix_sockets, move |candidate| {
358            let mut invalid = Vec::new();
359            for entry in candidate {
360                if !allowed_set.contains(&entry.to_ascii_lowercase()) {
361                    invalid.push(entry.clone());
362                }
363            }
364            if invalid.is_empty() {
365                Ok(())
366            } else {
367                Err(invalid_value(
368                    "network.allow_unix_sockets",
369                    format!("{invalid:?}"),
370                    "subset of managed allow_unix_sockets",
371                ))
372            }
373        })?;
374    }
375
376    Ok(())
377}
378
379fn validate_non_global_wildcard_domain_patterns(
380    field_name: &'static str,
381    patterns: &[String],
382) -> Result<(), NetworkProxyConstraintError> {
383    if let Some(pattern) = patterns
384        .iter()
385        .find(|pattern| is_global_wildcard_domain_pattern(pattern))
386    {
387        return Err(NetworkProxyConstraintError::InvalidValue {
388            field_name,
389            candidate: pattern.trim().to_string(),
390            allowed: "exact hosts or scoped wildcards like *.example.com or **.example.com"
391                .to_string(),
392        });
393    }
394    Ok(())
395}
396
397#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
398pub enum NetworkProxyConstraintError {
399    #[error("invalid value for {field_name}: {candidate} (allowed {allowed})")]
400    InvalidValue {
401        field_name: &'static str,
402        candidate: String,
403        allowed: String,
404    },
405}
406
407impl NetworkProxyConstraintError {
408    pub fn into_anyhow(self) -> anyhow::Error {
409        anyhow::anyhow!(self)
410    }
411}
412
413fn network_mode_rank(mode: NetworkMode) -> u8 {
414    match mode {
415        NetworkMode::Limited => 0,
416        NetworkMode::Full => 1,
417    }
418}
419
420#[cfg(test)]
421mod tests {}