Skip to main content

stalwart_lib/common/src/config/smtp/
resolver.rs

1/*
2 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
3 *
4 * SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
5 */
6
7use std::{
8    fmt::Display,
9    hash::{DefaultHasher, Hash, Hasher},
10    net::{IpAddr, Ipv4Addr, SocketAddr},
11    time::Duration,
12};
13
14use crate::utils::{
15    cache::CacheItemWeight,
16    config::{Config, utils::ParseValue},
17};
18use mail_auth::{
19    MessageAuthenticator,
20    hickory_resolver::{
21        TokioResolver,
22        config::{NameServerConfig, ProtocolConfig, ResolverConfig, ResolverOpts},
23        name_server::TokioConnectionProvider,
24        system_conf::read_system_conf,
25    },
26};
27use serde::{Deserialize, Serialize};
28
29use crate::Server;
30
31pub struct Resolvers {
32    pub dns: MessageAuthenticator,
33    pub dnssec: DnssecResolver,
34}
35
36#[derive(Clone)]
37pub struct DnssecResolver {
38    pub resolver: TokioResolver,
39}
40
41#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
42pub struct TlsaEntry {
43    pub is_end_entity: bool,
44    pub is_sha256: bool,
45    pub is_spki: bool,
46    pub data: Vec<u8>,
47}
48
49#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
50pub struct Tlsa {
51    pub entries: Vec<TlsaEntry>,
52    pub has_end_entities: bool,
53    pub has_intermediates: bool,
54}
55
56#[derive(Debug, PartialEq, Eq, Hash, Default, Clone, Copy, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub enum Mode {
59    Enforce,
60    Testing,
61    #[default]
62    None,
63}
64
65#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub enum MxPattern {
68    Equals(String),
69    StartsWith(String),
70}
71
72#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
73pub struct Policy {
74    pub id: String,
75    pub mode: Mode,
76    pub mx: Vec<MxPattern>,
77    pub max_age: u64,
78}
79
80impl CacheItemWeight for Tlsa {
81    fn weight(&self) -> u64 {
82        self.entries
83            .iter()
84            .map(|entry| (entry.data.len() + std::mem::size_of::<TlsaEntry>()) as u64)
85            .sum::<u64>()
86            + std::mem::size_of::<Tlsa>() as u64
87    }
88}
89
90impl CacheItemWeight for Policy {
91    fn weight(&self) -> u64 {
92        (std::mem::size_of::<Policy>()
93            + self
94                .mx
95                .iter()
96                .map(|mx| match mx {
97                    MxPattern::Equals(t) => t.len(),
98                    MxPattern::StartsWith(t) => t.len(),
99                })
100                .sum::<usize>()) as u64
101    }
102}
103
104impl Resolvers {
105    pub async fn parse(config: &mut Config) -> Self {
106        let (resolver_config, mut opts) = match config.value("resolver.type").unwrap_or("system") {
107            "cloudflare" => (ResolverConfig::cloudflare(), ResolverOpts::default()),
108            "cloudflare-tls" => (ResolverConfig::cloudflare_tls(), ResolverOpts::default()),
109            "quad9" => (ResolverConfig::quad9(), ResolverOpts::default()),
110            "quad9-tls" => (ResolverConfig::quad9_tls(), ResolverOpts::default()),
111            "google" => (ResolverConfig::google(), ResolverOpts::default()),
112            "system" => read_system_conf()
113                .map_err(|err| {
114                    config.new_build_error(
115                        "resolver.type",
116                        format!("Failed to read system DNS config: {err}"),
117                    )
118                })
119                .unwrap_or_else(|_| (ResolverConfig::cloudflare(), ResolverOpts::default())),
120            "custom" => {
121                let mut resolver_config = ResolverConfig::default();
122                for url in config
123                    .values("resolver.custom")
124                    .map(|(_, v)| v.to_string())
125                    .collect::<Vec<_>>()
126                {
127                    let (proto, host) = if let Some((proto, host)) = url
128                        .split_once("://")
129                        .map(|(a, b)| (a.to_string(), b.to_string()))
130                    {
131                        (
132                            match proto.as_str() {
133                                "udp" => ProtocolConfig::Udp,
134                                "tcp" => ProtocolConfig::Tcp,
135                                "tls" => ProtocolConfig::Tls {
136                                    server_name: host.clone().into(),
137                                },
138                                _ => {
139                                    config.new_parse_error(
140                                        "resolver.custom",
141                                        format!("Invalid custom resolver protocol {url:?}"),
142                                    );
143                                    ProtocolConfig::Udp
144                                }
145                            },
146                            host.to_string(),
147                        )
148                    } else {
149                        (ProtocolConfig::Udp, url)
150                    };
151
152                    let (host, port) = if let Some(host) = host.strip_prefix('[') {
153                        let (host, maybe_port) = host.rsplit_once(']').unwrap_or_default();
154
155                        (
156                            host,
157                            maybe_port
158                                .rsplit_once(':')
159                                .map(|(_, port)| port)
160                                .unwrap_or("53"),
161                        )
162                    } else if let Some((host, port)) = host.split_once(':') {
163                        (host, port)
164                    } else {
165                        (host.as_str(), "53")
166                    };
167
168                    let port = port
169                        .parse::<u16>()
170                        .map_err(|err| {
171                            config.new_parse_error(
172                                "resolver.custom",
173                                format!("Invalid custom resolver port {port:?}: {err}"),
174                            );
175                        })
176                        .unwrap_or(53);
177
178                    let host = host
179                        .parse::<IpAddr>()
180                        .map_err(|err| {
181                            config.new_parse_error(
182                                "resolver.custom",
183                                format!("Invalid custom resolver IP {host:?}: {err}"),
184                            )
185                        })
186                        .unwrap_or(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
187                    resolver_config
188                        .add_name_server(NameServerConfig::new(SocketAddr::new(host, port), proto));
189                }
190                if !resolver_config.name_servers().is_empty() {
191                    (resolver_config, ResolverOpts::default())
192                } else {
193                    config.new_parse_error(
194                        "resolver.custom",
195                        "At least one custom resolver must be specified.",
196                    );
197                    (ResolverConfig::cloudflare(), ResolverOpts::default())
198                }
199            }
200            other => {
201                let err = format!("Unknown resolver type {other:?}.");
202                config.new_parse_error("resolver.custom", err);
203                (ResolverConfig::cloudflare(), ResolverOpts::default())
204            }
205        };
206        if let Some(concurrency) = config.property("resolver.concurrency") {
207            opts.num_concurrent_reqs = concurrency;
208        }
209        if let Some(timeout) = config.property("resolver.timeout") {
210            opts.timeout = timeout;
211        }
212        if let Some(preserve) = config.property("resolver.preserve-intermediates") {
213            opts.preserve_intermediates = preserve;
214        }
215        if let Some(try_tcp_on_error) = config.property("resolver.try-tcp-on-error") {
216            opts.try_tcp_on_error = try_tcp_on_error;
217        }
218        if let Some(attempts) = config.property("resolver.attempts") {
219            opts.attempts = attempts;
220        }
221        opts.edns0 = config
222            .property_or_default("resolver.edns", "true")
223            .unwrap_or(true);
224
225        // We already have a cache, so disable the built-in cache
226        opts.cache_size = 0;
227
228        // Prepare DNSSEC resolver options
229        let config_dnssec = resolver_config.clone();
230        let mut opts_dnssec = opts.clone();
231        opts_dnssec.validate = true;
232
233        Resolvers {
234            dns: MessageAuthenticator::new(resolver_config, opts).unwrap(),
235            dnssec: DnssecResolver {
236                resolver: TokioResolver::builder_with_config(
237                    config_dnssec,
238                    TokioConnectionProvider::default(),
239                )
240                .with_options(opts_dnssec)
241                .build(),
242            },
243        }
244    }
245}
246
247impl Policy {
248    pub fn try_parse(config: &mut Config) -> Option<Self> {
249        let mode = config
250            .property_or_default::<Option<Mode>>("session.mta-sts.mode", "testing")
251            .unwrap_or_default()?;
252        let max_age = config
253            .property_or_default::<Duration>("session.mta-sts.max-age", "7d")
254            .unwrap_or_else(|| Duration::from_secs(604800))
255            .as_secs();
256        let mut mx = Vec::new();
257
258        for (_, item) in config.values("session.mta-sts.mx") {
259            if let Some(item) = item.strip_prefix("*.") {
260                mx.push(MxPattern::StartsWith(item.to_string()));
261            } else {
262                mx.push(MxPattern::Equals(item.to_string()));
263            }
264        }
265
266        let mut policy = Self {
267            id: Default::default(),
268            mode,
269            mx,
270            max_age,
271        };
272
273        if !policy.mx.is_empty() {
274            policy.mx.sort_unstable();
275            policy.id = policy.hash().to_string();
276        }
277
278        policy.into()
279    }
280
281    pub fn try_build<I, T>(mut self, names: I) -> Option<Self>
282    where
283        I: IntoIterator<Item = T>,
284        T: AsRef<str>,
285    {
286        if self.mx.is_empty() {
287            for name in names {
288                let name = name.as_ref();
289                if let Some(domain) = name.strip_prefix('.') {
290                    self.mx.push(MxPattern::StartsWith(domain.to_string()));
291                } else if name != "*" && !name.is_empty() {
292                    self.mx.push(MxPattern::Equals(name.to_string()));
293                }
294            }
295
296            if !self.mx.is_empty() {
297                self.mx.sort_unstable();
298                self.id = self.hash().to_string();
299                Some(self)
300            } else {
301                None
302            }
303        } else {
304            Some(self)
305        }
306    }
307
308    fn hash(&self) -> u64 {
309        let mut s = DefaultHasher::new();
310        self.mode.hash(&mut s);
311        self.max_age.hash(&mut s);
312        self.mx.hash(&mut s);
313        s.finish()
314    }
315}
316
317impl Server {
318    pub fn build_mta_sts_policy(&self) -> Option<Policy> {
319        self.core
320            .smtp
321            .session
322            .mta_sts_policy
323            .clone()
324            .and_then(|policy| {
325                policy.try_build(
326                    self.inner
327                        .data
328                        .tls_certificates
329                        .load()
330                        .keys()
331                        .filter(|key| {
332                            !key.starts_with("mta-sts.")
333                                && !key.starts_with("autoconfig.")
334                                && !key.starts_with("autodiscover.")
335                        }),
336                )
337            })
338    }
339}
340
341impl ParseValue for Mode {
342    fn parse_value(value: &str) -> Result<Self, String> {
343        match value {
344            "enforce" => Ok(Self::Enforce),
345            "testing" | "test" => Ok(Self::Testing),
346            "none" => Ok(Self::None),
347            _ => Err(format!("Invalid mode value {value:?}")),
348        }
349    }
350}
351
352impl Default for Resolvers {
353    fn default() -> Self {
354        let (config, opts) = match read_system_conf() {
355            Ok(conf) => conf,
356            Err(_) => (ResolverConfig::cloudflare(), ResolverOpts::default()),
357        };
358
359        let config_dnssec = config.clone();
360        let mut opts_dnssec = opts.clone();
361        opts_dnssec.validate = true;
362
363        Self {
364            dns: MessageAuthenticator::new(config, opts).expect("Failed to build DNS resolver"),
365            dnssec: DnssecResolver {
366                resolver: TokioResolver::builder_with_config(
367                    config_dnssec,
368                    TokioConnectionProvider::default(),
369                )
370                .with_options(opts_dnssec)
371                .build(),
372            },
373        }
374    }
375}
376
377impl Display for Policy {
378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379        f.write_str("version: STSv1\r\n")?;
380        f.write_str("mode: ")?;
381        match self.mode {
382            Mode::Enforce => f.write_str("enforce")?,
383            Mode::Testing => f.write_str("testing")?,
384            Mode::None => unreachable!(),
385        }
386        f.write_str("\r\nmax_age: ")?;
387        self.max_age.fmt(f)?;
388        f.write_str("\r\n")?;
389
390        for mx in &self.mx {
391            f.write_str("mx: ")?;
392            mx.fmt(f)?;
393            f.write_str("\r\n")?;
394        }
395
396        Ok(())
397    }
398}
399
400impl Display for MxPattern {
401    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402        match self {
403            MxPattern::Equals(mx) => f.write_str(mx),
404            MxPattern::StartsWith(mx) => {
405                f.write_str("*.")?;
406                f.write_str(mx)
407            }
408        }
409    }
410}
411
412impl Clone for Resolvers {
413    fn clone(&self) -> Self {
414        Self {
415            dns: self.dns.clone(),
416            dnssec: self.dnssec.clone(),
417        }
418    }
419}