1use std::{
8 fmt::Display,
9 hash::{DefaultHasher, Hash, Hasher},
10 net::{IpAddr, Ipv4Addr, SocketAddr},
11 time::Duration,
12};
13
14use crate::utils::{
15 cache::CacheItemWeight,
16 config::{Config, utils::ParseValue},
17};
18use mail_auth::{
19 MessageAuthenticator,
20 hickory_resolver::{
21 TokioResolver,
22 config::{NameServerConfig, ProtocolConfig, ResolverConfig, ResolverOpts},
23 name_server::TokioConnectionProvider,
24 system_conf::read_system_conf,
25 },
26};
27use serde::{Deserialize, Serialize};
28
29use crate::Server;
30
31pub struct Resolvers {
32 pub dns: MessageAuthenticator,
33 pub dnssec: DnssecResolver,
34}
35
36#[derive(Clone)]
37pub struct DnssecResolver {
38 pub resolver: TokioResolver,
39}
40
41#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
42pub struct TlsaEntry {
43 pub is_end_entity: bool,
44 pub is_sha256: bool,
45 pub is_spki: bool,
46 pub data: Vec<u8>,
47}
48
49#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
50pub struct Tlsa {
51 pub entries: Vec<TlsaEntry>,
52 pub has_end_entities: bool,
53 pub has_intermediates: bool,
54}
55
56#[derive(Debug, PartialEq, Eq, Hash, Default, Clone, Copy, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub enum Mode {
59 Enforce,
60 Testing,
61 #[default]
62 None,
63}
64
65#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub enum MxPattern {
68 Equals(String),
69 StartsWith(String),
70}
71
72#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
73pub struct Policy {
74 pub id: String,
75 pub mode: Mode,
76 pub mx: Vec<MxPattern>,
77 pub max_age: u64,
78}
79
80impl CacheItemWeight for Tlsa {
81 fn weight(&self) -> u64 {
82 self.entries
83 .iter()
84 .map(|entry| (entry.data.len() + std::mem::size_of::<TlsaEntry>()) as u64)
85 .sum::<u64>()
86 + std::mem::size_of::<Tlsa>() as u64
87 }
88}
89
90impl CacheItemWeight for Policy {
91 fn weight(&self) -> u64 {
92 (std::mem::size_of::<Policy>()
93 + self
94 .mx
95 .iter()
96 .map(|mx| match mx {
97 MxPattern::Equals(t) => t.len(),
98 MxPattern::StartsWith(t) => t.len(),
99 })
100 .sum::<usize>()) as u64
101 }
102}
103
104impl Resolvers {
105 pub async fn parse(config: &mut Config) -> Self {
106 let (resolver_config, mut opts) = match config.value("resolver.type").unwrap_or("system") {
107 "cloudflare" => (ResolverConfig::cloudflare(), ResolverOpts::default()),
108 "cloudflare-tls" => (ResolverConfig::cloudflare_tls(), ResolverOpts::default()),
109 "quad9" => (ResolverConfig::quad9(), ResolverOpts::default()),
110 "quad9-tls" => (ResolverConfig::quad9_tls(), ResolverOpts::default()),
111 "google" => (ResolverConfig::google(), ResolverOpts::default()),
112 "system" => read_system_conf()
113 .map_err(|err| {
114 config.new_build_error(
115 "resolver.type",
116 format!("Failed to read system DNS config: {err}"),
117 )
118 })
119 .unwrap_or_else(|_| (ResolverConfig::cloudflare(), ResolverOpts::default())),
120 "custom" => {
121 let mut resolver_config = ResolverConfig::default();
122 for url in config
123 .values("resolver.custom")
124 .map(|(_, v)| v.to_string())
125 .collect::<Vec<_>>()
126 {
127 let (proto, host) = if let Some((proto, host)) = url
128 .split_once("://")
129 .map(|(a, b)| (a.to_string(), b.to_string()))
130 {
131 (
132 match proto.as_str() {
133 "udp" => ProtocolConfig::Udp,
134 "tcp" => ProtocolConfig::Tcp,
135 "tls" => ProtocolConfig::Tls {
136 server_name: host.clone().into(),
137 },
138 _ => {
139 config.new_parse_error(
140 "resolver.custom",
141 format!("Invalid custom resolver protocol {url:?}"),
142 );
143 ProtocolConfig::Udp
144 }
145 },
146 host.to_string(),
147 )
148 } else {
149 (ProtocolConfig::Udp, url)
150 };
151
152 let (host, port) = if let Some(host) = host.strip_prefix('[') {
153 let (host, maybe_port) = host.rsplit_once(']').unwrap_or_default();
154
155 (
156 host,
157 maybe_port
158 .rsplit_once(':')
159 .map(|(_, port)| port)
160 .unwrap_or("53"),
161 )
162 } else if let Some((host, port)) = host.split_once(':') {
163 (host, port)
164 } else {
165 (host.as_str(), "53")
166 };
167
168 let port = port
169 .parse::<u16>()
170 .map_err(|err| {
171 config.new_parse_error(
172 "resolver.custom",
173 format!("Invalid custom resolver port {port:?}: {err}"),
174 );
175 })
176 .unwrap_or(53);
177
178 let host = host
179 .parse::<IpAddr>()
180 .map_err(|err| {
181 config.new_parse_error(
182 "resolver.custom",
183 format!("Invalid custom resolver IP {host:?}: {err}"),
184 )
185 })
186 .unwrap_or(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
187 resolver_config
188 .add_name_server(NameServerConfig::new(SocketAddr::new(host, port), proto));
189 }
190 if !resolver_config.name_servers().is_empty() {
191 (resolver_config, ResolverOpts::default())
192 } else {
193 config.new_parse_error(
194 "resolver.custom",
195 "At least one custom resolver must be specified.",
196 );
197 (ResolverConfig::cloudflare(), ResolverOpts::default())
198 }
199 }
200 other => {
201 let err = format!("Unknown resolver type {other:?}.");
202 config.new_parse_error("resolver.custom", err);
203 (ResolverConfig::cloudflare(), ResolverOpts::default())
204 }
205 };
206 if let Some(concurrency) = config.property("resolver.concurrency") {
207 opts.num_concurrent_reqs = concurrency;
208 }
209 if let Some(timeout) = config.property("resolver.timeout") {
210 opts.timeout = timeout;
211 }
212 if let Some(preserve) = config.property("resolver.preserve-intermediates") {
213 opts.preserve_intermediates = preserve;
214 }
215 if let Some(try_tcp_on_error) = config.property("resolver.try-tcp-on-error") {
216 opts.try_tcp_on_error = try_tcp_on_error;
217 }
218 if let Some(attempts) = config.property("resolver.attempts") {
219 opts.attempts = attempts;
220 }
221 opts.edns0 = config
222 .property_or_default("resolver.edns", "true")
223 .unwrap_or(true);
224
225 opts.cache_size = 0;
227
228 let config_dnssec = resolver_config.clone();
230 let mut opts_dnssec = opts.clone();
231 opts_dnssec.validate = true;
232
233 Resolvers {
234 dns: MessageAuthenticator::new(resolver_config, opts).unwrap(),
235 dnssec: DnssecResolver {
236 resolver: TokioResolver::builder_with_config(
237 config_dnssec,
238 TokioConnectionProvider::default(),
239 )
240 .with_options(opts_dnssec)
241 .build(),
242 },
243 }
244 }
245}
246
247impl Policy {
248 pub fn try_parse(config: &mut Config) -> Option<Self> {
249 let mode = config
250 .property_or_default::<Option<Mode>>("session.mta-sts.mode", "testing")
251 .unwrap_or_default()?;
252 let max_age = config
253 .property_or_default::<Duration>("session.mta-sts.max-age", "7d")
254 .unwrap_or_else(|| Duration::from_secs(604800))
255 .as_secs();
256 let mut mx = Vec::new();
257
258 for (_, item) in config.values("session.mta-sts.mx") {
259 if let Some(item) = item.strip_prefix("*.") {
260 mx.push(MxPattern::StartsWith(item.to_string()));
261 } else {
262 mx.push(MxPattern::Equals(item.to_string()));
263 }
264 }
265
266 let mut policy = Self {
267 id: Default::default(),
268 mode,
269 mx,
270 max_age,
271 };
272
273 if !policy.mx.is_empty() {
274 policy.mx.sort_unstable();
275 policy.id = policy.hash().to_string();
276 }
277
278 policy.into()
279 }
280
281 pub fn try_build<I, T>(mut self, names: I) -> Option<Self>
282 where
283 I: IntoIterator<Item = T>,
284 T: AsRef<str>,
285 {
286 if self.mx.is_empty() {
287 for name in names {
288 let name = name.as_ref();
289 if let Some(domain) = name.strip_prefix('.') {
290 self.mx.push(MxPattern::StartsWith(domain.to_string()));
291 } else if name != "*" && !name.is_empty() {
292 self.mx.push(MxPattern::Equals(name.to_string()));
293 }
294 }
295
296 if !self.mx.is_empty() {
297 self.mx.sort_unstable();
298 self.id = self.hash().to_string();
299 Some(self)
300 } else {
301 None
302 }
303 } else {
304 Some(self)
305 }
306 }
307
308 fn hash(&self) -> u64 {
309 let mut s = DefaultHasher::new();
310 self.mode.hash(&mut s);
311 self.max_age.hash(&mut s);
312 self.mx.hash(&mut s);
313 s.finish()
314 }
315}
316
317impl Server {
318 pub fn build_mta_sts_policy(&self) -> Option<Policy> {
319 self.core
320 .smtp
321 .session
322 .mta_sts_policy
323 .clone()
324 .and_then(|policy| {
325 policy.try_build(
326 self.inner
327 .data
328 .tls_certificates
329 .load()
330 .keys()
331 .filter(|key| {
332 !key.starts_with("mta-sts.")
333 && !key.starts_with("autoconfig.")
334 && !key.starts_with("autodiscover.")
335 }),
336 )
337 })
338 }
339}
340
341impl ParseValue for Mode {
342 fn parse_value(value: &str) -> Result<Self, String> {
343 match value {
344 "enforce" => Ok(Self::Enforce),
345 "testing" | "test" => Ok(Self::Testing),
346 "none" => Ok(Self::None),
347 _ => Err(format!("Invalid mode value {value:?}")),
348 }
349 }
350}
351
352impl Default for Resolvers {
353 fn default() -> Self {
354 let (config, opts) = match read_system_conf() {
355 Ok(conf) => conf,
356 Err(_) => (ResolverConfig::cloudflare(), ResolverOpts::default()),
357 };
358
359 let config_dnssec = config.clone();
360 let mut opts_dnssec = opts.clone();
361 opts_dnssec.validate = true;
362
363 Self {
364 dns: MessageAuthenticator::new(config, opts).expect("Failed to build DNS resolver"),
365 dnssec: DnssecResolver {
366 resolver: TokioResolver::builder_with_config(
367 config_dnssec,
368 TokioConnectionProvider::default(),
369 )
370 .with_options(opts_dnssec)
371 .build(),
372 },
373 }
374 }
375}
376
377impl Display for Policy {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 f.write_str("version: STSv1\r\n")?;
380 f.write_str("mode: ")?;
381 match self.mode {
382 Mode::Enforce => f.write_str("enforce")?,
383 Mode::Testing => f.write_str("testing")?,
384 Mode::None => unreachable!(),
385 }
386 f.write_str("\r\nmax_age: ")?;
387 self.max_age.fmt(f)?;
388 f.write_str("\r\n")?;
389
390 for mx in &self.mx {
391 f.write_str("mx: ")?;
392 mx.fmt(f)?;
393 f.write_str("\r\n")?;
394 }
395
396 Ok(())
397 }
398}
399
400impl Display for MxPattern {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 match self {
403 MxPattern::Equals(mx) => f.write_str(mx),
404 MxPattern::StartsWith(mx) => {
405 f.write_str("*.")?;
406 f.write_str(mx)
407 }
408 }
409 }
410}
411
412impl Clone for Resolvers {
413 fn clone(&self) -> Self {
414 Self {
415 dns: self.dns.clone(),
416 dnssec: self.dnssec.clone(),
417 }
418 }
419}