Skip to main content

socks_hub_core/acl/
mod.rs

1//! Access Control List (ACL) from shadowsocks
2//!
3//! This is for advance controlling server behaviors in both local and proxy servers.
4//!
5//! The ACL has one shared target-routing rule set used by both client-side and server-side
6//! proxy decisions. Server-only rules are limited to peer blocking and outbound blocking.
7//!
8//! source link https://github.com/shadowsocks/shadowsocks-rust/blob/master/crates/shadowsocks-service/src/acl/mod.rs
9//!
10
11use ipnet::{IpNet, Ipv4Net, Ipv6Net};
12use iprange::IpRange;
13use regex::bytes::{Regex, RegexBuilder, RegexSet, RegexSetBuilder};
14pub use socks5_impl::protocol::Address;
15use std::{
16    borrow::Cow,
17    collections::HashSet,
18    fmt,
19    fs::File,
20    io::{self, BufRead, BufReader, Error},
21    net::{IpAddr, SocketAddr},
22    path::{Path, PathBuf},
23    str,
24};
25
26mod sub_domains_tree;
27use sub_domains_tree::SubDomainsTree;
28
29/// Result of evaluating how a target should be handled.
30#[derive(Debug, Copy, Clone, Eq, PartialEq)]
31pub enum TargetDecision {
32    Proxy,
33    Bypass,
34    Block,
35}
36
37impl TargetDecision {
38    pub fn should_proxy(self) -> bool {
39        matches!(self, TargetDecision::Proxy)
40    }
41
42    pub fn should_bypass(self) -> bool {
43        matches!(self, TargetDecision::Bypass)
44    }
45
46    pub fn should_block(self) -> bool {
47        matches!(self, TargetDecision::Block)
48    }
49}
50
51#[derive(Clone)]
52struct Rules {
53    ipv4: IpRange<Ipv4Net>,
54    ipv6: IpRange<Ipv6Net>,
55    rule_regex: RegexSet,
56    rule_set: HashSet<String>,
57    rule_tree: SubDomainsTree,
58}
59
60impl fmt::Debug for Rules {
61    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
62        write!(f, "Rules {{ ipv4: {:?}, ipv6: {:?}, rule_regex: [", self.ipv4, self.ipv6)?;
63
64        let max_len = 2;
65        let has_more = self.rule_regex.len() > max_len;
66
67        for (idx, r) in self.rule_regex.patterns().iter().take(max_len).enumerate() {
68            if idx > 0 {
69                f.write_str(", ")?;
70            }
71            f.write_str(r)?;
72        }
73
74        if has_more {
75            f.write_str(", ...")?;
76        }
77
78        write!(f, "], rule_set: [")?;
79
80        let has_more = self.rule_set.len() > max_len;
81        for (idx, r) in self.rule_set.iter().take(max_len).enumerate() {
82            if idx > 0 {
83                f.write_str(", ")?;
84            }
85            f.write_str(r)?;
86        }
87
88        if has_more {
89            f.write_str(", ...")?;
90        }
91
92        write!(f, "], rule_tree: {:?} }}", self.rule_tree)
93    }
94}
95
96impl Rules {
97    /// Create a new rule
98    fn new(
99        mut ipv4: IpRange<Ipv4Net>,
100        mut ipv6: IpRange<Ipv6Net>,
101        rule_regex: RegexSet,
102        rule_set: HashSet<String>,
103        rule_tree: SubDomainsTree,
104    ) -> Rules {
105        // Optimization, merging networks
106        ipv4.simplify();
107        ipv6.simplify();
108
109        Rules {
110            ipv4,
111            ipv6,
112            rule_regex,
113            rule_set,
114            rule_tree,
115        }
116    }
117
118    /// Check if the specified address matches these rules
119    #[allow(dead_code)]
120    fn check_address_matched(&self, addr: &Address) -> bool {
121        match *addr {
122            Address::SocketAddress(ref saddr) => self.check_ip_matched(&saddr.ip()),
123            Address::DomainAddress(ref domain, ..) => self.check_host_matched(domain),
124        }
125    }
126
127    /// Check if the specified address matches any rules
128    fn check_ip_matched(&self, addr: &IpAddr) -> bool {
129        match addr {
130            IpAddr::V4(v4) => {
131                if self.ipv4.contains(v4) {
132                    return true;
133                }
134
135                let mapped_ipv6 = v4.to_ipv6_mapped();
136                self.ipv6.contains(&mapped_ipv6)
137            }
138            IpAddr::V6(v6) => {
139                if self.ipv6.contains(v6) {
140                    return true;
141                }
142
143                if let Some(mapped_ipv4) = v6.to_ipv4_mapped() {
144                    return self.ipv4.contains(&mapped_ipv4);
145                }
146
147                false
148            }
149        }
150    }
151
152    /// Check if the specified ASCII host matches any rules
153    fn check_host_matched(&self, host: &str) -> bool {
154        let host = host.trim_end_matches('.'); // FQDN, removes the last `.`
155        self.rule_set.contains(host) || self.rule_tree.contains(host) || self.rule_regex.is_match(host.as_bytes())
156    }
157
158    /// Check if there are no rules for IP addresses
159    fn is_ip_empty(&self) -> bool {
160        self.ipv4.is_empty() && self.ipv6.is_empty()
161    }
162
163    /// Check if there are no rules for domain names
164    fn is_host_empty(&self) -> bool {
165        self.rule_set.is_empty() && self.rule_tree.is_empty() && self.rule_regex.is_empty()
166    }
167}
168
169struct ParsingRules {
170    name: &'static str,
171    ipv4: IpRange<Ipv4Net>,
172    ipv6: IpRange<Ipv6Net>,
173    rules_regex: Vec<String>,
174    rules_set: HashSet<String>,
175    rules_tree: SubDomainsTree,
176}
177
178impl ParsingRules {
179    fn new(name: &'static str) -> Self {
180        ParsingRules {
181            name,
182            ipv4: IpRange::new(),
183            ipv6: IpRange::new(),
184            rules_regex: Vec::new(),
185            rules_set: HashSet::new(),
186            rules_tree: SubDomainsTree::new(),
187        }
188    }
189
190    fn add_ipv4_rule(&mut self, rule: impl Into<Ipv4Net>) {
191        let rule = rule.into();
192        // log::trace!("IPV4-RULE {}", rule);
193        self.ipv4.add(rule);
194    }
195
196    fn add_ipv6_rule(&mut self, rule: impl Into<Ipv6Net>) {
197        let rule = rule.into();
198        log::trace!("IPV6-RULE {rule}");
199        self.ipv6.add(rule);
200    }
201
202    fn add_regex_rule(&mut self, mut rule: String) {
203        static TREE_SET_RULE_EQUIV: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
204        let regex = TREE_SET_RULE_EQUIV.get_or_init(|| {
205            RegexBuilder::new(r#"^(?:(?:\((?:\?:)?\^\|\\\.\)|(?:\^\.(?:\+|\*))?\\\.)((?:[\w-]+(?:\\\.)?)+)|\^((?:[\w-]+(?:\\\.)?)+))\$?$"#)
206                .unicode(false)
207                .build()
208                .unwrap()
209        });
210
211        if let Some(caps) = regex.captures(rule.as_bytes()) {
212            if let Some(tree_rule) = caps.get(1) {
213                if let Ok(tree_rule) = str::from_utf8(tree_rule.as_bytes()) {
214                    let tree_rule = tree_rule.replace("\\.", ".");
215                    if self.add_tree_rule_inner(&tree_rule).is_ok() {
216                        // log::trace!("REGEX-RULE {} => TREE-RULE {}", rule, tree_rule);
217                        return;
218                    }
219                }
220            } else if let Some(set_rule) = caps.get(2) {
221                if let Ok(set_rule) = str::from_utf8(set_rule.as_bytes()) {
222                    let set_rule = set_rule.replace("\\.", ".");
223                    if self.add_set_rule_inner(&set_rule).is_ok() {
224                        // log::trace!("REGEX-RULE {} => SET-RULE {}", rule, set_rule);
225                        return;
226                    }
227                }
228            }
229        }
230
231        // log::trace!("REGEX-RULE {}", rule);
232
233        rule.make_ascii_lowercase();
234
235        // Handle it as a normal REGEX
236        // FIXME: If this line is not a valid regex, how can we know without actually compile it?
237        self.rules_regex.push(rule);
238    }
239
240    #[inline]
241    fn add_set_rule(&mut self, rule: &str) -> io::Result<()> {
242        log::trace!("SET-RULE {rule}");
243        self.add_set_rule_inner(rule)
244    }
245
246    fn add_set_rule_inner(&mut self, rule: &str) -> io::Result<()> {
247        self.rules_set.insert(self.check_is_ascii(rule)?.to_ascii_lowercase());
248        Ok(())
249    }
250
251    #[inline]
252    fn add_tree_rule(&mut self, rule: &str) -> io::Result<()> {
253        log::trace!("TREE-RULE {rule}");
254        self.add_tree_rule_inner(rule)
255    }
256
257    fn add_rule_line(&mut self, line: &str) -> io::Result<()> {
258        if let Some(rule) = line.strip_prefix("||") {
259            self.add_tree_rule(rule)?;
260            return Ok(());
261        }
262
263        if let Some(rule) = line.strip_prefix('|') {
264            self.add_set_rule(rule)?;
265            return Ok(());
266        }
267
268        match line.parse::<IpNet>() {
269            Ok(IpNet::V4(v4)) => {
270                self.add_ipv4_rule(v4);
271                Ok(())
272            }
273            Ok(IpNet::V6(v6)) => {
274                self.add_ipv6_rule(v6);
275                Ok(())
276            }
277            Err(..) => match line.parse::<IpAddr>() {
278                Ok(IpAddr::V4(v4)) => {
279                    self.add_ipv4_rule(v4);
280                    Ok(())
281                }
282                Ok(IpAddr::V6(v6)) => {
283                    self.add_ipv6_rule(v6);
284                    Ok(())
285                }
286                Err(..) => {
287                    self.add_regex_rule(line.to_owned());
288                    Ok(())
289                }
290            },
291        }
292    }
293
294    fn add_tree_rule_inner(&mut self, rule: &str) -> io::Result<()> {
295        // SubDomainsTree do lowercase conversion inside insert
296        self.rules_tree.insert(self.check_is_ascii(rule)?);
297        Ok(())
298    }
299
300    fn check_is_ascii<'a>(&self, str: &'a str) -> io::Result<&'a str> {
301        if str.is_ascii() {
302            // Remove the last `.` of FQDN
303            Ok(str.trim_end_matches('.'))
304        } else {
305            Err(Error::other(format!(
306                "{} parsing error: Unicode not allowed here `{str}`",
307                self.name
308            )))
309        }
310    }
311
312    fn compile_regex(name: &'static str, regex_rules: Vec<String>) -> io::Result<RegexSet> {
313        const REGEX_SIZE_LIMIT: usize = usize::MAX;
314        RegexSetBuilder::new(regex_rules)
315            .size_limit(REGEX_SIZE_LIMIT)
316            .unicode(false)
317            .build()
318            .map_err(|err| Error::other(format!("{name} regex error: {err}")))
319    }
320
321    fn into_rules(self) -> io::Result<Rules> {
322        Ok(Rules::new(
323            self.ipv4,
324            self.ipv6,
325            Self::compile_regex(self.name, self.rules_regex)?,
326            self.rules_set,
327            self.rules_tree,
328        ))
329    }
330}
331
332/// ACL rules v2
333///
334/// ACL files are small ordered routing tables. They have one default action and a handful of
335/// explicit sections:
336///
337/// - `[default proxy]` / `[default direct]` / `[default block]` - one line, specifies the default action
338/// - `[proxy_rules]` - targets that must go through proxy
339/// - `[direct_rules]` - targets that must connect directly
340/// - `[client_block]` - client addresses that must be rejected by the server
341/// - `[outbound_block]` / `[block]` - targets that must be blocked
342///
343/// Rule lines can be one of:
344///
345/// - CIDR network, like `10.9.0.32/16`
346/// - IP address, like `127.0.0.1` or `::1`
347/// - Exact domain, like `|google.com`
348/// - Domain suffix, like `||google.com`
349/// - Regular expression, like `(^|\.)gmail\.com$`
350#[derive(Debug, Clone)]
351pub struct AccessControl {
352    default_action: TargetDecision,
353    proxy_rules: Rules,
354    direct_rules: Rules,
355    client_block: Rules,
356    outbound_block: Rules,
357    file_path: PathBuf,
358}
359
360impl AccessControl {
361    /// Load ACL rules from a file
362    pub fn load_from_file<P: AsRef<Path>>(p: P) -> io::Result<AccessControl> {
363        log::trace!("ACL loading from {:?}", p.as_ref());
364
365        let file_path_ref = p.as_ref();
366        let file_path = file_path_ref.to_path_buf();
367
368        let fp = File::open(file_path_ref)?;
369        let r = BufReader::new(fp);
370
371        let mut default_action = None;
372
373        let mut proxy = ParsingRules::new("[proxy_rules]");
374        let mut direct = ParsingRules::new("[direct_rules]");
375        let mut client_block = ParsingRules::new("[client_block]");
376        let mut outbound_block = ParsingRules::new("[outbound_block]");
377        let mut curr = &mut direct;
378
379        enum Section {
380            Default,
381            ProxyRules,
382            DirectRules,
383            ClientBlock,
384            OutboundBlock,
385        }
386
387        let mut section = Section::Default;
388
389        for line in r.lines() {
390            let line = line?;
391            let line = line.trim();
392
393            if line.is_empty() {
394                continue;
395            }
396
397            // Comments
398            if line.starts_with('#') {
399                continue;
400            }
401
402            if !line.is_ascii() {
403                log::warn!("ACL rule {line} containing non-ASCII characters, skipped");
404                continue;
405            }
406
407            if line.starts_with('[') && line.ends_with(']') {
408                let header = line[1..line.len() - 1].trim().to_ascii_lowercase();
409                match header.as_str() {
410                    "default proxy" => {
411                        section = Section::Default;
412                        default_action = Some(TargetDecision::Proxy);
413                        curr = &mut direct;
414                    }
415                    "default direct" => {
416                        section = Section::Default;
417                        default_action = Some(TargetDecision::Bypass);
418                        curr = &mut direct;
419                    }
420                    "default block" => {
421                        section = Section::Default;
422                        default_action = Some(TargetDecision::Block);
423                        curr = &mut direct;
424                    }
425                    "proxy" | "proxy_rules" => {
426                        section = Section::ProxyRules;
427                        curr = &mut proxy;
428                    }
429                    "direct" | "direct_rules" => {
430                        section = Section::DirectRules;
431                        curr = &mut direct;
432                    }
433                    "client_block" => {
434                        section = Section::ClientBlock;
435                        curr = &mut client_block;
436                    }
437                    "outbound_block" | "block" => {
438                        section = Section::OutboundBlock;
439                        curr = &mut outbound_block;
440                    }
441                    _ => {
442                        return Err(Error::other(format!("unknown ACL section: {line}")));
443                    }
444                }
445
446                log::trace!("switch to section {line}");
447                continue;
448            }
449
450            match section {
451                Section::Default => {
452                    let value = line.strip_prefix("default ").unwrap_or(line).trim();
453                    if default_action.is_none() {
454                        return Err(Error::other(format!("invalid default ACL action: {value}")));
455                    }
456                    log::trace!("set default action to {default_action:?}");
457                }
458                Section::ProxyRules | Section::DirectRules | Section::ClientBlock | Section::OutboundBlock => {
459                    curr.add_rule_line(line)?;
460                }
461            }
462        }
463
464        Ok(AccessControl {
465            default_action: default_action.ok_or_else(|| Error::other("default action not specified in ACL file"))?,
466            proxy_rules: proxy.into_rules()?,
467            direct_rules: direct.into_rules()?,
468            client_block: client_block.into_rules()?,
469            outbound_block: outbound_block.into_rules()?,
470            file_path,
471        })
472    }
473
474    /// Get ACL file path
475    pub fn file_path(&self) -> &Path {
476        &self.file_path
477    }
478
479    /// Check if there are no IP routing rules.
480    pub fn is_ip_empty(&self) -> bool {
481        self.proxy_rules.is_ip_empty() && self.direct_rules.is_ip_empty()
482    }
483
484    /// Check if there are no host routing rules.
485    pub fn is_host_empty(&self) -> bool {
486        self.proxy_rules.is_host_empty() && self.direct_rules.is_host_empty()
487    }
488
489    /// Decide how an ASCII domain should be handled.
490    ///
491    /// Returns the first matching action, or `None` if no rule matches.
492    /// The caller can then fall back to the default action.
493    pub fn decide_host(&self, host: &str) -> Option<TargetDecision> {
494        let host = Self::normalize_host(host);
495        if self.direct_rules.check_host_matched(&host) {
496            return Some(TargetDecision::Bypass);
497        }
498        if self.proxy_rules.check_host_matched(&host) {
499            return Some(TargetDecision::Proxy);
500        }
501        None
502    }
503
504    /// Normalize a domain name for rule matching.
505    ///
506    /// Hostnames are converted to ASCII when possible, then folded to lower-case because
507    /// rule storage is case-insensitive.
508    fn normalize_host(host: &str) -> Cow<'_, str> {
509        idna::domain_to_ascii(host)
510            .map(|host| Cow::Owned(host.to_ascii_lowercase()))
511            .unwrap_or_else(|_| Cow::Owned(host.to_ascii_lowercase()))
512    }
513
514    /// Decide how a target should be handled.
515    pub async fn decide_target(&self, addr: &Address) -> TargetDecision {
516        match *addr {
517            Address::SocketAddress(ref addr) => {
518                if self.outbound_block.check_ip_matched(&addr.ip()) {
519                    return TargetDecision::Block;
520                }
521                self.decide_socket_addr(&addr.ip())
522            }
523            Address::DomainAddress(ref host, port) => {
524                if self.outbound_block.check_host_matched(&Self::normalize_host(host)) {
525                    return TargetDecision::Block;
526                }
527                if let Some(value) = self.decide_host(host) {
528                    return value;
529                }
530                if self.proxy_rules.is_ip_empty() && self.direct_rules.is_ip_empty() {
531                    return self.default_action;
532                }
533                if let Ok(vaddr) = dns_resolve(host, port).await {
534                    if vaddr.iter().any(|addr| self.outbound_block.check_ip_matched(&addr.ip())) {
535                        return TargetDecision::Block;
536                    }
537                    if let Some(decision) = self.decide_resolved_ips(&vaddr) {
538                        return decision;
539                    }
540                }
541                self.default_action
542            }
543        }
544    }
545
546    /// Check if client address should be blocked (for server)
547    pub fn check_client_blocked(&self, addr: &SocketAddr) -> bool {
548        self.client_block.check_ip_matched(&addr.ip())
549    }
550
551    /// Check if outbound address is blocked (for server)
552    ///
553    /// NOTE: `Address::DomainAddress` is only validated by regex rules,
554    ///       resolved addresses are checked in the `lookup_outbound_then!` macro
555    pub async fn check_outbound_blocked(&self, outbound: &Address) -> bool {
556        self.decide_target(outbound).await.should_block()
557    }
558
559    fn decide_socket_addr(&self, ip: &IpAddr) -> TargetDecision {
560        if self.direct_rules.check_ip_matched(ip) {
561            return TargetDecision::Bypass;
562        }
563        if self.proxy_rules.check_ip_matched(ip) {
564            return TargetDecision::Proxy;
565        }
566
567        self.default_action
568    }
569
570    fn decide_resolved_ips(&self, addrs: &[SocketAddr]) -> Option<TargetDecision> {
571        if addrs.iter().any(|addr| self.direct_rules.check_ip_matched(&addr.ip())) {
572            return Some(TargetDecision::Bypass);
573        }
574        if addrs.iter().any(|addr| self.proxy_rules.check_ip_matched(&addr.ip())) {
575            return Some(TargetDecision::Proxy);
576        }
577
578        None
579    }
580}
581
582async fn dns_resolve(domain: &str, port: u16) -> std::io::Result<Vec<std::net::SocketAddr>> {
583    let addrs = tokio::net::lookup_host((domain, port)).await?;
584    Ok(addrs.collect())
585}
586
587#[tokio::test]
588async fn test_dns_resolve() {
589    let addrs = dns_resolve("baidu.com", 80).await.unwrap();
590    println!("Resolved addresses: {addrs:?}");
591    assert!(!addrs.is_empty());
592
593    let addrs = dns_resolve("localhost", 80).await.unwrap();
594    println!("Resolved addresses: {addrs:?}");
595    assert!(!addrs.is_empty());
596
597    let addrs = dns_resolve("123.45.67.89", 65535).await.unwrap();
598    println!("Resolved addresses: {addrs:?}");
599    assert!(!addrs.is_empty());
600
601    let addrs = dns_resolve("xxxxsasasasd", 65535).await;
602    assert!(addrs.is_err());
603}
604
605#[tokio::test]
606async fn test_acl() {
607    let acl_path = std::env::temp_dir().join(format!(
608        "socks-hub-acl-v2-{}-{}.acl",
609        std::process::id(),
610        std::time::SystemTime::now()
611            .duration_since(std::time::UNIX_EPOCH)
612            .unwrap()
613            .as_nanos()
614    ));
615
616    std::fs::write(
617        &acl_path,
618        r#"
619[default proxy]
620[proxy]
621||google.com
622|sex.com
623[direct]
624127.0.0.1
625||baidu.com
626|example.com
627192.168.0.0/16
628[block]
62910.0.0.0/8
630"#,
631    )
632    .unwrap();
633
634    let acl = AccessControl::load_from_file(&acl_path).unwrap();
635    let _ = std::fs::remove_file(&acl_path);
636
637    assert!(!acl.is_ip_empty());
638    assert!(!acl.is_host_empty());
639
640    assert_eq!(acl.decide_host("www.google.com"), Some(TargetDecision::Proxy));
641    assert_eq!(acl.decide_host("www.baidu.com"), Some(TargetDecision::Bypass));
642    assert_eq!(acl.decide_host("sex.com"), Some(TargetDecision::Proxy));
643    assert_eq!(acl.decide_host("example.com"), Some(TargetDecision::Bypass));
644    assert_eq!(acl.decide_host("youtube.com"), None);
645
646    let proxy_addr = Address::SocketAddress(std::net::SocketAddr::from(([127, 0, 0, 1], 80)));
647    let direct_addr = Address::SocketAddress(std::net::SocketAddr::from(([192, 168, 1, 10], 80)));
648    let blocked_addr = Address::SocketAddress(std::net::SocketAddr::from(([10, 0, 0, 1], 80)));
649
650    assert_eq!(acl.decide_target(&proxy_addr).await, TargetDecision::Bypass);
651    assert_eq!(acl.decide_target(&direct_addr).await, TargetDecision::Bypass);
652    assert!(acl.check_outbound_blocked(&blocked_addr).await);
653
654    std::fs::write(
655        &acl_path,
656        r#"
657[default block]
658[proxy]
659||example.com
660"#,
661    )
662    .unwrap();
663
664    let acl = AccessControl::load_from_file(&acl_path).unwrap();
665    assert_eq!(
666        acl.decide_target(&Address::from(("unmatched.test", 80))).await,
667        TargetDecision::Block
668    );
669}