1use ipnet::{IpNet, Ipv4Net, Ipv6Net};
12use iprange::IpRange;
13use regex::bytes::{Regex, RegexBuilder, RegexSet, RegexSetBuilder};
14pub use socks5_impl::protocol::Address;
15use std::{
16 borrow::Cow,
17 collections::HashSet,
18 fmt,
19 fs::File,
20 io::{self, BufRead, BufReader, Error},
21 net::{IpAddr, SocketAddr},
22 path::{Path, PathBuf},
23 str,
24};
25
26mod sub_domains_tree;
27use sub_domains_tree::SubDomainsTree;
28
29#[derive(Debug, Copy, Clone, Eq, PartialEq)]
31pub enum TargetDecision {
32 Proxy,
33 Bypass,
34 Block,
35}
36
37impl TargetDecision {
38 pub fn should_proxy(self) -> bool {
39 matches!(self, TargetDecision::Proxy)
40 }
41
42 pub fn should_bypass(self) -> bool {
43 matches!(self, TargetDecision::Bypass)
44 }
45
46 pub fn should_block(self) -> bool {
47 matches!(self, TargetDecision::Block)
48 }
49}
50
51#[derive(Clone)]
52struct Rules {
53 ipv4: IpRange<Ipv4Net>,
54 ipv6: IpRange<Ipv6Net>,
55 rule_regex: RegexSet,
56 rule_set: HashSet<String>,
57 rule_tree: SubDomainsTree,
58}
59
60impl fmt::Debug for Rules {
61 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
62 write!(f, "Rules {{ ipv4: {:?}, ipv6: {:?}, rule_regex: [", self.ipv4, self.ipv6)?;
63
64 let max_len = 2;
65 let has_more = self.rule_regex.len() > max_len;
66
67 for (idx, r) in self.rule_regex.patterns().iter().take(max_len).enumerate() {
68 if idx > 0 {
69 f.write_str(", ")?;
70 }
71 f.write_str(r)?;
72 }
73
74 if has_more {
75 f.write_str(", ...")?;
76 }
77
78 write!(f, "], rule_set: [")?;
79
80 let has_more = self.rule_set.len() > max_len;
81 for (idx, r) in self.rule_set.iter().take(max_len).enumerate() {
82 if idx > 0 {
83 f.write_str(", ")?;
84 }
85 f.write_str(r)?;
86 }
87
88 if has_more {
89 f.write_str(", ...")?;
90 }
91
92 write!(f, "], rule_tree: {:?} }}", self.rule_tree)
93 }
94}
95
96impl Rules {
97 fn new(
99 mut ipv4: IpRange<Ipv4Net>,
100 mut ipv6: IpRange<Ipv6Net>,
101 rule_regex: RegexSet,
102 rule_set: HashSet<String>,
103 rule_tree: SubDomainsTree,
104 ) -> Rules {
105 ipv4.simplify();
107 ipv6.simplify();
108
109 Rules {
110 ipv4,
111 ipv6,
112 rule_regex,
113 rule_set,
114 rule_tree,
115 }
116 }
117
118 #[allow(dead_code)]
120 fn check_address_matched(&self, addr: &Address) -> bool {
121 match *addr {
122 Address::SocketAddress(ref saddr) => self.check_ip_matched(&saddr.ip()),
123 Address::DomainAddress(ref domain, ..) => self.check_host_matched(domain),
124 }
125 }
126
127 fn check_ip_matched(&self, addr: &IpAddr) -> bool {
129 match addr {
130 IpAddr::V4(v4) => {
131 if self.ipv4.contains(v4) {
132 return true;
133 }
134
135 let mapped_ipv6 = v4.to_ipv6_mapped();
136 self.ipv6.contains(&mapped_ipv6)
137 }
138 IpAddr::V6(v6) => {
139 if self.ipv6.contains(v6) {
140 return true;
141 }
142
143 if let Some(mapped_ipv4) = v6.to_ipv4_mapped() {
144 return self.ipv4.contains(&mapped_ipv4);
145 }
146
147 false
148 }
149 }
150 }
151
152 fn check_host_matched(&self, host: &str) -> bool {
154 let host = host.trim_end_matches('.'); self.rule_set.contains(host) || self.rule_tree.contains(host) || self.rule_regex.is_match(host.as_bytes())
156 }
157
158 fn is_ip_empty(&self) -> bool {
160 self.ipv4.is_empty() && self.ipv6.is_empty()
161 }
162
163 fn is_host_empty(&self) -> bool {
165 self.rule_set.is_empty() && self.rule_tree.is_empty() && self.rule_regex.is_empty()
166 }
167}
168
169struct ParsingRules {
170 name: &'static str,
171 ipv4: IpRange<Ipv4Net>,
172 ipv6: IpRange<Ipv6Net>,
173 rules_regex: Vec<String>,
174 rules_set: HashSet<String>,
175 rules_tree: SubDomainsTree,
176}
177
178impl ParsingRules {
179 fn new(name: &'static str) -> Self {
180 ParsingRules {
181 name,
182 ipv4: IpRange::new(),
183 ipv6: IpRange::new(),
184 rules_regex: Vec::new(),
185 rules_set: HashSet::new(),
186 rules_tree: SubDomainsTree::new(),
187 }
188 }
189
190 fn add_ipv4_rule(&mut self, rule: impl Into<Ipv4Net>) {
191 let rule = rule.into();
192 self.ipv4.add(rule);
194 }
195
196 fn add_ipv6_rule(&mut self, rule: impl Into<Ipv6Net>) {
197 let rule = rule.into();
198 log::trace!("IPV6-RULE {rule}");
199 self.ipv6.add(rule);
200 }
201
202 fn add_regex_rule(&mut self, mut rule: String) {
203 static TREE_SET_RULE_EQUIV: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
204 let regex = TREE_SET_RULE_EQUIV.get_or_init(|| {
205 RegexBuilder::new(r#"^(?:(?:\((?:\?:)?\^\|\\\.\)|(?:\^\.(?:\+|\*))?\\\.)((?:[\w-]+(?:\\\.)?)+)|\^((?:[\w-]+(?:\\\.)?)+))\$?$"#)
206 .unicode(false)
207 .build()
208 .unwrap()
209 });
210
211 if let Some(caps) = regex.captures(rule.as_bytes()) {
212 if let Some(tree_rule) = caps.get(1) {
213 if let Ok(tree_rule) = str::from_utf8(tree_rule.as_bytes()) {
214 let tree_rule = tree_rule.replace("\\.", ".");
215 if self.add_tree_rule_inner(&tree_rule).is_ok() {
216 return;
218 }
219 }
220 } else if let Some(set_rule) = caps.get(2) {
221 if let Ok(set_rule) = str::from_utf8(set_rule.as_bytes()) {
222 let set_rule = set_rule.replace("\\.", ".");
223 if self.add_set_rule_inner(&set_rule).is_ok() {
224 return;
226 }
227 }
228 }
229 }
230
231 rule.make_ascii_lowercase();
234
235 self.rules_regex.push(rule);
238 }
239
240 #[inline]
241 fn add_set_rule(&mut self, rule: &str) -> io::Result<()> {
242 log::trace!("SET-RULE {rule}");
243 self.add_set_rule_inner(rule)
244 }
245
246 fn add_set_rule_inner(&mut self, rule: &str) -> io::Result<()> {
247 self.rules_set.insert(self.check_is_ascii(rule)?.to_ascii_lowercase());
248 Ok(())
249 }
250
251 #[inline]
252 fn add_tree_rule(&mut self, rule: &str) -> io::Result<()> {
253 log::trace!("TREE-RULE {rule}");
254 self.add_tree_rule_inner(rule)
255 }
256
257 fn add_rule_line(&mut self, line: &str) -> io::Result<()> {
258 if let Some(rule) = line.strip_prefix("||") {
259 self.add_tree_rule(rule)?;
260 return Ok(());
261 }
262
263 if let Some(rule) = line.strip_prefix('|') {
264 self.add_set_rule(rule)?;
265 return Ok(());
266 }
267
268 match line.parse::<IpNet>() {
269 Ok(IpNet::V4(v4)) => {
270 self.add_ipv4_rule(v4);
271 Ok(())
272 }
273 Ok(IpNet::V6(v6)) => {
274 self.add_ipv6_rule(v6);
275 Ok(())
276 }
277 Err(..) => match line.parse::<IpAddr>() {
278 Ok(IpAddr::V4(v4)) => {
279 self.add_ipv4_rule(v4);
280 Ok(())
281 }
282 Ok(IpAddr::V6(v6)) => {
283 self.add_ipv6_rule(v6);
284 Ok(())
285 }
286 Err(..) => {
287 self.add_regex_rule(line.to_owned());
288 Ok(())
289 }
290 },
291 }
292 }
293
294 fn add_tree_rule_inner(&mut self, rule: &str) -> io::Result<()> {
295 self.rules_tree.insert(self.check_is_ascii(rule)?);
297 Ok(())
298 }
299
300 fn check_is_ascii<'a>(&self, str: &'a str) -> io::Result<&'a str> {
301 if str.is_ascii() {
302 Ok(str.trim_end_matches('.'))
304 } else {
305 Err(Error::other(format!(
306 "{} parsing error: Unicode not allowed here `{str}`",
307 self.name
308 )))
309 }
310 }
311
312 fn compile_regex(name: &'static str, regex_rules: Vec<String>) -> io::Result<RegexSet> {
313 const REGEX_SIZE_LIMIT: usize = usize::MAX;
314 RegexSetBuilder::new(regex_rules)
315 .size_limit(REGEX_SIZE_LIMIT)
316 .unicode(false)
317 .build()
318 .map_err(|err| Error::other(format!("{name} regex error: {err}")))
319 }
320
321 fn into_rules(self) -> io::Result<Rules> {
322 Ok(Rules::new(
323 self.ipv4,
324 self.ipv6,
325 Self::compile_regex(self.name, self.rules_regex)?,
326 self.rules_set,
327 self.rules_tree,
328 ))
329 }
330}
331
332#[derive(Debug, Clone)]
351pub struct AccessControl {
352 default_action: TargetDecision,
353 proxy_rules: Rules,
354 direct_rules: Rules,
355 client_block: Rules,
356 outbound_block: Rules,
357 file_path: PathBuf,
358}
359
360impl AccessControl {
361 pub fn load_from_file<P: AsRef<Path>>(p: P) -> io::Result<AccessControl> {
363 log::trace!("ACL loading from {:?}", p.as_ref());
364
365 let file_path_ref = p.as_ref();
366 let file_path = file_path_ref.to_path_buf();
367
368 let fp = File::open(file_path_ref)?;
369 let r = BufReader::new(fp);
370
371 let mut default_action = None;
372
373 let mut proxy = ParsingRules::new("[proxy_rules]");
374 let mut direct = ParsingRules::new("[direct_rules]");
375 let mut client_block = ParsingRules::new("[client_block]");
376 let mut outbound_block = ParsingRules::new("[outbound_block]");
377 let mut curr = &mut direct;
378
379 enum Section {
380 Default,
381 ProxyRules,
382 DirectRules,
383 ClientBlock,
384 OutboundBlock,
385 }
386
387 let mut section = Section::Default;
388
389 for line in r.lines() {
390 let line = line?;
391 let line = line.trim();
392
393 if line.is_empty() {
394 continue;
395 }
396
397 if line.starts_with('#') {
399 continue;
400 }
401
402 if !line.is_ascii() {
403 log::warn!("ACL rule {line} containing non-ASCII characters, skipped");
404 continue;
405 }
406
407 if line.starts_with('[') && line.ends_with(']') {
408 let header = line[1..line.len() - 1].trim().to_ascii_lowercase();
409 match header.as_str() {
410 "default proxy" => {
411 section = Section::Default;
412 default_action = Some(TargetDecision::Proxy);
413 curr = &mut direct;
414 }
415 "default direct" => {
416 section = Section::Default;
417 default_action = Some(TargetDecision::Bypass);
418 curr = &mut direct;
419 }
420 "default block" => {
421 section = Section::Default;
422 default_action = Some(TargetDecision::Block);
423 curr = &mut direct;
424 }
425 "proxy" | "proxy_rules" => {
426 section = Section::ProxyRules;
427 curr = &mut proxy;
428 }
429 "direct" | "direct_rules" => {
430 section = Section::DirectRules;
431 curr = &mut direct;
432 }
433 "client_block" => {
434 section = Section::ClientBlock;
435 curr = &mut client_block;
436 }
437 "outbound_block" | "block" => {
438 section = Section::OutboundBlock;
439 curr = &mut outbound_block;
440 }
441 _ => {
442 return Err(Error::other(format!("unknown ACL section: {line}")));
443 }
444 }
445
446 log::trace!("switch to section {line}");
447 continue;
448 }
449
450 match section {
451 Section::Default => {
452 let value = line.strip_prefix("default ").unwrap_or(line).trim();
453 if default_action.is_none() {
454 return Err(Error::other(format!("invalid default ACL action: {value}")));
455 }
456 log::trace!("set default action to {default_action:?}");
457 }
458 Section::ProxyRules | Section::DirectRules | Section::ClientBlock | Section::OutboundBlock => {
459 curr.add_rule_line(line)?;
460 }
461 }
462 }
463
464 Ok(AccessControl {
465 default_action: default_action.ok_or_else(|| Error::other("default action not specified in ACL file"))?,
466 proxy_rules: proxy.into_rules()?,
467 direct_rules: direct.into_rules()?,
468 client_block: client_block.into_rules()?,
469 outbound_block: outbound_block.into_rules()?,
470 file_path,
471 })
472 }
473
474 pub fn file_path(&self) -> &Path {
476 &self.file_path
477 }
478
479 pub fn is_ip_empty(&self) -> bool {
481 self.proxy_rules.is_ip_empty() && self.direct_rules.is_ip_empty()
482 }
483
484 pub fn is_host_empty(&self) -> bool {
486 self.proxy_rules.is_host_empty() && self.direct_rules.is_host_empty()
487 }
488
489 pub fn decide_host(&self, host: &str) -> Option<TargetDecision> {
494 let host = Self::normalize_host(host);
495 if self.direct_rules.check_host_matched(&host) {
496 return Some(TargetDecision::Bypass);
497 }
498 if self.proxy_rules.check_host_matched(&host) {
499 return Some(TargetDecision::Proxy);
500 }
501 None
502 }
503
504 fn normalize_host(host: &str) -> Cow<'_, str> {
509 idna::domain_to_ascii(host)
510 .map(|host| Cow::Owned(host.to_ascii_lowercase()))
511 .unwrap_or_else(|_| Cow::Owned(host.to_ascii_lowercase()))
512 }
513
514 pub async fn decide_target(&self, addr: &Address) -> TargetDecision {
516 match *addr {
517 Address::SocketAddress(ref addr) => {
518 if self.outbound_block.check_ip_matched(&addr.ip()) {
519 return TargetDecision::Block;
520 }
521 self.decide_socket_addr(&addr.ip())
522 }
523 Address::DomainAddress(ref host, port) => {
524 if self.outbound_block.check_host_matched(&Self::normalize_host(host)) {
525 return TargetDecision::Block;
526 }
527 if let Some(value) = self.decide_host(host) {
528 return value;
529 }
530 if self.proxy_rules.is_ip_empty() && self.direct_rules.is_ip_empty() {
531 return self.default_action;
532 }
533 if let Ok(vaddr) = dns_resolve(host, port).await {
534 if vaddr.iter().any(|addr| self.outbound_block.check_ip_matched(&addr.ip())) {
535 return TargetDecision::Block;
536 }
537 if let Some(decision) = self.decide_resolved_ips(&vaddr) {
538 return decision;
539 }
540 }
541 self.default_action
542 }
543 }
544 }
545
546 pub fn check_client_blocked(&self, addr: &SocketAddr) -> bool {
548 self.client_block.check_ip_matched(&addr.ip())
549 }
550
551 pub async fn check_outbound_blocked(&self, outbound: &Address) -> bool {
556 self.decide_target(outbound).await.should_block()
557 }
558
559 fn decide_socket_addr(&self, ip: &IpAddr) -> TargetDecision {
560 if self.direct_rules.check_ip_matched(ip) {
561 return TargetDecision::Bypass;
562 }
563 if self.proxy_rules.check_ip_matched(ip) {
564 return TargetDecision::Proxy;
565 }
566
567 self.default_action
568 }
569
570 fn decide_resolved_ips(&self, addrs: &[SocketAddr]) -> Option<TargetDecision> {
571 if addrs.iter().any(|addr| self.direct_rules.check_ip_matched(&addr.ip())) {
572 return Some(TargetDecision::Bypass);
573 }
574 if addrs.iter().any(|addr| self.proxy_rules.check_ip_matched(&addr.ip())) {
575 return Some(TargetDecision::Proxy);
576 }
577
578 None
579 }
580}
581
582async fn dns_resolve(domain: &str, port: u16) -> std::io::Result<Vec<std::net::SocketAddr>> {
583 let addrs = tokio::net::lookup_host((domain, port)).await?;
584 Ok(addrs.collect())
585}
586
587#[tokio::test]
588async fn test_dns_resolve() {
589 let addrs = dns_resolve("baidu.com", 80).await.unwrap();
590 println!("Resolved addresses: {addrs:?}");
591 assert!(!addrs.is_empty());
592
593 let addrs = dns_resolve("localhost", 80).await.unwrap();
594 println!("Resolved addresses: {addrs:?}");
595 assert!(!addrs.is_empty());
596
597 let addrs = dns_resolve("123.45.67.89", 65535).await.unwrap();
598 println!("Resolved addresses: {addrs:?}");
599 assert!(!addrs.is_empty());
600
601 let addrs = dns_resolve("xxxxsasasasd", 65535).await;
602 assert!(addrs.is_err());
603}
604
605#[tokio::test]
606async fn test_acl() {
607 let acl_path = std::env::temp_dir().join(format!(
608 "socks-hub-acl-v2-{}-{}.acl",
609 std::process::id(),
610 std::time::SystemTime::now()
611 .duration_since(std::time::UNIX_EPOCH)
612 .unwrap()
613 .as_nanos()
614 ));
615
616 std::fs::write(
617 &acl_path,
618 r#"
619[default proxy]
620[proxy]
621||google.com
622|sex.com
623[direct]
624127.0.0.1
625||baidu.com
626|example.com
627192.168.0.0/16
628[block]
62910.0.0.0/8
630"#,
631 )
632 .unwrap();
633
634 let acl = AccessControl::load_from_file(&acl_path).unwrap();
635 let _ = std::fs::remove_file(&acl_path);
636
637 assert!(!acl.is_ip_empty());
638 assert!(!acl.is_host_empty());
639
640 assert_eq!(acl.decide_host("www.google.com"), Some(TargetDecision::Proxy));
641 assert_eq!(acl.decide_host("www.baidu.com"), Some(TargetDecision::Bypass));
642 assert_eq!(acl.decide_host("sex.com"), Some(TargetDecision::Proxy));
643 assert_eq!(acl.decide_host("example.com"), Some(TargetDecision::Bypass));
644 assert_eq!(acl.decide_host("youtube.com"), None);
645
646 let proxy_addr = Address::SocketAddress(std::net::SocketAddr::from(([127, 0, 0, 1], 80)));
647 let direct_addr = Address::SocketAddress(std::net::SocketAddr::from(([192, 168, 1, 10], 80)));
648 let blocked_addr = Address::SocketAddress(std::net::SocketAddr::from(([10, 0, 0, 1], 80)));
649
650 assert_eq!(acl.decide_target(&proxy_addr).await, TargetDecision::Bypass);
651 assert_eq!(acl.decide_target(&direct_addr).await, TargetDecision::Bypass);
652 assert!(acl.check_outbound_blocked(&blocked_addr).await);
653
654 std::fs::write(
655 &acl_path,
656 r#"
657[default block]
658[proxy]
659||example.com
660"#,
661 )
662 .unwrap();
663
664 let acl = AccessControl::load_from_file(&acl_path).unwrap();
665 assert_eq!(
666 acl.decide_target(&Address::from(("unmatched.test", 80))).await,
667 TargetDecision::Block
668 );
669}