1use ipnet::{IpNet, Ipv4Net, Ipv6Net};
9use iprange::IpRange;
10use regex::bytes::{Regex, RegexBuilder, RegexSet, RegexSetBuilder};
11pub use socks5_impl::protocol::Address;
12use std::{
13 borrow::Cow,
14 collections::HashSet,
15 fmt,
16 fs::File,
17 io::{self, BufRead, BufReader, Error},
18 net::{IpAddr, SocketAddr},
19 path::{Path, PathBuf},
20 str,
21};
22
23mod sub_domains_tree;
24use sub_domains_tree::SubDomainsTree;
25
26#[derive(Debug, Copy, Clone, Eq, PartialEq)]
28pub enum Mode {
29 BlackList,
31 WhiteList,
33}
34
35#[derive(Clone)]
36struct Rules {
37 ipv4: IpRange<Ipv4Net>,
38 ipv6: IpRange<Ipv6Net>,
39 rule_regex: RegexSet,
40 rule_set: HashSet<String>,
41 rule_tree: SubDomainsTree,
42}
43
44impl fmt::Debug for Rules {
45 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46 write!(f, "Rules {{ ipv4: {:?}, ipv6: {:?}, rule_regex: [", self.ipv4, self.ipv6)?;
47
48 let max_len = 2;
49 let has_more = self.rule_regex.len() > max_len;
50
51 for (idx, r) in self.rule_regex.patterns().iter().take(max_len).enumerate() {
52 if idx > 0 {
53 f.write_str(", ")?;
54 }
55 f.write_str(r)?;
56 }
57
58 if has_more {
59 f.write_str(", ...")?;
60 }
61
62 write!(f, "], rule_set: [")?;
63
64 let has_more = self.rule_set.len() > max_len;
65 for (idx, r) in self.rule_set.iter().take(max_len).enumerate() {
66 if idx > 0 {
67 f.write_str(", ")?;
68 }
69 f.write_str(r)?;
70 }
71
72 if has_more {
73 f.write_str(", ...")?;
74 }
75
76 write!(f, "], rule_tree: {:?} }}", self.rule_tree)
77 }
78}
79
80impl Rules {
81 fn new(
83 mut ipv4: IpRange<Ipv4Net>,
84 mut ipv6: IpRange<Ipv6Net>,
85 rule_regex: RegexSet,
86 rule_set: HashSet<String>,
87 rule_tree: SubDomainsTree,
88 ) -> Rules {
89 ipv4.simplify();
91 ipv6.simplify();
92
93 Rules {
94 ipv4,
95 ipv6,
96 rule_regex,
97 rule_set,
98 rule_tree,
99 }
100 }
101
102 #[allow(dead_code)]
104 fn check_address_matched(&self, addr: &Address) -> bool {
105 match *addr {
106 Address::SocketAddress(ref saddr) => self.check_ip_matched(&saddr.ip()),
107 Address::DomainAddress(ref domain, ..) => self.check_host_matched(domain),
108 }
109 }
110
111 fn check_ip_matched(&self, addr: &IpAddr) -> bool {
113 match addr {
114 IpAddr::V4(v4) => {
115 if self.ipv4.contains(v4) {
116 return true;
117 }
118
119 let mapped_ipv6 = v4.to_ipv6_mapped();
120 self.ipv6.contains(&mapped_ipv6)
121 }
122 IpAddr::V6(v6) => {
123 if self.ipv6.contains(v6) {
124 return true;
125 }
126
127 if let Some(mapped_ipv4) = v6.to_ipv4_mapped() {
128 return self.ipv4.contains(&mapped_ipv4);
129 }
130
131 false
132 }
133 }
134 }
135
136 fn check_host_matched(&self, host: &str) -> bool {
138 let host = host.trim_end_matches('.'); self.rule_set.contains(host) || self.rule_tree.contains(host) || self.rule_regex.is_match(host.as_bytes())
140 }
141
142 fn is_ip_empty(&self) -> bool {
144 self.ipv4.is_empty() && self.ipv6.is_empty()
145 }
146
147 fn is_host_empty(&self) -> bool {
149 self.rule_set.is_empty() && self.rule_tree.is_empty() && self.rule_regex.is_empty()
150 }
151}
152
153struct ParsingRules {
154 name: &'static str,
155 ipv4: IpRange<Ipv4Net>,
156 ipv6: IpRange<Ipv6Net>,
157 rules_regex: Vec<String>,
158 rules_set: HashSet<String>,
159 rules_tree: SubDomainsTree,
160}
161
162impl ParsingRules {
163 fn new(name: &'static str) -> Self {
164 ParsingRules {
165 name,
166 ipv4: IpRange::new(),
167 ipv6: IpRange::new(),
168 rules_regex: Vec::new(),
169 rules_set: HashSet::new(),
170 rules_tree: SubDomainsTree::new(),
171 }
172 }
173
174 fn add_ipv4_rule(&mut self, rule: impl Into<Ipv4Net>) {
175 let rule = rule.into();
176 self.ipv4.add(rule);
178 }
179
180 fn add_ipv6_rule(&mut self, rule: impl Into<Ipv6Net>) {
181 let rule = rule.into();
182 log::trace!("IPV6-RULE {rule}");
183 self.ipv6.add(rule);
184 }
185
186 fn add_regex_rule(&mut self, mut rule: String) {
187 static TREE_SET_RULE_EQUIV: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
188 let regex = TREE_SET_RULE_EQUIV.get_or_init(|| {
189 RegexBuilder::new(r#"^(?:(?:\((?:\?:)?\^\|\\\.\)|(?:\^\.(?:\+|\*))?\\\.)((?:[\w-]+(?:\\\.)?)+)|\^((?:[\w-]+(?:\\\.)?)+))\$?$"#)
190 .unicode(false)
191 .build()
192 .unwrap()
193 });
194
195 if let Some(caps) = regex.captures(rule.as_bytes()) {
196 if let Some(tree_rule) = caps.get(1) {
197 if let Ok(tree_rule) = str::from_utf8(tree_rule.as_bytes()) {
198 let tree_rule = tree_rule.replace("\\.", ".");
199 if self.add_tree_rule_inner(&tree_rule).is_ok() {
200 return;
202 }
203 }
204 } else if let Some(set_rule) = caps.get(2) {
205 if let Ok(set_rule) = str::from_utf8(set_rule.as_bytes()) {
206 let set_rule = set_rule.replace("\\.", ".");
207 if self.add_set_rule_inner(&set_rule).is_ok() {
208 return;
210 }
211 }
212 }
213 }
214
215 rule.make_ascii_lowercase();
218
219 self.rules_regex.push(rule);
222 }
223
224 #[inline]
225 fn add_set_rule(&mut self, rule: &str) -> io::Result<()> {
226 log::trace!("SET-RULE {rule}");
227 self.add_set_rule_inner(rule)
228 }
229
230 fn add_set_rule_inner(&mut self, rule: &str) -> io::Result<()> {
231 self.rules_set.insert(self.check_is_ascii(rule)?.to_ascii_lowercase());
232 Ok(())
233 }
234
235 #[inline]
236 fn add_tree_rule(&mut self, rule: &str) -> io::Result<()> {
237 log::trace!("TREE-RULE {rule}");
238 self.add_tree_rule_inner(rule)
239 }
240
241 fn add_tree_rule_inner(&mut self, rule: &str) -> io::Result<()> {
242 self.rules_tree.insert(self.check_is_ascii(rule)?);
244 Ok(())
245 }
246
247 fn check_is_ascii<'a>(&self, str: &'a str) -> io::Result<&'a str> {
248 if str.is_ascii() {
249 Ok(str.trim_end_matches('.'))
251 } else {
252 Err(Error::other(format!(
253 "{} parsing error: Unicode not allowed here `{str}`",
254 self.name
255 )))
256 }
257 }
258
259 fn compile_regex(name: &'static str, regex_rules: Vec<String>) -> io::Result<RegexSet> {
260 const REGEX_SIZE_LIMIT: usize = usize::MAX;
261 RegexSetBuilder::new(regex_rules)
262 .size_limit(REGEX_SIZE_LIMIT)
263 .unicode(false)
264 .build()
265 .map_err(|err| Error::other(format!("{name} regex error: {err}")))
266 }
267
268 fn into_rules(self) -> io::Result<Rules> {
269 Ok(Rules::new(
270 self.ipv4,
271 self.ipv6,
272 Self::compile_regex(self.name, self.rules_regex)?,
273 self.rules_set,
274 self.rules_tree,
275 ))
276 }
277}
278
279#[derive(Debug, Clone)]
329pub struct AccessControl {
330 outbound_block: Rules,
331 black_list: Rules,
332 white_list: Rules,
333 mode: Mode,
334 file_path: PathBuf,
335}
336
337impl AccessControl {
338 pub fn load_from_file<P: AsRef<Path>>(p: P) -> io::Result<AccessControl> {
340 log::trace!("ACL loading from {:?}", p.as_ref());
341
342 let file_path_ref = p.as_ref();
343 let file_path = file_path_ref.to_path_buf();
344
345 let fp = File::open(file_path_ref)?;
346 let r = BufReader::new(fp);
347
348 let mut mode = Mode::BlackList;
349
350 let mut outbound_block = ParsingRules::new("[outbound_block_list]");
351 let mut bypass = ParsingRules::new("[black_list] or [bypass_list]");
352 let mut proxy = ParsingRules::new("[white_list] or [proxy_list]");
353 let mut curr = &mut bypass;
354
355 log::trace!("ACL parsing start from mode {mode:?} and black_list / bypass_list");
356
357 for line in r.lines() {
358 let line = line?;
359 if line.is_empty() {
360 continue;
361 }
362
363 if line.starts_with('#') {
365 continue;
366 }
367
368 let line = line.trim();
369
370 if !line.is_ascii() {
371 log::warn!("ACL rule {line} containing non-ASCII characters, skipped");
372 continue;
373 }
374
375 if let Some(rule) = line.strip_prefix("||") {
376 curr.add_tree_rule(rule)?;
377 continue;
378 }
379
380 if let Some(rule) = line.strip_prefix('|') {
381 curr.add_set_rule(rule)?;
382 continue;
383 }
384
385 match line {
386 "[reject_all]" | "[bypass_all]" => {
387 mode = Mode::WhiteList;
388 log::trace!("switch to mode {mode:?}");
389 }
390 "[accept_all]" | "[proxy_all]" => {
391 mode = Mode::BlackList;
392 log::trace!("switch to mode {mode:?}");
393 }
394 "[outbound_block_list]" => {
395 curr = &mut outbound_block;
396 log::trace!("loading outbound_block_list");
397 }
398 "[black_list]" | "[bypass_list]" => {
399 curr = &mut bypass;
400 log::trace!("loading black_list / bypass_list");
401 }
402 "[white_list]" | "[proxy_list]" => {
403 curr = &mut proxy;
404 log::trace!("loading white_list / proxy_list");
405 }
406 _ => {
407 match line.parse::<IpNet>() {
408 Ok(IpNet::V4(v4)) => {
409 curr.add_ipv4_rule(v4);
410 }
411 Ok(IpNet::V6(v6)) => {
412 curr.add_ipv6_rule(v6);
413 }
414 Err(..) => {
415 match line.parse::<IpAddr>() {
417 Ok(IpAddr::V4(v4)) => {
418 curr.add_ipv4_rule(v4);
419 }
420 Ok(IpAddr::V6(v6)) => {
421 curr.add_ipv6_rule(v6);
422 }
423 Err(..) => {
424 curr.add_regex_rule(line.to_owned());
425 }
426 }
427 }
428 }
429 }
430 }
431 }
432
433 Ok(AccessControl {
434 outbound_block: outbound_block.into_rules()?,
435 black_list: bypass.into_rules()?,
436 white_list: proxy.into_rules()?,
437 mode,
438 file_path,
439 })
440 }
441
442 pub fn file_path(&self) -> &Path {
444 &self.file_path
445 }
446
447 pub fn check_host_in_proxy_list(&self, host: &str) -> Option<bool> {
455 let host = Self::convert_to_ascii(host);
456 self.check_ascii_host_in_proxy_list(&host)
457 }
458
459 pub fn check_ascii_host_in_proxy_list(&self, host: &str) -> Option<bool> {
467 if self.white_list.check_host_matched(host) {
469 return Some(true);
470 }
471 if self.black_list.check_host_matched(host) {
473 return Some(false);
474 }
475 None
476 }
477
478 pub fn is_ip_empty(&self) -> bool {
480 match self.mode {
481 Mode::BlackList => self.black_list.is_ip_empty(),
482 Mode::WhiteList => self.white_list.is_ip_empty(),
483 }
484 }
485
486 pub fn is_host_empty(&self) -> bool {
488 self.black_list.is_host_empty() && self.white_list.is_host_empty()
489 }
490
491 pub fn check_ip_in_proxy_list(&self, ip: &IpAddr) -> bool {
493 match self.mode {
494 Mode::BlackList => !self.black_list.check_ip_matched(ip),
495 Mode::WhiteList => self.white_list.check_ip_matched(ip),
496 }
497 }
498
499 pub fn is_default_in_proxy_list(&self) -> bool {
505 match self.mode {
506 Mode::BlackList => true,
507 Mode::WhiteList => false,
508 }
509 }
510
511 fn convert_to_ascii(host: &str) -> Cow<str> {
514 idna::domain_to_ascii(host).map(From::from).unwrap_or_else(|_| host.into())
515 }
516
517 pub async fn check_target_bypassed(&self, addr: &Address) -> bool {
521 match *addr {
522 Address::SocketAddress(ref addr) => !self.check_ip_in_proxy_list(&addr.ip()),
523 Address::DomainAddress(ref host, port) => {
525 if let Some(value) = self.check_host_in_proxy_list(host) {
526 return !value;
527 }
528 if self.is_ip_empty() {
529 return !self.is_default_in_proxy_list();
530 }
531 if let Ok(vaddr) = dns_resolve(host, port).await {
532 for addr in vaddr {
533 if !self.check_ip_in_proxy_list(&addr.ip()) {
534 return true;
535 }
536 }
537 }
538 false
539 }
540 }
541 }
542
543 pub fn check_client_blocked(&self, addr: &SocketAddr) -> bool {
545 match self.mode {
546 Mode::BlackList => {
547 self.black_list.check_ip_matched(&addr.ip())
549 }
550 Mode::WhiteList => {
551 !self.white_list.check_ip_matched(&addr.ip())
553 }
554 }
555 }
556
557 pub async fn check_outbound_blocked(&self, outbound: &Address) -> bool {
562 match outbound {
563 Address::SocketAddress(saddr) => self.outbound_block.check_ip_matched(&saddr.ip()),
564 Address::DomainAddress(host, port) => {
565 if self.outbound_block.check_host_matched(&Self::convert_to_ascii(host)) {
566 return true;
567 }
568
569 if let Ok(vaddr) = dns_resolve(host, *port).await {
570 for addr in vaddr {
571 if self.outbound_block.check_ip_matched(&addr.ip()) {
572 return true;
573 }
574 }
575 }
576
577 false
578 }
579 }
580 }
581}
582
583async fn dns_resolve(domain: &str, port: u16) -> std::io::Result<Vec<std::net::SocketAddr>> {
584 let addrs = tokio::net::lookup_host((domain, port)).await?;
585 Ok(addrs.collect())
586}
587
588#[tokio::test]
589async fn test_dns_resolve() {
590 let addrs = dns_resolve("baidu.com", 80).await.unwrap();
591 println!("Resolved addresses: {addrs:?}");
592 assert!(!addrs.is_empty());
593
594 let addrs = dns_resolve("localhost", 80).await.unwrap();
595 println!("Resolved addresses: {addrs:?}");
596 assert!(!addrs.is_empty());
597
598 let addrs = dns_resolve("123.45.67.89", 65535).await.unwrap();
599 println!("Resolved addresses: {addrs:?}");
600 assert!(!addrs.is_empty());
601
602 let addrs = dns_resolve("xxxxsasasasd", 65535).await;
603 assert!(addrs.is_err());
604}
605
606#[test]
607fn test_acl() {
608 let acl = AccessControl::load_from_file("shadowsocks.acl").unwrap();
609
610 assert!(!acl.is_ip_empty());
611 assert!(!acl.is_host_empty());
612
613 assert!(acl.check_host_in_proxy_list("www.google.com").unwrap());
614 assert!(!acl.check_host_in_proxy_list("www.baidu.com").unwrap_or_default());
615 assert!(acl.check_host_in_proxy_list("sex.com").unwrap());
616 assert!(acl.check_host_in_proxy_list("pornhub.com").unwrap_or_default());
617 assert!(!acl.check_host_in_proxy_list("example.com").unwrap_or_default());
618 assert!(acl.check_host_in_proxy_list("youtube.com").unwrap_or_default());
619}