Skip to main content

seer_core/dns/
records.rs

1use serde::{Deserialize, Serialize};
2use std::fmt;
3use std::str::FromStr;
4
5use crate::error::{Result, SeerError};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
8#[serde(rename_all = "UPPERCASE")]
9pub enum RecordType {
10    A,
11    AAAA,
12    CNAME,
13    MX,
14    NS,
15    TXT,
16    SOA,
17    PTR,
18    SRV,
19    CAA,
20    NAPTR,
21    DNSKEY,
22    DS,
23    TLSA,
24    SSHFP,
25    ANY,
26}
27
28impl fmt::Display for RecordType {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            RecordType::A => write!(f, "A"),
32            RecordType::AAAA => write!(f, "AAAA"),
33            RecordType::CNAME => write!(f, "CNAME"),
34            RecordType::MX => write!(f, "MX"),
35            RecordType::NS => write!(f, "NS"),
36            RecordType::TXT => write!(f, "TXT"),
37            RecordType::SOA => write!(f, "SOA"),
38            RecordType::PTR => write!(f, "PTR"),
39            RecordType::SRV => write!(f, "SRV"),
40            RecordType::CAA => write!(f, "CAA"),
41            RecordType::NAPTR => write!(f, "NAPTR"),
42            RecordType::DNSKEY => write!(f, "DNSKEY"),
43            RecordType::DS => write!(f, "DS"),
44            RecordType::TLSA => write!(f, "TLSA"),
45            RecordType::SSHFP => write!(f, "SSHFP"),
46            RecordType::ANY => write!(f, "ANY"),
47        }
48    }
49}
50
51impl FromStr for RecordType {
52    type Err = SeerError;
53
54    fn from_str(s: &str) -> Result<Self> {
55        match s.to_uppercase().as_str() {
56            "A" => Ok(RecordType::A),
57            "AAAA" => Ok(RecordType::AAAA),
58            "CNAME" => Ok(RecordType::CNAME),
59            "MX" => Ok(RecordType::MX),
60            "NS" => Ok(RecordType::NS),
61            "TXT" => Ok(RecordType::TXT),
62            "SOA" => Ok(RecordType::SOA),
63            "PTR" => Ok(RecordType::PTR),
64            "SRV" => Ok(RecordType::SRV),
65            "CAA" => Ok(RecordType::CAA),
66            "NAPTR" => Ok(RecordType::NAPTR),
67            "DNSKEY" => Ok(RecordType::DNSKEY),
68            "DS" => Ok(RecordType::DS),
69            "TLSA" => Ok(RecordType::TLSA),
70            "SSHFP" => Ok(RecordType::SSHFP),
71            "ANY" | "*" => Ok(RecordType::ANY),
72            _ => Err(SeerError::InvalidRecordType(s.to_string())),
73        }
74    }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct DnsRecord {
79    pub name: String,
80    pub record_type: RecordType,
81    pub ttl: u32,
82    pub data: RecordData,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(tag = "record_type", content = "value", rename_all = "UPPERCASE")]
87#[allow(clippy::upper_case_acronyms)]
88pub enum RecordData {
89    A {
90        address: String,
91    },
92    AAAA {
93        address: String,
94    },
95    CNAME {
96        target: String,
97    },
98    MX {
99        preference: u16,
100        exchange: String,
101    },
102    NS {
103        nameserver: String,
104    },
105    TXT {
106        text: String,
107    },
108    SOA {
109        mname: String,
110        rname: String,
111        serial: u32,
112        refresh: u32,
113        retry: u32,
114        expire: u32,
115        minimum: u32,
116    },
117    PTR {
118        target: String,
119    },
120    SRV {
121        priority: u16,
122        weight: u16,
123        port: u16,
124        target: String,
125    },
126    CAA {
127        flags: u8,
128        tag: String,
129        value: String,
130    },
131    DNSKEY {
132        flags: u16,
133        protocol: u8,
134        algorithm: u8,
135        public_key: String,
136    },
137    DS {
138        key_tag: u16,
139        algorithm: u8,
140        digest_type: u8,
141        digest: String,
142    },
143    TLSA {
144        cert_usage: u8,
145        selector: u8,
146        matching: u8,
147        /// Hex-encoded certificate association data (uppercase).
148        cert_data: String,
149    },
150    SSHFP {
151        algorithm: u8,
152        fingerprint_type: u8,
153        /// Hex-encoded fingerprint (uppercase).
154        fingerprint: String,
155    },
156    Unknown {
157        raw: String,
158    },
159}
160
161impl fmt::Display for RecordData {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        match self {
164            RecordData::A { address } => write!(f, "{}", address),
165            RecordData::AAAA { address } => write!(f, "{}", address),
166            RecordData::CNAME { target } => write!(f, "{}", target),
167            RecordData::MX {
168                preference,
169                exchange,
170            } => write!(f, "{} {}", preference, exchange),
171            RecordData::NS { nameserver } => write!(f, "{}", nameserver),
172            RecordData::TXT { text } => write!(f, "\"{}\"", text),
173            RecordData::SOA {
174                mname,
175                rname,
176                serial,
177                refresh,
178                retry,
179                expire,
180                minimum,
181            } => write!(
182                f,
183                "{} {} {} {} {} {} {}",
184                mname, rname, serial, refresh, retry, expire, minimum
185            ),
186            RecordData::PTR { target } => write!(f, "{}", target),
187            RecordData::SRV {
188                priority,
189                weight,
190                port,
191                target,
192            } => write!(f, "{} {} {} {}", priority, weight, port, target),
193            RecordData::CAA { flags, tag, value } => write!(f, "{} {} \"{}\"", flags, tag, value),
194            RecordData::DNSKEY {
195                flags,
196                protocol,
197                algorithm,
198                public_key,
199            } => write!(f, "{} {} {} {}", flags, protocol, algorithm, public_key),
200            RecordData::DS {
201                key_tag,
202                algorithm,
203                digest_type,
204                digest,
205            } => write!(f, "{} {} {} {}", key_tag, algorithm, digest_type, digest),
206            RecordData::TLSA {
207                cert_usage,
208                selector,
209                matching,
210                cert_data,
211            } => write!(f, "{} {} {} {}", cert_usage, selector, matching, cert_data),
212            RecordData::SSHFP {
213                algorithm,
214                fingerprint_type,
215                fingerprint,
216            } => write!(f, "{} {} {}", algorithm, fingerprint_type, fingerprint),
217            RecordData::Unknown { raw } => write!(f, "{}", raw),
218        }
219    }
220}
221
222impl DnsRecord {
223    pub fn format_short(&self) -> String {
224        format!("{}", self.data)
225    }
226
227    pub fn format_full(&self) -> String {
228        format!(
229            "{}\t{}\tIN\t{}\t{}",
230            self.name, self.ttl, self.record_type, self.data
231        )
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_record_type_from_str() {
241        assert_eq!("A".parse::<RecordType>().unwrap(), RecordType::A);
242        assert_eq!("aaaa".parse::<RecordType>().unwrap(), RecordType::AAAA);
243        assert_eq!("MX".parse::<RecordType>().unwrap(), RecordType::MX);
244        assert_eq!("*".parse::<RecordType>().unwrap(), RecordType::ANY);
245        assert!("INVALID".parse::<RecordType>().is_err());
246    }
247
248    #[test]
249    fn test_record_type_display() {
250        assert_eq!(RecordType::A.to_string(), "A");
251        assert_eq!(RecordType::AAAA.to_string(), "AAAA");
252        assert_eq!(RecordType::MX.to_string(), "MX");
253        assert_eq!(RecordType::SOA.to_string(), "SOA");
254    }
255
256    #[test]
257    fn test_dns_record_format_short() {
258        let record = DnsRecord {
259            name: "example.com".to_string(),
260            record_type: RecordType::A,
261            ttl: 300,
262            data: RecordData::A {
263                address: "1.2.3.4".to_string(),
264            },
265        };
266        assert_eq!(record.format_short(), "1.2.3.4");
267    }
268
269    #[test]
270    fn test_dns_record_format_full() {
271        let record = DnsRecord {
272            name: "example.com".to_string(),
273            record_type: RecordType::A,
274            ttl: 300,
275            data: RecordData::A {
276                address: "1.2.3.4".to_string(),
277            },
278        };
279        assert_eq!(record.format_full(), "example.com\t300\tIN\tA\t1.2.3.4");
280    }
281
282    #[test]
283    fn test_record_data_display() {
284        let mx = RecordData::MX {
285            preference: 10,
286            exchange: "mail.example.com".to_string(),
287        };
288        assert_eq!(format!("{}", mx), "10 mail.example.com");
289
290        let txt = RecordData::TXT {
291            text: "v=spf1 include:example.com".to_string(),
292        };
293        assert_eq!(format!("{}", txt), "\"v=spf1 include:example.com\"");
294
295        let srv = RecordData::SRV {
296            priority: 10,
297            weight: 5,
298            port: 443,
299            target: "server.example.com".to_string(),
300        };
301        assert_eq!(format!("{}", srv), "10 5 443 server.example.com");
302    }
303
304    #[test]
305    fn test_record_serialization_roundtrip() {
306        let record = DnsRecord {
307            name: "example.com".to_string(),
308            record_type: RecordType::A,
309            ttl: 300,
310            data: RecordData::A {
311                address: "1.2.3.4".to_string(),
312            },
313        };
314        let json = serde_json::to_string(&record).unwrap();
315        assert!(json.contains("\"A\""));
316        assert!(json.contains("1.2.3.4"));
317    }
318
319    #[test]
320    fn test_soa_display() {
321        let soa = RecordData::SOA {
322            mname: "ns1.example.com".to_string(),
323            rname: "admin.example.com".to_string(),
324            serial: 2024010101,
325            refresh: 3600,
326            retry: 900,
327            expire: 604800,
328            minimum: 86400,
329        };
330        let display = format!("{}", soa);
331        assert!(display.contains("ns1.example.com"));
332        assert!(display.contains("2024010101"));
333    }
334}