1use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
2use std::collections::HashMap;
3use std::fmt::Display;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
5use std::sync::Arc;
6
7#[derive(Clone, Debug, Default)]
9pub struct Acls {
10 acls: HashMap<String, Arc<Acl>>,
11}
12
13impl Acls {
14 pub fn new() -> Self {
15 Self {
16 acls: HashMap::new(),
17 }
18 }
19
20 pub fn get_acl(&self, name: &str) -> Option<&Arc<Acl>> {
21 self.acls.get(name)
22 }
23
24 pub fn insert(&mut self, name: String, acl: Acl) {
25 self.acls.insert(name, Arc::new(acl));
26 }
27}
28
29#[derive(Debug, Default, Deserialize)]
47pub struct Acl {
48 pub(crate) entries: Vec<Entry>,
49}
50
51impl Acl {
52 pub fn lookup(&self, ip: IpAddr) -> Option<&Entry> {
59 self.entries.iter().fold(None, |acc, entry| {
60 if let Some(mask) = entry.prefix.is_match(ip) {
61 if acc.is_none_or(|prev_match: &Entry| mask >= prev_match.prefix.mask) {
62 return Some(entry);
63 }
64 }
65 acc
66 })
67 }
68}
69
70#[derive(Debug, Deserialize, Serialize, PartialEq)]
72pub struct Entry {
73 prefix: Prefix,
74 action: Action,
75}
76
77#[derive(Debug, PartialEq)]
79pub struct Prefix {
80 ip: IpAddr,
81 mask: u8,
82}
83
84impl Prefix {
85 pub(crate) fn new(ip: IpAddr, mask: u8) -> Self {
86 let (ip, mask) = match ip {
88 IpAddr::V4(v4) => {
89 let mask = mask.clamp(1, 32);
90 let bit_mask = u32::MAX << (32 - mask);
91 (
92 IpAddr::V4(Ipv4Addr::from_bits(v4.to_bits() & bit_mask)),
93 mask,
94 )
95 }
96 IpAddr::V6(v6) => {
97 let mask = mask.clamp(1, 128);
98 let bit_mask = u128::MAX << (128 - mask);
99 (
100 IpAddr::V6(Ipv6Addr::from_bits(v6.to_bits() & bit_mask)),
101 mask,
102 )
103 }
104 };
105
106 Self { ip, mask }
107 }
108
109 pub(crate) fn is_match(&self, ip: IpAddr) -> Option<u8> {
112 let masked = Self::new(ip, self.mask);
113 if masked.ip == self.ip {
114 Some(self.mask)
115 } else {
116 None
117 }
118 }
119}
120
121impl Display for Prefix {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.write_fmt(format_args!("{}/{}", self.ip, self.mask))
124 }
125}
126
127impl<'de> Deserialize<'de> for Prefix {
128 fn deserialize<D>(de: D) -> Result<Self, D::Error>
129 where
130 D: Deserializer<'de>,
131 {
132 let v = String::deserialize(de)?;
133 let (ip, mask) = v.split_once('/').ok_or(D::Error::custom(format!(
134 "invalid format '{}': want IP/MASK",
135 v
136 )))?;
137
138 let mask = mask
139 .parse::<u8>()
140 .map_err(|err| D::Error::custom(format!("invalid prefix {}: {}", mask, err)))?;
141
142 let ip = match ip.contains(':') {
144 false => {
145 if !(1..=32).contains(&mask) {
146 return Err(D::Error::custom(format!(
147 "mask outside allowed range [1, 32]: {}",
148 mask
149 )));
150 }
151 ip.parse::<Ipv4Addr>().map(IpAddr::V4)
152 }
153 true => {
154 if !(1..=128).contains(&mask) {
155 return Err(D::Error::custom(format!(
156 "mask outside allowed range [1, 128]: {}",
157 mask
158 )));
159 }
160 ip.parse::<Ipv6Addr>().map(IpAddr::V6)
161 }
162 }
163 .map_err(|err| D::Error::custom(format!("invalid ip address {}: {}", ip, err)))?;
164
165 Ok(Self::new(ip, mask))
166 }
167}
168
169impl Serialize for Prefix {
170 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
171 where
172 S: Serializer,
173 {
174 serializer.serialize_str(format!("{}", self).as_str())
175 }
176}
177
178const ACTION_ALLOW: &str = "ALLOW";
179const ACTION_BLOCK: &str = "BLOCK";
180
181#[derive(Clone, Debug, PartialEq)]
183pub enum Action {
184 Allow,
185 Block,
186 Other(String),
187}
188
189impl<'de> Deserialize<'de> for Action {
190 fn deserialize<D>(de: D) -> Result<Self, D::Error>
191 where
192 D: Deserializer<'de>,
193 {
194 let action = String::deserialize(de)?;
195 Ok(match action.to_uppercase().as_str() {
196 ACTION_ALLOW => Self::Allow,
197 ACTION_BLOCK => Self::Block,
198 _ => Self::Other(action),
199 })
200 }
201}
202
203impl Serialize for Action {
204 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
205 where
206 S: Serializer,
207 {
208 match self {
209 Self::Allow => serializer.serialize_str(ACTION_ALLOW),
210 Self::Block => serializer.serialize_str(ACTION_BLOCK),
211 Self::Other(other) => serializer.serialize_str(format!("Other({})", other).as_str()),
212 }
213 }
214}
215
216#[test]
217fn prefix_is_match() {
218 let prefix = Prefix::new(Ipv4Addr::new(192, 168, 100, 0).into(), 16);
219
220 assert_eq!(
221 prefix.is_match(Ipv4Addr::new(192, 168, 100, 0).into()),
222 Some(16)
223 );
224 assert_eq!(
225 prefix.is_match(Ipv4Addr::new(192, 168, 200, 200).into()),
226 Some(16)
227 );
228
229 assert_eq!(prefix.is_match(Ipv4Addr::new(192, 167, 0, 0).into()), None);
230 assert_eq!(prefix.is_match(Ipv4Addr::new(192, 169, 0, 0).into()), None);
231
232 let prefix = Prefix::new(Ipv6Addr::new(0xFACE, 0, 0, 0, 0, 0, 0, 0).into(), 16);
233 assert_eq!(
234 prefix.is_match(Ipv6Addr::new(0xFACE, 1, 2, 3, 4, 5, 6, 7).into()),
235 Some(16)
236 );
237
238 let v4 = Ipv4Addr::new(192, 168, 200, 200);
239 let v4_as_v6 = v4.to_ipv6_mapped();
240
241 assert_eq!(Prefix::new(v4.into(), 8).is_match(v4_as_v6.into()), None);
242 assert_eq!(Prefix::new(v4_as_v6.into(), 8).is_match(v4.into()), None);
243}
244
245#[test]
246fn acl_lookup() {
247 let acl = Acl {
248 entries: vec![
249 Entry {
250 prefix: Prefix::new(Ipv4Addr::new(192, 168, 100, 0).into(), 16),
251 action: Action::Block,
252 },
253 Entry {
254 prefix: Prefix::new(Ipv4Addr::new(192, 168, 100, 0).into(), 24),
255 action: Action::Block,
256 },
257 Entry {
258 prefix: Prefix::new(Ipv4Addr::new(192, 168, 100, 0).into(), 8),
259 action: Action::Block,
260 },
261 ],
262 };
263
264 match acl.lookup(Ipv4Addr::new(192, 168, 100, 1).into()) {
265 Some(lookup_match) => {
266 assert_eq!(acl.entries[1], *lookup_match);
267 }
268 None => panic!("expected lookup match"),
269 };
270
271 match acl.lookup(Ipv4Addr::new(192, 168, 200, 1).into()) {
272 Some(lookup_match) => {
273 assert_eq!(acl.entries[0], *lookup_match);
274 }
275 None => panic!("expected lookup match"),
276 };
277
278 match acl.lookup(Ipv4Addr::new(192, 1, 1, 1).into()) {
279 Some(lookup_match) => {
280 assert_eq!(acl.entries[2], *lookup_match);
281 }
282 None => panic!("expected lookup match"),
283 };
284
285 if let Some(lookup_match) = acl.lookup(Ipv4Addr::new(1, 1, 1, 1).into()) {
286 panic!("expected no lookup match, got {:?}", lookup_match)
287 };
288}
289
290#[test]
291fn acl_json_parse() {
292 let input = r#"
296 { "entries": [
297 { "op": "create", "prefix": "1.2.3.0/24", "action": "BLOCK" },
298 { "op": "update", "prefix": "192.168.0.0/16", "action": "BLOCK" },
299 { "op": "create", "prefix": "23.23.23.23/32", "action": "ALLOW" },
300 { "op": "update", "prefix": "1.2.3.4/32", "action": "ALLOW" },
301 { "op": "update", "prefix": "1.2.3.4/8", "action": "ALLOW" }
302 ]}
303 "#;
304 let acl: Acl = serde_json::from_str(input).expect("can decode");
305
306 let want = vec![
307 Entry {
308 prefix: Prefix {
309 ip: IpAddr::V4(Ipv4Addr::new(1, 2, 3, 0)),
310 mask: 24,
311 },
312 action: Action::Block,
313 },
314 Entry {
315 prefix: Prefix {
316 ip: IpAddr::V4(Ipv4Addr::new(192, 168, 0, 0)),
317 mask: 16,
318 },
319 action: Action::Block,
320 },
321 Entry {
322 prefix: Prefix {
323 ip: IpAddr::V4(Ipv4Addr::new(23, 23, 23, 23)),
324 mask: 32,
325 },
326 action: Action::Allow,
327 },
328 Entry {
329 prefix: Prefix {
330 ip: IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
331 mask: 32,
332 },
333 action: Action::Allow,
334 },
335 Entry {
336 prefix: Prefix {
337 ip: IpAddr::V4(Ipv4Addr::new(1, 0, 0, 0)),
338 mask: 8,
339 },
340 action: Action::Allow,
341 },
342 ];
343
344 assert_eq!(acl.entries, want);
345}
346
347#[test]
348fn prefix_json_roundtrip() {
349 let assert_roundtrips = |input: &str, want: &str| {
350 let prefix: Prefix =
351 serde_json::from_str(format!("\"{}\"", input).as_str()).expect("can decode");
352 let got = serde_json::to_string(&prefix).expect("can encode");
353 assert_eq!(
354 got,
355 format!("\"{}\"", want),
356 "'{}' roundtrip: got {}, want {}",
357 input,
358 got,
359 want
360 );
361 };
362
363 assert_roundtrips("255.255.255.255/32", "255.255.255.255/32");
364 assert_roundtrips("255.255.255.255/8", "255.0.0.0/8");
365
366 assert_roundtrips("2002::1234:abcd:ffff:c0a8:101/64", "2002:0:0:1234::/64");
367 assert_roundtrips("2000::AB/32", "2000::/32");
368
369 assert!(serde_json::from_str::<Prefix>("\"1.2.3.4/33\"").is_err());
371 assert!(serde_json::from_str::<Prefix>("\"200::/129\"").is_err());
372 assert!(serde_json::from_str::<Prefix>("\"200::/none\"").is_err());
373
374 assert!(serde_json::from_str::<Prefix>("\"1.2.3.four/16\"").is_err());
376 assert!(serde_json::from_str::<Prefix>("\"200::end/32\"").is_err());
377
378 assert!(serde_json::from_str::<Prefix>("\"1.2.3.4\"").is_err());
380 assert!(serde_json::from_str::<Prefix>("\"200::\"").is_err());
381}
382
383#[test]
384fn action_json_roundtrip() {
385 let assert_roundtrips = |input: &str, want: &str| {
386 let action: Action =
387 serde_json::from_str(format!("\"{}\"", input).as_str()).expect("can decode");
388 let got = serde_json::to_string(&action).expect("can encode");
389 assert_eq!(
390 got,
391 format!("\"{}\"", want),
392 "'{}' roundtrip: got {}, want {}",
393 input,
394 got,
395 want
396 );
397 };
398
399 assert_roundtrips("ALLOW", "ALLOW");
400 assert_roundtrips("allow", "ALLOW");
401 assert_roundtrips("BLOCK", "BLOCK");
402 assert_roundtrips("block", "BLOCK");
403 assert_roundtrips("POTATO", "Other(POTATO)");
404 assert_roundtrips("potato", "Other(potato)");
405}