tag2upload_service_manager/
dns.rs

1
2use crate::prelude::*;
3
4#[derive(Debug, Clone, Eq, PartialEq)]
5pub struct DnsGlobPattern(glob::Pattern);
6
7#[derive(Debug, Clone, Eq, PartialEq, Deftly)]
8#[derive_deftly(DeserializeViaFromStr)]
9#[deftly(deser(expect = "dns glob pattern or IP address mask"))]
10pub enum AllowedClient {
11    Addr(IpNet),
12    /// Names must not contain `:`
13    /// and their last dotted component must contain some nondigits.
14    Name(DnsGlobPattern),
15}
16
17pub struct IsAllowedClient { _hidden: () }
18
19#[derive(Error, Debug, Clone)]
20#[error("syntactically invalid IP network or DNS name (or glob pattern)")]
21pub struct InvalidAllowedClient;
22
23#[derive(Error, Debug, Clone)]
24#[error("client not permitted: {client}")]
25pub struct DisallowedClient {
26    client: ActualClient,
27}
28
29pub type Resolver = hickory_resolver::TokioResolver;
30use hickory_resolver::ResolveError;
31
32impl IsAllowedClient {
33    pub fn new_unchecked() -> Self {
34        IsAllowedClient { _hidden: () }
35    }
36}
37
38/// Actual calling IP address, possibly DNS resolved
39#[derive(Debug, Clone)]
40pub struct ActualClient {
41    addr: IpAddr,
42    names: OnceLock<Result<Vec<String>, ResolveError>>,
43}
44
45impl ActualClient {
46    pub fn new(addr: IpAddr) -> Self {
47        ActualClient { addr, names: OnceLock::new() }
48    }
49}
50
51impl Display for ActualClient {
52    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
53        write!(f, "[{}]", self.addr)?;
54        match &self.names.get() {
55            None => {},
56            Some(Ok(names)) => {
57                for n in names {
58                    write!(f, "/{n:?}")?;
59                }
60            }
61            Some(Err(e)) => {
62                write!(f, "({})", e)?;
63            }
64        }
65        Ok(())
66    }
67}
68
69impl FromStr for DnsGlobPattern {
70    type Err = InvalidAllowedClient;
71
72    fn from_str(s: &str) -> Result<Self, InvalidAllowedClient> {
73        (|| {
74            if s.contains(':') {
75                return None;
76            }
77            if ! "*[".chars().any(|c| s.contains(c)) {
78                // It's not yet definitely a glob pattern;
79                // nor is it definitely an IPv6 address.
80                // See if the last component is all digits - then
81                // it's probably a (possibly invalid) IP address.
82                let rhs = s.rsplit_once('.').map(|(_, r)| r).unwrap_or(s);
83                if rhs.chars().all(|c| c.is_ascii_digit()) {
84                    return None;
85                }
86            }
87            let s = s.to_ascii_lowercase();
88            let g = glob::Pattern::new(&s).ok()?;
89            Some(DnsGlobPattern(g))
90        })().ok_or(InvalidAllowedClient)
91    }
92}
93
94impl FromStr for AllowedClient {
95    type Err = InvalidAllowedClient;
96    fn from_str(s: &str) -> Result<Self, InvalidAllowedClient> {
97        Ok(if let Ok(net) = s.parse() {
98            AllowedClient::Addr(net)
99        } else if let Ok(pat) = s.parse() {
100            AllowedClient::Name(pat)
101        } else {
102            return Err(InvalidAllowedClient);
103        })
104    }
105}
106
107impl ActualClient {
108    /// The `String` is the hostname, if it was found
109    pub async fn allowed_by(
110        &self,
111        allowed: &[AllowedClient],
112    ) -> Result<IsAllowedClient, DisallowedClient> {
113        self.allowed_by_inner(allowed).await
114            .map_err(|()| DisallowedClient { client: self.clone() })
115    }
116
117    pub async fn allowed_by_inner(
118        &self,
119        allowed: &[AllowedClient],
120    ) -> Result<IsAllowedClient, ()> {
121        let mut any_names = None;
122
123        if let Some(y) = allowed.iter().find_map(|a| match a {
124            AllowedClient::Addr(a) => a.contains(&self.addr)
125                .then(|| IsAllowedClient::new_unchecked()),
126            AllowedClient::Name(_) => {
127                any_names = Some(());
128                None
129            }
130        }) {
131            return Ok(y);
132        }
133
134        let () = any_names.ok_or(())?;
135        let names = self.resolve().await?;
136
137        if let Some(y) = allowed.iter().find_map(|a| {
138            let AllowedClient::Name(g) = a
139            else { return None };
140
141            names.iter().find_map(|n| {
142                g.0.matches(n)
143                    .then(|| IsAllowedClient::new_unchecked())
144            })
145        }) {
146            return Ok(y)
147        }
148
149        Err(())
150    }
151
152    async fn resolve(&self) -> Result<&[String], ()> {
153        if self.names.get().is_none() {
154            // (In theory we might run this twice in parallel;
155            // but we don't share &ActualClient, so that doesn't happen.)
156            let new_names = Self::resolve_inner(self.addr).await;
157            self.names.set(new_names)
158                .unwrap_or_else(|_: Result<Vec<String>, ResolveError>| ());
159        }
160        self.names
161            .get().expect("just resolved")
162            .as_deref()
163            .map_err(|_| ())
164    }
165
166    async fn resolve_inner(addr: IpAddr) -> Result<Vec<String>, ResolveError> {
167        let names = globals().dns_resolver
168            .reverse_lookup(addr).await?
169            .iter()
170            .map(|ptr| {
171                let name = ptr.to_lowercase().to_ascii();
172                // hickory returns a `PTR` which contains one of its `Name`s
173                // for which its `.to_ascii()` produces trailing `.`.  This
174                // will always be absolute since that's what a PTR is.  (This
175                // code copes if hickory does something different in future.)
176                let name = name.strip_suffix('.').unwrap_or(&name);
177                name.to_owned()
178            })
179            .collect_vec();
180        Ok(names)
181    }
182}
183
184impl DisallowedClient {
185    pub fn http_status(&self) -> rocket::http::Status {
186        use rocket::http::Status as S;
187        match self.client.names.get() {
188            None | Some(Ok(_)) => S::Forbidden,
189            // DNS resolution failure, best error code is probably 503
190            Some(Err(_)) => S::InternalServerError,
191        }
192    }
193}
194
195#[test]
196fn chk_allowed_client() {
197    let chk_name = |s: &str| {
198        let allow: AllowedClient = s.parse().expect(s);
199        let glob = DnsGlobPattern(s.parse().expect(s));
200        assert_eq!(allow, AllowedClient::Name(glob));
201    };
202    let chk_addr = |s: &str| {
203        let allow: AllowedClient = s.parse().expect(s);
204        let addr = s.parse().expect(s);
205        assert_eq!(allow, AllowedClient::Addr(addr));
206    };
207    let chk_err = |s: &str| {
208        s.parse::<AllowedClient>().err().expect(s);
209    };
210    chk_name("*");
211    chk_name("a");
212    chk_name("a.b");
213    chk_name("3com");
214    chk_name("a.3com");
215    chk_name("[a]");
216
217    chk_addr("::1/128");
218    chk_addr("12.0.0.1/32");
219    chk_addr("::/0");
220    chk_addr("0.0.0.0/0");
221
222    chk_err("1");
223    chk_err("127.0.0.1");
224    chk_err("::1");
225}