tag2upload_service_manager/
dns.rs
use crate::prelude::*;
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct DnsGlobPattern(glob::Pattern);
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum AllowedCaller {
Addr(IpNet),
Name(DnsGlobPattern),
}
pub struct IsAllowedCaller { _hidden: () }
#[derive(Error, Debug, Clone)]
#[error("syntactically invalid IP network or DNS name (or glob pattern)")]
pub struct InvalidAllowedCaller;
#[derive(Error, Debug, Clone)]
pub enum DisallowedCaller {
Unresolved { addr: IpAddr },
Resolved { addr: IpAddr, names: Vec<String> },
}
pub type Resolver = hickory_resolver::TokioAsyncResolver;
impl IsAllowedCaller {
pub fn new_unchecked() -> Self {
IsAllowedCaller { _hidden: () }
}
}
impl Display for DisallowedCaller {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
DisallowedCaller::Unresolved { addr } =>
write!(f, "unrecognised calling IP address {addr}")?,
DisallowedCaller::Resolved { addr, names } => {
write!(f, "unrecognised calling host [{addr}]")?;
for n in names {
write!(f, "/{n:?}")?;
}
}
}
Ok(())
}
}
impl FromStr for DnsGlobPattern {
type Err = InvalidAllowedCaller;
fn from_str(s: &str) -> Result<Self, InvalidAllowedCaller> {
(|| {
if s.contains(':') {
return None;
}
if ! "*[".chars().any(|c| s.contains(c)) {
let rhs = s.rsplit_once('.').map(|(_, r)| r).unwrap_or(s);
if rhs.chars().all(|c| c.is_ascii_digit()) {
return None;
}
}
let s = s.to_ascii_lowercase();
let g = glob::Pattern::new(&s).ok()?;
Some(DnsGlobPattern(g))
})().ok_or(InvalidAllowedCaller)
}
}
impl FromStr for AllowedCaller {
type Err = InvalidAllowedCaller;
fn from_str(s: &str) -> Result<Self, InvalidAllowedCaller> {
Ok(if let Ok(net) = s.parse() {
AllowedCaller::Addr(net)
} else if let Ok(pat) = s.parse() {
AllowedCaller::Name(pat)
} else {
return Err(InvalidAllowedCaller);
})
}
}
impl<'de> Deserialize<'de> for AllowedCaller {
fn deserialize<D: Deserializer<'de>>(
deser: D,
) -> Result<AllowedCaller, D::Error> {
let s: String = String::deserialize(deser)?;
s.parse()
.map_err(|_| D::Error::invalid_value(
serde::de::Unexpected::Str(&s),
&"dns glob pattern or IP address mask",
))
}
}
impl AllowedCaller {
pub async fn list_contains(
allowed: &[AllowedCaller],
addr: IpAddr,
) -> Result<Result<IsAllowedCaller, DisallowedCaller>, AE> {
let mut any_names = None;
if let Some(y) = allowed.iter().find_map(|a| match a {
AllowedCaller::Addr(a) => a.contains(&addr)
.then(|| IsAllowedCaller::new_unchecked()),
AllowedCaller::Name(_) => {
any_names = Some(());
None
}
}) {
return Ok(Ok(y));
}
let Some(()) = any_names
else { return Ok(Err(DisallowedCaller::Unresolved { addr })) };
let names = globals().dns_resolver
.reverse_lookup(addr).await
.context("reverse lookup for {addr}")?
.iter()
.map(|ptr| ptr.to_lowercase().to_ascii())
.collect_vec();
if let Some(y) = allowed.iter().find_map(|a| {
let AllowedCaller::Name(g) = a
else { return None };
names.iter().find_map(|n| {
g.0.matches(n)
.then(|| IsAllowedCaller::new_unchecked())
})
}) {
return Ok(Ok(y))
}
Ok(Err(DisallowedCaller::Resolved { addr, names }))
}
}
#[test]
fn chk_allowed_caller() {
let chk_name = |s: &str| {
let allow: AllowedCaller = s.parse().expect(s);
let glob = DnsGlobPattern(s.parse().expect(s));
assert_eq!(allow, AllowedCaller::Name(glob));
};
let chk_addr = |s: &str| {
let allow: AllowedCaller = s.parse().expect(s);
let addr = s.parse().expect(s);
assert_eq!(allow, AllowedCaller::Addr(addr));
};
let chk_err = |s: &str| {
s.parse::<AllowedCaller>().err().expect(s);
};
chk_name("*");
chk_name("a");
chk_name("a.b");
chk_name("3com");
chk_name("a.3com");
chk_name("[a]");
chk_addr("::1/128");
chk_addr("12.0.0.1/32");
chk_addr("::/0");
chk_addr("0.0.0.0/0");
chk_err("1");
chk_err("127.0.0.1");
chk_err("::1");
}