1use std::net::IpAddr;
4
5use bitvec::{
6 order::{BitOrder, Msb0},
7 slice::BitSlice,
8 store::BitStore,
9 view::BitView,
10};
11
12use crate::acl::cidr::Cidr;
13
14struct Trie {
15 root: TrieNode,
16}
17
18impl Trie {
19 pub fn new() -> Self {
20 Trie {
21 root: TrieNode::new(),
22 }
23 }
24
25 pub fn insert_bits<T, O>(&mut self, bits: &BitSlice<T, O>)
26 where
27 T: BitStore,
28 O: BitOrder,
29 {
30 let mut cur = &mut self.root;
31
32 for bit in bits.iter() {
33 match *bit {
34 true => {
35 if cur.right.is_none() {
36 cur.right = Some(Box::new(TrieNode::new()));
37 }
38
39 cur = unsafe { cur.right.as_mut().unwrap_unchecked().as_mut() };
40 }
41 false => {
42 if cur.left.is_none() {
43 cur.left = Some(Box::new(TrieNode::new()));
44 }
45
46 cur = unsafe { cur.left.as_mut().unwrap_unchecked().as_mut() };
47 }
48 }
49 }
50
51 cur.is_complete = true;
52 }
53
54 pub fn contains(&self, data: &[u8]) -> bool {
55 let mut cur = &self.root;
56 let bits = data.view_bits::<Msb0>();
57
58 for bit in bits.iter() {
59 match *bit {
60 true => {
61 if cur.right.is_none() {
62 return false;
63 }
64
65 cur = unsafe { cur.right.as_ref().unwrap_unchecked().as_ref() };
66 }
67 false => {
68 if cur.left.is_none() {
69 return false;
70 }
71
72 cur = unsafe { cur.left.as_ref().unwrap_unchecked().as_ref() };
73 }
74 }
75
76 if cur.is_complete {
77 break;
78 }
79 }
80
81 true
82 }
83
84 pub fn clear(&mut self) {
85 self.root.left = None;
86 self.root.right = None;
87 self.root.is_complete = false;
88 }
89}
90
91struct TrieNode {
92 left: Option<Box<TrieNode>>,
93 right: Option<Box<TrieNode>>,
94 is_complete: bool,
95}
96
97impl TrieNode {
98 pub fn new() -> Self {
99 TrieNode {
100 left: None,
101 right: None,
102 is_complete: false,
103 }
104 }
105}
106
107pub struct IpSet {
109 ipv4: Trie,
110 ipv6: Trie,
111}
112
113impl IpSet {
114 pub fn new() -> Self {
116 IpSet {
117 ipv4: Trie::new(),
118 ipv6: Trie::new(),
119 }
120 }
121
122 pub fn insert(&mut self, cidr: Cidr) {
124 let mask = cidr.mask as usize;
125
126 match cidr.addr {
127 IpAddr::V4(v4) => self
128 .ipv4
129 .insert_bits(&v4.octets().view_bits::<Msb0>()[..mask]),
130 IpAddr::V6(v6) => self
131 .ipv6
132 .insert_bits(&v6.octets().view_bits::<Msb0>()[..mask]),
133 }
134 }
135
136 pub fn contains(&self, addr: IpAddr) -> bool {
138 match addr {
139 IpAddr::V4(v4) => self.ipv4.contains(&v4.octets()),
140 IpAddr::V6(v6) => self.ipv6.contains(&v6.octets()),
141 }
142 }
143
144 pub fn clear(&mut self) {
146 self.ipv4.clear();
147 self.ipv6.clear();
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_ip_set() {
157 let iplist = [
158 "0.0.0.0/8",
159 "127.0.0.0/8",
160 "192.168.0.0/16",
161 "220.160.0.0/11",
162 "255.255.255.255/32",
163 "::1/128",
164 "::ffff:127.0.0.1/104",
165 "fc00::/7",
166 "fe80::/10",
167 "2001:b28:f23d:f001::e/128",
168 ];
169
170 let mut set = IpSet::new();
171
172 for ip in iplist {
173 set.insert(ip.parse().unwrap());
174 }
175
176 assert_eq!(set.contains("0.0.0.1".parse().unwrap()), true);
177 assert_eq!(set.contains("127.0.0.1".parse().unwrap()), true);
178 assert_eq!(set.contains("192.168.0.1".parse().unwrap()), true);
179 assert_eq!(set.contains("220.181.38.148".parse().unwrap()), true);
180 assert_eq!(set.contains("255.255.255.255".parse().unwrap()), true);
181
182 assert_eq!(set.contains("::1".parse().unwrap()), true);
183 assert_eq!(set.contains("::ffff:127.0.0.1".parse().unwrap()), true);
184 assert_eq!(set.contains("fc00::ffff".parse().unwrap()), true);
185 assert_eq!(set.contains("fe80::1234".parse().unwrap()), true);
186 assert_eq!(set.contains("2001:b28:f23d:f001::e".parse().unwrap()), true);
187
188 assert_eq!(set.contains("1.1.1.1".parse().unwrap()), false);
189 assert_eq!(set.contains("128.0.0.1".parse().unwrap()), false);
190 assert_eq!(set.contains("8.7.198.46".parse().unwrap()), false);
191 assert_eq!(set.contains("210.181.38.251".parse().unwrap()), false);
192 assert_eq!(set.contains("::ffff:192.0.0.1".parse().unwrap()), false);
193 assert_eq!(set.contains("2001:b28:f23d:1::f".parse().unwrap()), false);
194
195 set.clear();
196
197 assert_eq!(set.contains("0.0.0.1".parse().unwrap()), false);
198 assert_eq!(set.contains("127.0.0.1".parse().unwrap()), false);
199 assert_eq!(set.contains("192.168.0.1".parse().unwrap()), false);
200 assert_eq!(set.contains("220.181.38.148".parse().unwrap()), false);
201 assert_eq!(set.contains("255.255.255.255".parse().unwrap()), false);
202
203 assert_eq!(set.contains("::1".parse().unwrap()), false);
204 assert_eq!(set.contains("::ffff:127.0.0.1".parse().unwrap()), false);
205 assert_eq!(set.contains("fc00::ffff".parse().unwrap()), false);
206 assert_eq!(set.contains("fe80::1234".parse().unwrap()), false);
207 }
208}