1use std::net::IpAddr;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{debug, warn};
11use zlayer_spec::{AccessAction, AccessRule, NetworkPolicySpec};
12
13#[derive(Clone)]
23pub struct NetworkPolicyChecker {
24 policies: Arc<RwLock<Vec<NetworkPolicySpec>>>,
25}
26
27impl NetworkPolicyChecker {
28 pub fn new(policies: Arc<RwLock<Vec<NetworkPolicySpec>>>) -> Self {
30 Self { policies }
31 }
32
33 pub async fn check_access(
40 &self,
41 source_ip: IpAddr,
42 service: &str,
43 deployment: &str,
44 port: u16,
45 ) -> bool {
46 let policies = self.policies.read().await;
47
48 let matching_networks: Vec<&NetworkPolicySpec> = policies
49 .iter()
50 .filter(|p| ip_in_cidrs(source_ip, &p.cidrs))
51 .collect();
52
53 if matching_networks.is_empty() {
55 return true;
56 }
57
58 for network in &matching_networks {
60 for rule in &network.access_rules {
61 if rule_matches(rule, service, deployment, port)
62 && rule.action == AccessAction::Deny
63 {
64 warn!(
65 source = %source_ip,
66 network = %network.name,
67 service = %service,
68 port = %port,
69 "Network policy denied access"
70 );
71 return false;
72 }
73 }
74 }
75
76 for network in &matching_networks {
78 for rule in &network.access_rules {
79 if rule_matches(rule, service, deployment, port)
80 && rule.action == AccessAction::Allow
81 {
82 debug!(
83 source = %source_ip,
84 network = %network.name,
85 service = %service,
86 port = %port,
87 "Network policy allowed access"
88 );
89 return true;
90 }
91 }
92 }
93
94 warn!(
96 source = %source_ip,
97 service = %service,
98 port = %port,
99 "Source in network policy but no matching rule; default deny"
100 );
101 false
102 }
103}
104
105fn ip_in_cidrs(ip: IpAddr, cidrs: &[String]) -> bool {
107 for cidr_str in cidrs {
108 if let Some((net_str, prefix_str)) = cidr_str.split_once('/') {
109 let Ok(net_addr) = net_str.parse::<IpAddr>() else {
110 continue;
111 };
112 let Ok(prefix_len) = prefix_str.parse::<u32>() else {
113 continue;
114 };
115 if cidr_contains(net_addr, prefix_len, ip) {
116 return true;
117 }
118 }
119 }
120 false
121}
122
123fn cidr_contains(network: IpAddr, prefix_len: u32, addr: IpAddr) -> bool {
125 match (network, addr) {
126 (IpAddr::V4(net), IpAddr::V4(ip)) => {
127 let prefix_len = prefix_len.min(32);
128 if prefix_len == 0 {
129 return true;
130 }
131 let mask = u32::MAX.checked_shl(32 - prefix_len).unwrap_or(0);
132 (u32::from(net) & mask) == (u32::from(ip) & mask)
133 }
134 (IpAddr::V6(net), IpAddr::V6(ip)) => {
135 let prefix_len = prefix_len.min(128);
136 if prefix_len == 0 {
137 return true;
138 }
139 let mask = u128::MAX.checked_shl(128 - prefix_len).unwrap_or(0);
140 (u128::from(net) & mask) == (u128::from(ip) & mask)
141 }
142 _ => false, }
144}
145
146fn rule_matches(rule: &AccessRule, service: &str, deployment: &str, port: u16) -> bool {
148 let service_match = rule.service == "*" || rule.service == service;
149 let deployment_match = rule.deployment == "*" || rule.deployment == deployment;
150 let port_match = rule
151 .ports
152 .as_ref()
153 .is_none_or(|ports| ports.contains(&port));
154 service_match && deployment_match && port_match
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use zlayer_spec::{AccessAction, AccessRule, NetworkPolicySpec};
161
162 fn make_policy(name: &str, cidrs: Vec<&str>, rules: Vec<AccessRule>) -> NetworkPolicySpec {
163 NetworkPolicySpec {
164 name: name.to_string(),
165 cidrs: cidrs.into_iter().map(String::from).collect(),
166 access_rules: rules,
167 ..Default::default()
168 }
169 }
170
171 fn allow_rule(service: &str, deployment: &str, ports: Option<Vec<u16>>) -> AccessRule {
172 AccessRule {
173 service: service.to_string(),
174 deployment: deployment.to_string(),
175 ports,
176 action: AccessAction::Allow,
177 }
178 }
179
180 fn deny_rule(service: &str, deployment: &str, ports: Option<Vec<u16>>) -> AccessRule {
181 AccessRule {
182 service: service.to_string(),
183 deployment: deployment.to_string(),
184 ports,
185 action: AccessAction::Deny,
186 }
187 }
188
189 #[tokio::test]
190 async fn test_no_matching_network_allows() {
191 let policies = Arc::new(RwLock::new(vec![make_policy(
192 "corp",
193 vec!["10.0.0.0/8"],
194 vec![allow_rule("api", "*", None)],
195 )]));
196 let checker = NetworkPolicyChecker::new(policies);
197
198 assert!(
200 checker
201 .check_access("192.168.1.1".parse().unwrap(), "api", "*", 8080)
202 .await
203 );
204 }
205
206 #[tokio::test]
207 async fn test_matching_allow_rule() {
208 let policies = Arc::new(RwLock::new(vec![make_policy(
209 "corp",
210 vec!["10.0.0.0/8"],
211 vec![allow_rule("api", "*", None)],
212 )]));
213 let checker = NetworkPolicyChecker::new(policies);
214
215 assert!(
217 checker
218 .check_access("10.1.2.3".parse().unwrap(), "api", "*", 8080)
219 .await
220 );
221 }
222
223 #[tokio::test]
224 async fn test_matching_deny_rule() {
225 let policies = Arc::new(RwLock::new(vec![make_policy(
226 "restricted",
227 vec!["10.0.0.0/8"],
228 vec![deny_rule("admin", "*", None)],
229 )]));
230 let checker = NetworkPolicyChecker::new(policies);
231
232 assert!(
234 !checker
235 .check_access("10.1.2.3".parse().unwrap(), "admin", "*", 443)
236 .await
237 );
238 }
239
240 #[tokio::test]
241 async fn test_deny_takes_priority_over_allow() {
242 let policies = Arc::new(RwLock::new(vec![make_policy(
243 "mixed",
244 vec!["10.0.0.0/8"],
245 vec![
246 allow_rule("api", "*", None),
247 deny_rule("api", "*", Some(vec![9090])),
248 ],
249 )]));
250 let checker = NetworkPolicyChecker::new(policies);
251
252 assert!(
254 checker
255 .check_access("10.1.2.3".parse().unwrap(), "api", "*", 8080)
256 .await
257 );
258
259 assert!(
261 !checker
262 .check_access("10.1.2.3".parse().unwrap(), "api", "*", 9090)
263 .await
264 );
265 }
266
267 #[tokio::test]
268 async fn test_network_but_no_matching_rule_denies() {
269 let policies = Arc::new(RwLock::new(vec![make_policy(
270 "corp",
271 vec!["10.0.0.0/8"],
272 vec![allow_rule("api", "*", None)],
273 )]));
274 let checker = NetworkPolicyChecker::new(policies);
275
276 assert!(
278 !checker
279 .check_access("10.1.2.3".parse().unwrap(), "frontend", "*", 80)
280 .await
281 );
282 }
283
284 #[tokio::test]
285 async fn test_wildcard_service_rule() {
286 let policies = Arc::new(RwLock::new(vec![make_policy(
287 "admin-net",
288 vec!["172.16.0.0/12"],
289 vec![allow_rule("*", "*", None)],
290 )]));
291 let checker = NetworkPolicyChecker::new(policies);
292
293 assert!(
295 checker
296 .check_access("172.16.5.10".parse().unwrap(), "anything", "*", 443)
297 .await
298 );
299 }
300
301 #[tokio::test]
302 async fn test_port_restriction() {
303 let policies = Arc::new(RwLock::new(vec![make_policy(
304 "web",
305 vec!["10.200.0.0/16"],
306 vec![allow_rule("api", "*", Some(vec![80, 443]))],
307 )]));
308 let checker = NetworkPolicyChecker::new(policies);
309
310 assert!(
312 checker
313 .check_access("10.200.1.1".parse().unwrap(), "api", "*", 443)
314 .await
315 );
316
317 assert!(
319 !checker
320 .check_access("10.200.1.1".parse().unwrap(), "api", "*", 8080)
321 .await
322 );
323 }
324
325 #[tokio::test]
326 async fn test_multiple_networks() {
327 let policies = Arc::new(RwLock::new(vec![
328 make_policy(
329 "office",
330 vec!["192.168.1.0/24"],
331 vec![allow_rule("api", "*", None)],
332 ),
333 make_policy(
334 "vpn",
335 vec!["10.200.0.0/16"],
336 vec![allow_rule("*", "*", None)],
337 ),
338 ]));
339 let checker = NetworkPolicyChecker::new(policies);
340
341 assert!(
343 checker
344 .check_access("192.168.1.50".parse().unwrap(), "api", "*", 80)
345 .await
346 );
347 assert!(
348 !checker
349 .check_access("192.168.1.50".parse().unwrap(), "admin", "*", 80)
350 .await
351 );
352
353 assert!(
355 checker
356 .check_access("10.200.5.5".parse().unwrap(), "admin", "*", 80)
357 .await
358 );
359 }
360
361 #[tokio::test]
362 async fn test_empty_policies_allows_all() {
363 let policies = Arc::new(RwLock::new(Vec::new()));
364 let checker = NetworkPolicyChecker::new(policies);
365
366 assert!(
367 checker
368 .check_access("1.2.3.4".parse().unwrap(), "anything", "*", 80)
369 .await
370 );
371 }
372
373 #[test]
374 fn test_ip_in_cidrs_v4() {
375 let cidrs = vec!["10.0.0.0/8".to_string(), "192.168.1.0/24".to_string()];
376
377 assert!(ip_in_cidrs("10.1.2.3".parse().unwrap(), &cidrs));
378 assert!(ip_in_cidrs("192.168.1.100".parse().unwrap(), &cidrs));
379 assert!(!ip_in_cidrs("172.16.0.1".parse().unwrap(), &cidrs));
380 }
381
382 #[test]
383 fn test_ip_in_cidrs_v6() {
384 let cidrs = vec!["fd00::/64".to_string()];
385
386 assert!(ip_in_cidrs("fd00::1".parse().unwrap(), &cidrs));
387 assert!(!ip_in_cidrs("fd01::1".parse().unwrap(), &cidrs));
388 }
389
390 #[test]
391 fn test_ip_in_cidrs_empty() {
392 assert!(!ip_in_cidrs("10.0.0.1".parse().unwrap(), &[]));
393 }
394
395 #[test]
396 fn test_rule_matches_wildcards() {
397 let rule = allow_rule("*", "*", None);
398 assert!(rule_matches(&rule, "any-service", "any-deployment", 12345));
399 }
400
401 #[test]
402 fn test_rule_matches_specific() {
403 let rule = allow_rule("api", "prod", Some(vec![443]));
404
405 assert!(rule_matches(&rule, "api", "prod", 443));
406 assert!(!rule_matches(&rule, "api", "staging", 443));
407 assert!(!rule_matches(&rule, "web", "prod", 443));
408 assert!(!rule_matches(&rule, "api", "prod", 80));
409 }
410}