ss_rs/acl/
ip_set.rs

1//! A set of ip networks.
2
3use std::net::IpAddr;
4
5use bitvec::{
6    order::{BitOrder, Msb0},
7    slice::BitSlice,
8    store::BitStore,
9    view::BitView,
10};
11
12use crate::acl::cidr::Cidr;
13
14struct Trie {
15    root: TrieNode,
16}
17
18impl Trie {
19    pub fn new() -> Self {
20        Trie {
21            root: TrieNode::new(),
22        }
23    }
24
25    pub fn insert_bits<T, O>(&mut self, bits: &BitSlice<T, O>)
26    where
27        T: BitStore,
28        O: BitOrder,
29    {
30        let mut cur = &mut self.root;
31
32        for bit in bits.iter() {
33            match *bit {
34                true => {
35                    if cur.right.is_none() {
36                        cur.right = Some(Box::new(TrieNode::new()));
37                    }
38
39                    cur = unsafe { cur.right.as_mut().unwrap_unchecked().as_mut() };
40                }
41                false => {
42                    if cur.left.is_none() {
43                        cur.left = Some(Box::new(TrieNode::new()));
44                    }
45
46                    cur = unsafe { cur.left.as_mut().unwrap_unchecked().as_mut() };
47                }
48            }
49        }
50
51        cur.is_complete = true;
52    }
53
54    pub fn contains(&self, data: &[u8]) -> bool {
55        let mut cur = &self.root;
56        let bits = data.view_bits::<Msb0>();
57
58        for bit in bits.iter() {
59            match *bit {
60                true => {
61                    if cur.right.is_none() {
62                        return false;
63                    }
64
65                    cur = unsafe { cur.right.as_ref().unwrap_unchecked().as_ref() };
66                }
67                false => {
68                    if cur.left.is_none() {
69                        return false;
70                    }
71
72                    cur = unsafe { cur.left.as_ref().unwrap_unchecked().as_ref() };
73                }
74            }
75
76            if cur.is_complete {
77                break;
78            }
79        }
80
81        true
82    }
83
84    pub fn clear(&mut self) {
85        self.root.left = None;
86        self.root.right = None;
87        self.root.is_complete = false;
88    }
89}
90
91struct TrieNode {
92    left: Option<Box<TrieNode>>,
93    right: Option<Box<TrieNode>>,
94    is_complete: bool,
95}
96
97impl TrieNode {
98    pub fn new() -> Self {
99        TrieNode {
100            left: None,
101            right: None,
102            is_complete: false,
103        }
104    }
105}
106
107/// Stores a set of ip networks.
108pub struct IpSet {
109    ipv4: Trie,
110    ipv6: Trie,
111}
112
113impl IpSet {
114    /// Creates a new ip set.
115    pub fn new() -> Self {
116        IpSet {
117            ipv4: Trie::new(),
118            ipv6: Trie::new(),
119        }
120    }
121
122    /// Inserts a new ip network into the set.
123    pub fn insert(&mut self, cidr: Cidr) {
124        let mask = cidr.mask as usize;
125
126        match cidr.addr {
127            IpAddr::V4(v4) => self
128                .ipv4
129                .insert_bits(&v4.octets().view_bits::<Msb0>()[..mask]),
130            IpAddr::V6(v6) => self
131                .ipv6
132                .insert_bits(&v6.octets().view_bits::<Msb0>()[..mask]),
133        }
134    }
135
136    /// Checks whether the given ip address is in the ip set.
137    pub fn contains(&self, addr: IpAddr) -> bool {
138        match addr {
139            IpAddr::V4(v4) => self.ipv4.contains(&v4.octets()),
140            IpAddr::V6(v6) => self.ipv6.contains(&v6.octets()),
141        }
142    }
143
144    /// Clears the ip set.
145    pub fn clear(&mut self) {
146        self.ipv4.clear();
147        self.ipv6.clear();
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_ip_set() {
157        let iplist = [
158            "0.0.0.0/8",
159            "127.0.0.0/8",
160            "192.168.0.0/16",
161            "220.160.0.0/11",
162            "255.255.255.255/32",
163            "::1/128",
164            "::ffff:127.0.0.1/104",
165            "fc00::/7",
166            "fe80::/10",
167            "2001:b28:f23d:f001::e/128",
168        ];
169
170        let mut set = IpSet::new();
171
172        for ip in iplist {
173            set.insert(ip.parse().unwrap());
174        }
175
176        assert_eq!(set.contains("0.0.0.1".parse().unwrap()), true);
177        assert_eq!(set.contains("127.0.0.1".parse().unwrap()), true);
178        assert_eq!(set.contains("192.168.0.1".parse().unwrap()), true);
179        assert_eq!(set.contains("220.181.38.148".parse().unwrap()), true);
180        assert_eq!(set.contains("255.255.255.255".parse().unwrap()), true);
181
182        assert_eq!(set.contains("::1".parse().unwrap()), true);
183        assert_eq!(set.contains("::ffff:127.0.0.1".parse().unwrap()), true);
184        assert_eq!(set.contains("fc00::ffff".parse().unwrap()), true);
185        assert_eq!(set.contains("fe80::1234".parse().unwrap()), true);
186        assert_eq!(set.contains("2001:b28:f23d:f001::e".parse().unwrap()), true);
187
188        assert_eq!(set.contains("1.1.1.1".parse().unwrap()), false);
189        assert_eq!(set.contains("128.0.0.1".parse().unwrap()), false);
190        assert_eq!(set.contains("8.7.198.46".parse().unwrap()), false);
191        assert_eq!(set.contains("210.181.38.251".parse().unwrap()), false);
192        assert_eq!(set.contains("::ffff:192.0.0.1".parse().unwrap()), false);
193        assert_eq!(set.contains("2001:b28:f23d:1::f".parse().unwrap()), false);
194
195        set.clear();
196
197        assert_eq!(set.contains("0.0.0.1".parse().unwrap()), false);
198        assert_eq!(set.contains("127.0.0.1".parse().unwrap()), false);
199        assert_eq!(set.contains("192.168.0.1".parse().unwrap()), false);
200        assert_eq!(set.contains("220.181.38.148".parse().unwrap()), false);
201        assert_eq!(set.contains("255.255.255.255".parse().unwrap()), false);
202
203        assert_eq!(set.contains("::1".parse().unwrap()), false);
204        assert_eq!(set.contains("::ffff:127.0.0.1".parse().unwrap()), false);
205        assert_eq!(set.contains("fc00::ffff".parse().unwrap()), false);
206        assert_eq!(set.contains("fe80::1234".parse().unwrap()), false);
207    }
208}