socks_hub/acl/
mod.rs

1//! Access Control List (ACL) from shadowsocks
2//!
3//! This is for advance controlling server behaviors in both local and proxy servers.
4//!
5//! source link https://github.com/shadowsocks/shadowsocks-rust/blob/master/crates/shadowsocks-service/src/acl/mod.rs
6//!
7
8use ipnet::{IpNet, Ipv4Net, Ipv6Net};
9use iprange::IpRange;
10use regex::bytes::{Regex, RegexBuilder, RegexSet, RegexSetBuilder};
11pub use socks5_impl::protocol::Address;
12use std::{
13    borrow::Cow,
14    collections::HashSet,
15    fmt,
16    fs::File,
17    io::{self, BufRead, BufReader, Error},
18    net::{IpAddr, SocketAddr},
19    path::{Path, PathBuf},
20    str,
21};
22
23mod sub_domains_tree;
24use sub_domains_tree::SubDomainsTree;
25
26/// Strategy mode that ACL is running
27#[derive(Debug, Copy, Clone, Eq, PartialEq)]
28pub enum Mode {
29    /// BlackList mode, rejects or bypasses all requests by default
30    BlackList,
31    /// WhiteList mode, accepts or proxies all requests by default
32    WhiteList,
33}
34
35#[derive(Clone)]
36struct Rules {
37    ipv4: IpRange<Ipv4Net>,
38    ipv6: IpRange<Ipv6Net>,
39    rule_regex: RegexSet,
40    rule_set: HashSet<String>,
41    rule_tree: SubDomainsTree,
42}
43
44impl fmt::Debug for Rules {
45    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46        write!(f, "Rules {{ ipv4: {:?}, ipv6: {:?}, rule_regex: [", self.ipv4, self.ipv6)?;
47
48        let max_len = 2;
49        let has_more = self.rule_regex.len() > max_len;
50
51        for (idx, r) in self.rule_regex.patterns().iter().take(max_len).enumerate() {
52            if idx > 0 {
53                f.write_str(", ")?;
54            }
55            f.write_str(r)?;
56        }
57
58        if has_more {
59            f.write_str(", ...")?;
60        }
61
62        write!(f, "], rule_set: [")?;
63
64        let has_more = self.rule_set.len() > max_len;
65        for (idx, r) in self.rule_set.iter().take(max_len).enumerate() {
66            if idx > 0 {
67                f.write_str(", ")?;
68            }
69            f.write_str(r)?;
70        }
71
72        if has_more {
73            f.write_str(", ...")?;
74        }
75
76        write!(f, "], rule_tree: {:?} }}", self.rule_tree)
77    }
78}
79
80impl Rules {
81    /// Create a new rule
82    fn new(
83        mut ipv4: IpRange<Ipv4Net>,
84        mut ipv6: IpRange<Ipv6Net>,
85        rule_regex: RegexSet,
86        rule_set: HashSet<String>,
87        rule_tree: SubDomainsTree,
88    ) -> Rules {
89        // Optimization, merging networks
90        ipv4.simplify();
91        ipv6.simplify();
92
93        Rules {
94            ipv4,
95            ipv6,
96            rule_regex,
97            rule_set,
98            rule_tree,
99        }
100    }
101
102    /// Check if the specified address matches these rules
103    #[allow(dead_code)]
104    fn check_address_matched(&self, addr: &Address) -> bool {
105        match *addr {
106            Address::SocketAddress(ref saddr) => self.check_ip_matched(&saddr.ip()),
107            Address::DomainAddress(ref domain, ..) => self.check_host_matched(domain),
108        }
109    }
110
111    /// Check if the specified address matches any rules
112    fn check_ip_matched(&self, addr: &IpAddr) -> bool {
113        match addr {
114            IpAddr::V4(v4) => {
115                if self.ipv4.contains(v4) {
116                    return true;
117                }
118
119                let mapped_ipv6 = v4.to_ipv6_mapped();
120                self.ipv6.contains(&mapped_ipv6)
121            }
122            IpAddr::V6(v6) => {
123                if self.ipv6.contains(v6) {
124                    return true;
125                }
126
127                if let Some(mapped_ipv4) = v6.to_ipv4_mapped() {
128                    return self.ipv4.contains(&mapped_ipv4);
129                }
130
131                false
132            }
133        }
134    }
135
136    /// Check if the specified ASCII host matches any rules
137    fn check_host_matched(&self, host: &str) -> bool {
138        let host = host.trim_end_matches('.'); // FQDN, removes the last `.`
139        self.rule_set.contains(host) || self.rule_tree.contains(host) || self.rule_regex.is_match(host.as_bytes())
140    }
141
142    /// Check if there are no rules for IP addresses
143    fn is_ip_empty(&self) -> bool {
144        self.ipv4.is_empty() && self.ipv6.is_empty()
145    }
146
147    /// Check if there are no rules for domain names
148    fn is_host_empty(&self) -> bool {
149        self.rule_set.is_empty() && self.rule_tree.is_empty() && self.rule_regex.is_empty()
150    }
151}
152
153struct ParsingRules {
154    name: &'static str,
155    ipv4: IpRange<Ipv4Net>,
156    ipv6: IpRange<Ipv6Net>,
157    rules_regex: Vec<String>,
158    rules_set: HashSet<String>,
159    rules_tree: SubDomainsTree,
160}
161
162impl ParsingRules {
163    fn new(name: &'static str) -> Self {
164        ParsingRules {
165            name,
166            ipv4: IpRange::new(),
167            ipv6: IpRange::new(),
168            rules_regex: Vec::new(),
169            rules_set: HashSet::new(),
170            rules_tree: SubDomainsTree::new(),
171        }
172    }
173
174    fn add_ipv4_rule(&mut self, rule: impl Into<Ipv4Net>) {
175        let rule = rule.into();
176        // log::trace!("IPV4-RULE {}", rule);
177        self.ipv4.add(rule);
178    }
179
180    fn add_ipv6_rule(&mut self, rule: impl Into<Ipv6Net>) {
181        let rule = rule.into();
182        log::trace!("IPV6-RULE {rule}");
183        self.ipv6.add(rule);
184    }
185
186    fn add_regex_rule(&mut self, mut rule: String) {
187        static TREE_SET_RULE_EQUIV: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
188        let regex = TREE_SET_RULE_EQUIV.get_or_init(|| {
189            RegexBuilder::new(r#"^(?:(?:\((?:\?:)?\^\|\\\.\)|(?:\^\.(?:\+|\*))?\\\.)((?:[\w-]+(?:\\\.)?)+)|\^((?:[\w-]+(?:\\\.)?)+))\$?$"#)
190                .unicode(false)
191                .build()
192                .unwrap()
193        });
194
195        if let Some(caps) = regex.captures(rule.as_bytes()) {
196            if let Some(tree_rule) = caps.get(1) {
197                if let Ok(tree_rule) = str::from_utf8(tree_rule.as_bytes()) {
198                    let tree_rule = tree_rule.replace("\\.", ".");
199                    if self.add_tree_rule_inner(&tree_rule).is_ok() {
200                        // log::trace!("REGEX-RULE {} => TREE-RULE {}", rule, tree_rule);
201                        return;
202                    }
203                }
204            } else if let Some(set_rule) = caps.get(2) {
205                if let Ok(set_rule) = str::from_utf8(set_rule.as_bytes()) {
206                    let set_rule = set_rule.replace("\\.", ".");
207                    if self.add_set_rule_inner(&set_rule).is_ok() {
208                        // log::trace!("REGEX-RULE {} => SET-RULE {}", rule, set_rule);
209                        return;
210                    }
211                }
212            }
213        }
214
215        // log::trace!("REGEX-RULE {}", rule);
216
217        rule.make_ascii_lowercase();
218
219        // Handle it as a normal REGEX
220        // FIXME: If this line is not a valid regex, how can we know without actually compile it?
221        self.rules_regex.push(rule);
222    }
223
224    #[inline]
225    fn add_set_rule(&mut self, rule: &str) -> io::Result<()> {
226        log::trace!("SET-RULE {rule}");
227        self.add_set_rule_inner(rule)
228    }
229
230    fn add_set_rule_inner(&mut self, rule: &str) -> io::Result<()> {
231        self.rules_set.insert(self.check_is_ascii(rule)?.to_ascii_lowercase());
232        Ok(())
233    }
234
235    #[inline]
236    fn add_tree_rule(&mut self, rule: &str) -> io::Result<()> {
237        log::trace!("TREE-RULE {rule}");
238        self.add_tree_rule_inner(rule)
239    }
240
241    fn add_tree_rule_inner(&mut self, rule: &str) -> io::Result<()> {
242        // SubDomainsTree do lowercase conversion inside insert
243        self.rules_tree.insert(self.check_is_ascii(rule)?);
244        Ok(())
245    }
246
247    fn check_is_ascii<'a>(&self, str: &'a str) -> io::Result<&'a str> {
248        if str.is_ascii() {
249            // Remove the last `.` of FQDN
250            Ok(str.trim_end_matches('.'))
251        } else {
252            Err(Error::other(format!(
253                "{} parsing error: Unicode not allowed here `{str}`",
254                self.name
255            )))
256        }
257    }
258
259    fn compile_regex(name: &'static str, regex_rules: Vec<String>) -> io::Result<RegexSet> {
260        const REGEX_SIZE_LIMIT: usize = usize::MAX;
261        RegexSetBuilder::new(regex_rules)
262            .size_limit(REGEX_SIZE_LIMIT)
263            .unicode(false)
264            .build()
265            .map_err(|err| Error::other(format!("{name} regex error: {err}")))
266    }
267
268    fn into_rules(self) -> io::Result<Rules> {
269        Ok(Rules::new(
270            self.ipv4,
271            self.ipv6,
272            Self::compile_regex(self.name, self.rules_regex)?,
273            self.rules_set,
274            self.rules_tree,
275        ))
276    }
277}
278
279/// ACL rules
280///
281/// ## Sections
282///
283/// ACL File is formatted in sections, each section has a name with surrounded by brackets `[` and `]`
284/// followed by Rules line by line.
285///
286/// ```plain
287/// [SECTION-1]
288/// RULE-1
289/// RULE-2
290/// RULE-3
291///
292/// [SECTION-2]
293/// RULE-1
294/// RULE-2
295/// RULE-3
296/// ```
297///
298/// Available sections are
299///
300/// - For local servers (`sslocal`, `ssredir`, ...)
301///     * `[bypass_all]` - ACL runs in `BlackList` mode.
302///     * `[proxy_all]` - ACL runs in `WhiteList` mode.
303///     * `[bypass_list]` - Rules for connecting directly
304///     * `[proxy_list]` - Rules for connecting through proxies
305/// - For remote servers (`ssserver`)
306///     * `[reject_all]` - ACL runs in `BlackList` mode.
307///     * `[accept_all]` - ACL runs in `WhiteList` mode.
308///     * `[black_list]` - Rules for rejecting
309///     * `[white_list]` - Rules for allowing
310///     * `[outbound_block_list]` - Rules for blocking outbound addresses.
311///
312/// ## Mode
313///
314/// Mode is the default ACL strategy for those addresses that are not in configuration file.
315///
316/// - `BlackList` - Bypasses / Rejects all addresses except those in `[proxy_list]` or `[white_list]`
317/// - `WhiteList` - Proxies / Accepts all addresses except those in `[bypass_list]` or `[black_list]`
318///
319/// ## Rules
320///
321/// Rules can be either
322///
323/// - CIDR form network addresses, like `10.9.0.32/16`
324/// - IP addresses, like `127.0.0.1` or `::1`
325/// - Regular Expression for matching hosts, like `(^|\.)gmail\.com$`
326/// - Domain with preceding `|` for exact matching, like `|google.com`
327/// - Domain with preceding `||` for matching with subdomains, like `||google.com`
328#[derive(Debug, Clone)]
329pub struct AccessControl {
330    outbound_block: Rules,
331    black_list: Rules,
332    white_list: Rules,
333    mode: Mode,
334    file_path: PathBuf,
335}
336
337impl AccessControl {
338    /// Load ACL rules from a file
339    pub fn load_from_file<P: AsRef<Path>>(p: P) -> io::Result<AccessControl> {
340        log::trace!("ACL loading from {:?}", p.as_ref());
341
342        let file_path_ref = p.as_ref();
343        let file_path = file_path_ref.to_path_buf();
344
345        let fp = File::open(file_path_ref)?;
346        let r = BufReader::new(fp);
347
348        let mut mode = Mode::BlackList;
349
350        let mut outbound_block = ParsingRules::new("[outbound_block_list]");
351        let mut bypass = ParsingRules::new("[black_list] or [bypass_list]");
352        let mut proxy = ParsingRules::new("[white_list] or [proxy_list]");
353        let mut curr = &mut bypass;
354
355        log::trace!("ACL parsing start from mode {mode:?} and black_list / bypass_list");
356
357        for line in r.lines() {
358            let line = line?;
359            if line.is_empty() {
360                continue;
361            }
362
363            // Comments
364            if line.starts_with('#') {
365                continue;
366            }
367
368            let line = line.trim();
369
370            if !line.is_ascii() {
371                log::warn!("ACL rule {line} containing non-ASCII characters, skipped");
372                continue;
373            }
374
375            if let Some(rule) = line.strip_prefix("||") {
376                curr.add_tree_rule(rule)?;
377                continue;
378            }
379
380            if let Some(rule) = line.strip_prefix('|') {
381                curr.add_set_rule(rule)?;
382                continue;
383            }
384
385            match line {
386                "[reject_all]" | "[bypass_all]" => {
387                    mode = Mode::WhiteList;
388                    log::trace!("switch to mode {mode:?}");
389                }
390                "[accept_all]" | "[proxy_all]" => {
391                    mode = Mode::BlackList;
392                    log::trace!("switch to mode {mode:?}");
393                }
394                "[outbound_block_list]" => {
395                    curr = &mut outbound_block;
396                    log::trace!("loading outbound_block_list");
397                }
398                "[black_list]" | "[bypass_list]" => {
399                    curr = &mut bypass;
400                    log::trace!("loading black_list / bypass_list");
401                }
402                "[white_list]" | "[proxy_list]" => {
403                    curr = &mut proxy;
404                    log::trace!("loading white_list / proxy_list");
405                }
406                _ => {
407                    match line.parse::<IpNet>() {
408                        Ok(IpNet::V4(v4)) => {
409                            curr.add_ipv4_rule(v4);
410                        }
411                        Ok(IpNet::V6(v6)) => {
412                            curr.add_ipv6_rule(v6);
413                        }
414                        Err(..) => {
415                            // Maybe it is a pure IpAddr
416                            match line.parse::<IpAddr>() {
417                                Ok(IpAddr::V4(v4)) => {
418                                    curr.add_ipv4_rule(v4);
419                                }
420                                Ok(IpAddr::V6(v6)) => {
421                                    curr.add_ipv6_rule(v6);
422                                }
423                                Err(..) => {
424                                    curr.add_regex_rule(line.to_owned());
425                                }
426                            }
427                        }
428                    }
429                }
430            }
431        }
432
433        Ok(AccessControl {
434            outbound_block: outbound_block.into_rules()?,
435            black_list: bypass.into_rules()?,
436            white_list: proxy.into_rules()?,
437            mode,
438            file_path,
439        })
440    }
441
442    /// Get ACL file path
443    pub fn file_path(&self) -> &Path {
444        &self.file_path
445    }
446
447    /// Check if domain name is in proxy_list.
448    /// If so, it should be resolved from remote (for Android's DNS relay)
449    ///
450    /// Return
451    /// - `Some(true)` if `host` is in `white_list` (should be proxied)
452    /// - `Some(false)` if `host` is in `black_list` (should be bypassed)
453    /// - `None` if `host` doesn't match any rules
454    pub fn check_host_in_proxy_list(&self, host: &str) -> Option<bool> {
455        let host = Self::convert_to_ascii(host);
456        self.check_ascii_host_in_proxy_list(&host)
457    }
458
459    /// Check if ASCII domain name is in proxy_list.
460    /// If so, it should be resolved from remote (for Android's DNS relay)
461    ///
462    /// Return
463    /// - `Some(true)` if `host` is in `white_list` (should be proxied)
464    /// - `Some(false)` if `host` is in `black_list` (should be bypassed)
465    /// - `None` if `host` doesn't match any rules
466    pub fn check_ascii_host_in_proxy_list(&self, host: &str) -> Option<bool> {
467        // Addresses in proxy_list will be proxied
468        if self.white_list.check_host_matched(host) {
469            return Some(true);
470        }
471        // Addresses in bypass_list will be bypassed
472        if self.black_list.check_host_matched(host) {
473            return Some(false);
474        }
475        None
476    }
477
478    /// If there are no IP rules
479    pub fn is_ip_empty(&self) -> bool {
480        match self.mode {
481            Mode::BlackList => self.black_list.is_ip_empty(),
482            Mode::WhiteList => self.white_list.is_ip_empty(),
483        }
484    }
485
486    /// If there are no domain name rules
487    pub fn is_host_empty(&self) -> bool {
488        self.black_list.is_host_empty() && self.white_list.is_host_empty()
489    }
490
491    /// Check if `IpAddr` should be proxied
492    pub fn check_ip_in_proxy_list(&self, ip: &IpAddr) -> bool {
493        match self.mode {
494            Mode::BlackList => !self.black_list.check_ip_matched(ip),
495            Mode::WhiteList => self.white_list.check_ip_matched(ip),
496        }
497    }
498
499    /// Default mode
500    ///
501    /// Default behavior for hosts that are not configured
502    /// - `true` - Proxied
503    /// - `false` - Bypassed
504    pub fn is_default_in_proxy_list(&self) -> bool {
505        match self.mode {
506            Mode::BlackList => true,
507            Mode::WhiteList => false,
508        }
509    }
510
511    /// Returns the ASCII representation a domain name,
512    /// if conversion fails returns original string
513    fn convert_to_ascii(host: &str) -> Cow<str> {
514        idna::domain_to_ascii(host).map(From::from).unwrap_or_else(|_| host.into())
515    }
516
517    /// Check if target address should be bypassed (for client)
518    ///
519    /// This function may perform a DNS resolution
520    pub async fn check_target_bypassed(&self, addr: &Address) -> bool {
521        match *addr {
522            Address::SocketAddress(ref addr) => !self.check_ip_in_proxy_list(&addr.ip()),
523            // Resolve hostname and check the list
524            Address::DomainAddress(ref host, port) => {
525                if let Some(value) = self.check_host_in_proxy_list(host) {
526                    return !value;
527                }
528                if self.is_ip_empty() {
529                    return !self.is_default_in_proxy_list();
530                }
531                if let Ok(vaddr) = dns_resolve(host, port).await {
532                    for addr in vaddr {
533                        if !self.check_ip_in_proxy_list(&addr.ip()) {
534                            return true;
535                        }
536                    }
537                }
538                false
539            }
540        }
541    }
542
543    /// Check if client address should be blocked (for server)
544    pub fn check_client_blocked(&self, addr: &SocketAddr) -> bool {
545        match self.mode {
546            Mode::BlackList => {
547                // Only clients in black_list will be blocked
548                self.black_list.check_ip_matched(&addr.ip())
549            }
550            Mode::WhiteList => {
551                // Only clients in white_list will be proxied
552                !self.white_list.check_ip_matched(&addr.ip())
553            }
554        }
555    }
556
557    /// Check if outbound address is blocked (for server)
558    ///
559    /// NOTE: `Address::DomainAddress` is only validated by regex rules,
560    ///       resolved addresses are checked in the `lookup_outbound_then!` macro
561    pub async fn check_outbound_blocked(&self, outbound: &Address) -> bool {
562        match outbound {
563            Address::SocketAddress(saddr) => self.outbound_block.check_ip_matched(&saddr.ip()),
564            Address::DomainAddress(host, port) => {
565                if self.outbound_block.check_host_matched(&Self::convert_to_ascii(host)) {
566                    return true;
567                }
568
569                if let Ok(vaddr) = dns_resolve(host, *port).await {
570                    for addr in vaddr {
571                        if self.outbound_block.check_ip_matched(&addr.ip()) {
572                            return true;
573                        }
574                    }
575                }
576
577                false
578            }
579        }
580    }
581}
582
583async fn dns_resolve(domain: &str, port: u16) -> std::io::Result<Vec<std::net::SocketAddr>> {
584    let addrs = tokio::net::lookup_host((domain, port)).await?;
585    Ok(addrs.collect())
586}
587
588#[tokio::test]
589async fn test_dns_resolve() {
590    let addrs = dns_resolve("baidu.com", 80).await.unwrap();
591    println!("Resolved addresses: {addrs:?}");
592    assert!(!addrs.is_empty());
593
594    let addrs = dns_resolve("localhost", 80).await.unwrap();
595    println!("Resolved addresses: {addrs:?}");
596    assert!(!addrs.is_empty());
597
598    let addrs = dns_resolve("123.45.67.89", 65535).await.unwrap();
599    println!("Resolved addresses: {addrs:?}");
600    assert!(!addrs.is_empty());
601
602    let addrs = dns_resolve("xxxxsasasasd", 65535).await;
603    assert!(addrs.is_err());
604}
605
606#[test]
607fn test_acl() {
608    let acl = AccessControl::load_from_file("shadowsocks.acl").unwrap();
609
610    assert!(!acl.is_ip_empty());
611    assert!(!acl.is_host_empty());
612
613    assert!(acl.check_host_in_proxy_list("www.google.com").unwrap());
614    assert!(!acl.check_host_in_proxy_list("www.baidu.com").unwrap_or_default());
615    assert!(acl.check_host_in_proxy_list("sex.com").unwrap());
616    assert!(acl.check_host_in_proxy_list("pornhub.com").unwrap_or_default());
617    assert!(!acl.check_host_in_proxy_list("example.com").unwrap_or_default());
618    assert!(acl.check_host_in_proxy_list("youtube.com").unwrap_or_default());
619}