Skip to main content

zlayer_proxy/
network_policy.rs

1//! Network policy access control for the reverse proxy.
2//!
3//! This module provides [`NetworkPolicyChecker`], which evaluates incoming
4//! requests against [`NetworkPolicySpec`] access rules to determine whether
5//! a source IP is allowed to reach a given service/port combination.
6
7use std::net::IpAddr;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{debug, warn};
11use zlayer_spec::{AccessAction, AccessRule, NetworkPolicySpec};
12
13/// Checks incoming requests against network access policies.
14///
15/// When a request arrives, the checker:
16/// 1. Finds all networks whose CIDRs contain the source IP.
17/// 2. If no networks match, access is allowed (default open).
18/// 3. If any matching network has a deny rule for the target, access is denied.
19/// 4. If any matching network has an allow rule for the target, access is allowed.
20/// 5. If the source belongs to a network but no rules match, access is denied
21///    (having a network policy implies explicit access control).
22#[derive(Clone)]
23pub struct NetworkPolicyChecker {
24    policies: Arc<RwLock<Vec<NetworkPolicySpec>>>,
25}
26
27impl NetworkPolicyChecker {
28    /// Create a new checker backed by the given shared policy list.
29    pub fn new(policies: Arc<RwLock<Vec<NetworkPolicySpec>>>) -> Self {
30        Self { policies }
31    }
32
33    /// Check if `source_ip` is allowed to access a target service on the given port.
34    ///
35    /// Returns `true` if access is allowed, `false` if denied.
36    ///
37    /// The `deployment` parameter exists for forward compatibility with
38    /// per-deployment rules; pass `"*"` when the deployment is unknown.
39    pub async fn check_access(
40        &self,
41        source_ip: IpAddr,
42        service: &str,
43        deployment: &str,
44        port: u16,
45    ) -> bool {
46        let policies = self.policies.read().await;
47
48        let matching_networks: Vec<&NetworkPolicySpec> = policies
49            .iter()
50            .filter(|p| ip_in_cidrs(source_ip, &p.cidrs))
51            .collect();
52
53        // No network policy governs this IP — allow by default.
54        if matching_networks.is_empty() {
55            return true;
56        }
57
58        // Phase 1: explicit deny takes priority.
59        for network in &matching_networks {
60            for rule in &network.access_rules {
61                if rule_matches(rule, service, deployment, port)
62                    && rule.action == AccessAction::Deny
63                {
64                    warn!(
65                        source = %source_ip,
66                        network = %network.name,
67                        service = %service,
68                        port = %port,
69                        "Network policy denied access"
70                    );
71                    return false;
72                }
73            }
74        }
75
76        // Phase 2: explicit allow.
77        for network in &matching_networks {
78            for rule in &network.access_rules {
79                if rule_matches(rule, service, deployment, port)
80                    && rule.action == AccessAction::Allow
81                {
82                    debug!(
83                        source = %source_ip,
84                        network = %network.name,
85                        service = %service,
86                        port = %port,
87                        "Network policy allowed access"
88                    );
89                    return true;
90                }
91            }
92        }
93
94        // Source is governed by a network policy but no rules matched — default deny.
95        warn!(
96            source = %source_ip,
97            service = %service,
98            port = %port,
99            "Source in network policy but no matching rule; default deny"
100        );
101        false
102    }
103}
104
105/// Returns `true` if `ip` falls within any of the given CIDR strings.
106fn ip_in_cidrs(ip: IpAddr, cidrs: &[String]) -> bool {
107    for cidr_str in cidrs {
108        if let Some((net_str, prefix_str)) = cidr_str.split_once('/') {
109            let Ok(net_addr) = net_str.parse::<IpAddr>() else {
110                continue;
111            };
112            let Ok(prefix_len) = prefix_str.parse::<u32>() else {
113                continue;
114            };
115            if cidr_contains(net_addr, prefix_len, ip) {
116                return true;
117            }
118        }
119    }
120    false
121}
122
123/// Returns `true` if `addr` is within the CIDR `network/prefix_len`.
124fn cidr_contains(network: IpAddr, prefix_len: u32, addr: IpAddr) -> bool {
125    match (network, addr) {
126        (IpAddr::V4(net), IpAddr::V4(ip)) => {
127            let prefix_len = prefix_len.min(32);
128            if prefix_len == 0 {
129                return true;
130            }
131            let mask = u32::MAX.checked_shl(32 - prefix_len).unwrap_or(0);
132            (u32::from(net) & mask) == (u32::from(ip) & mask)
133        }
134        (IpAddr::V6(net), IpAddr::V6(ip)) => {
135            let prefix_len = prefix_len.min(128);
136            if prefix_len == 0 {
137                return true;
138            }
139            let mask = u128::MAX.checked_shl(128 - prefix_len).unwrap_or(0);
140            (u128::from(net) & mask) == (u128::from(ip) & mask)
141        }
142        _ => false, // v4 vs v6 mismatch
143    }
144}
145
146/// Check whether a single access rule matches the given target.
147fn rule_matches(rule: &AccessRule, service: &str, deployment: &str, port: u16) -> bool {
148    let service_match = rule.service == "*" || rule.service == service;
149    let deployment_match = rule.deployment == "*" || rule.deployment == deployment;
150    let port_match = rule
151        .ports
152        .as_ref()
153        .is_none_or(|ports| ports.contains(&port));
154    service_match && deployment_match && port_match
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use zlayer_spec::{AccessAction, AccessRule, NetworkPolicySpec};
161
162    fn make_policy(name: &str, cidrs: Vec<&str>, rules: Vec<AccessRule>) -> NetworkPolicySpec {
163        NetworkPolicySpec {
164            name: name.to_string(),
165            cidrs: cidrs.into_iter().map(String::from).collect(),
166            access_rules: rules,
167            ..Default::default()
168        }
169    }
170
171    fn allow_rule(service: &str, deployment: &str, ports: Option<Vec<u16>>) -> AccessRule {
172        AccessRule {
173            service: service.to_string(),
174            deployment: deployment.to_string(),
175            ports,
176            action: AccessAction::Allow,
177        }
178    }
179
180    fn deny_rule(service: &str, deployment: &str, ports: Option<Vec<u16>>) -> AccessRule {
181        AccessRule {
182            service: service.to_string(),
183            deployment: deployment.to_string(),
184            ports,
185            action: AccessAction::Deny,
186        }
187    }
188
189    #[tokio::test]
190    async fn test_no_matching_network_allows() {
191        let policies = Arc::new(RwLock::new(vec![make_policy(
192            "corp",
193            vec!["10.0.0.0/8"],
194            vec![allow_rule("api", "*", None)],
195        )]));
196        let checker = NetworkPolicyChecker::new(policies);
197
198        // 192.168.1.1 is not in 10.0.0.0/8 — should allow.
199        assert!(
200            checker
201                .check_access("192.168.1.1".parse().unwrap(), "api", "*", 8080)
202                .await
203        );
204    }
205
206    #[tokio::test]
207    async fn test_matching_allow_rule() {
208        let policies = Arc::new(RwLock::new(vec![make_policy(
209            "corp",
210            vec!["10.0.0.0/8"],
211            vec![allow_rule("api", "*", None)],
212        )]));
213        let checker = NetworkPolicyChecker::new(policies);
214
215        // 10.1.2.3 is in 10.0.0.0/8 and rule allows "api" — should allow.
216        assert!(
217            checker
218                .check_access("10.1.2.3".parse().unwrap(), "api", "*", 8080)
219                .await
220        );
221    }
222
223    #[tokio::test]
224    async fn test_matching_deny_rule() {
225        let policies = Arc::new(RwLock::new(vec![make_policy(
226            "restricted",
227            vec!["10.0.0.0/8"],
228            vec![deny_rule("admin", "*", None)],
229        )]));
230        let checker = NetworkPolicyChecker::new(policies);
231
232        // 10.1.2.3 is in 10.0.0.0/8 and rule denies "admin" — should deny.
233        assert!(
234            !checker
235                .check_access("10.1.2.3".parse().unwrap(), "admin", "*", 443)
236                .await
237        );
238    }
239
240    #[tokio::test]
241    async fn test_deny_takes_priority_over_allow() {
242        let policies = Arc::new(RwLock::new(vec![make_policy(
243            "mixed",
244            vec!["10.0.0.0/8"],
245            vec![
246                allow_rule("api", "*", None),
247                deny_rule("api", "*", Some(vec![9090])),
248            ],
249        )]));
250        let checker = NetworkPolicyChecker::new(policies);
251
252        // Port 8080 is allowed (no deny matches port 8080).
253        assert!(
254            checker
255                .check_access("10.1.2.3".parse().unwrap(), "api", "*", 8080)
256                .await
257        );
258
259        // Port 9090 is denied (deny rule matches).
260        assert!(
261            !checker
262                .check_access("10.1.2.3".parse().unwrap(), "api", "*", 9090)
263                .await
264        );
265    }
266
267    #[tokio::test]
268    async fn test_network_but_no_matching_rule_denies() {
269        let policies = Arc::new(RwLock::new(vec![make_policy(
270            "corp",
271            vec!["10.0.0.0/8"],
272            vec![allow_rule("api", "*", None)],
273        )]));
274        let checker = NetworkPolicyChecker::new(policies);
275
276        // 10.1.2.3 is in the network, but "frontend" has no matching rule — default deny.
277        assert!(
278            !checker
279                .check_access("10.1.2.3".parse().unwrap(), "frontend", "*", 80)
280                .await
281        );
282    }
283
284    #[tokio::test]
285    async fn test_wildcard_service_rule() {
286        let policies = Arc::new(RwLock::new(vec![make_policy(
287            "admin-net",
288            vec!["172.16.0.0/12"],
289            vec![allow_rule("*", "*", None)],
290        )]));
291        let checker = NetworkPolicyChecker::new(policies);
292
293        // Wildcard service rule should match any service.
294        assert!(
295            checker
296                .check_access("172.16.5.10".parse().unwrap(), "anything", "*", 443)
297                .await
298        );
299    }
300
301    #[tokio::test]
302    async fn test_port_restriction() {
303        let policies = Arc::new(RwLock::new(vec![make_policy(
304            "web",
305            vec!["10.200.0.0/16"],
306            vec![allow_rule("api", "*", Some(vec![80, 443]))],
307        )]));
308        let checker = NetworkPolicyChecker::new(policies);
309
310        // Port 443 — allowed.
311        assert!(
312            checker
313                .check_access("10.200.1.1".parse().unwrap(), "api", "*", 443)
314                .await
315        );
316
317        // Port 8080 — not in ports list, no matching rule — default deny.
318        assert!(
319            !checker
320                .check_access("10.200.1.1".parse().unwrap(), "api", "*", 8080)
321                .await
322        );
323    }
324
325    #[tokio::test]
326    async fn test_multiple_networks() {
327        let policies = Arc::new(RwLock::new(vec![
328            make_policy(
329                "office",
330                vec!["192.168.1.0/24"],
331                vec![allow_rule("api", "*", None)],
332            ),
333            make_policy(
334                "vpn",
335                vec!["10.200.0.0/16"],
336                vec![allow_rule("*", "*", None)],
337            ),
338        ]));
339        let checker = NetworkPolicyChecker::new(policies);
340
341        // Office user can reach "api" but not "admin".
342        assert!(
343            checker
344                .check_access("192.168.1.50".parse().unwrap(), "api", "*", 80)
345                .await
346        );
347        assert!(
348            !checker
349                .check_access("192.168.1.50".parse().unwrap(), "admin", "*", 80)
350                .await
351        );
352
353        // VPN user can reach anything.
354        assert!(
355            checker
356                .check_access("10.200.5.5".parse().unwrap(), "admin", "*", 80)
357                .await
358        );
359    }
360
361    #[tokio::test]
362    async fn test_empty_policies_allows_all() {
363        let policies = Arc::new(RwLock::new(Vec::new()));
364        let checker = NetworkPolicyChecker::new(policies);
365
366        assert!(
367            checker
368                .check_access("1.2.3.4".parse().unwrap(), "anything", "*", 80)
369                .await
370        );
371    }
372
373    #[test]
374    fn test_ip_in_cidrs_v4() {
375        let cidrs = vec!["10.0.0.0/8".to_string(), "192.168.1.0/24".to_string()];
376
377        assert!(ip_in_cidrs("10.1.2.3".parse().unwrap(), &cidrs));
378        assert!(ip_in_cidrs("192.168.1.100".parse().unwrap(), &cidrs));
379        assert!(!ip_in_cidrs("172.16.0.1".parse().unwrap(), &cidrs));
380    }
381
382    #[test]
383    fn test_ip_in_cidrs_v6() {
384        let cidrs = vec!["fd00::/64".to_string()];
385
386        assert!(ip_in_cidrs("fd00::1".parse().unwrap(), &cidrs));
387        assert!(!ip_in_cidrs("fd01::1".parse().unwrap(), &cidrs));
388    }
389
390    #[test]
391    fn test_ip_in_cidrs_empty() {
392        assert!(!ip_in_cidrs("10.0.0.1".parse().unwrap(), &[]));
393    }
394
395    #[test]
396    fn test_rule_matches_wildcards() {
397        let rule = allow_rule("*", "*", None);
398        assert!(rule_matches(&rule, "any-service", "any-deployment", 12345));
399    }
400
401    #[test]
402    fn test_rule_matches_specific() {
403        let rule = allow_rule("api", "prod", Some(vec![443]));
404
405        assert!(rule_matches(&rule, "api", "prod", 443));
406        assert!(!rule_matches(&rule, "api", "staging", 443));
407        assert!(!rule_matches(&rule, "web", "prod", 443));
408        assert!(!rule_matches(&rule, "api", "prod", 80));
409    }
410}