Skip to main content

seer_core/dns/
records.rs

1use serde::{Deserialize, Serialize};
2use std::fmt;
3use std::str::FromStr;
4
5use crate::error::{Result, SeerError};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
8#[serde(rename_all = "UPPERCASE")]
9pub enum RecordType {
10    A,
11    AAAA,
12    CNAME,
13    MX,
14    NS,
15    TXT,
16    SOA,
17    PTR,
18    SRV,
19    CAA,
20    NAPTR,
21    DNSKEY,
22    DS,
23    TLSA,
24    SSHFP,
25    ANY,
26}
27
28impl fmt::Display for RecordType {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            RecordType::A => write!(f, "A"),
32            RecordType::AAAA => write!(f, "AAAA"),
33            RecordType::CNAME => write!(f, "CNAME"),
34            RecordType::MX => write!(f, "MX"),
35            RecordType::NS => write!(f, "NS"),
36            RecordType::TXT => write!(f, "TXT"),
37            RecordType::SOA => write!(f, "SOA"),
38            RecordType::PTR => write!(f, "PTR"),
39            RecordType::SRV => write!(f, "SRV"),
40            RecordType::CAA => write!(f, "CAA"),
41            RecordType::NAPTR => write!(f, "NAPTR"),
42            RecordType::DNSKEY => write!(f, "DNSKEY"),
43            RecordType::DS => write!(f, "DS"),
44            RecordType::TLSA => write!(f, "TLSA"),
45            RecordType::SSHFP => write!(f, "SSHFP"),
46            RecordType::ANY => write!(f, "ANY"),
47        }
48    }
49}
50
51impl FromStr for RecordType {
52    type Err = SeerError;
53
54    fn from_str(s: &str) -> Result<Self> {
55        match s.to_uppercase().as_str() {
56            "A" => Ok(RecordType::A),
57            "AAAA" => Ok(RecordType::AAAA),
58            "CNAME" => Ok(RecordType::CNAME),
59            "MX" => Ok(RecordType::MX),
60            "NS" => Ok(RecordType::NS),
61            "TXT" => Ok(RecordType::TXT),
62            "SOA" => Ok(RecordType::SOA),
63            "PTR" => Ok(RecordType::PTR),
64            "SRV" => Ok(RecordType::SRV),
65            "CAA" => Ok(RecordType::CAA),
66            "NAPTR" => Ok(RecordType::NAPTR),
67            "DNSKEY" => Ok(RecordType::DNSKEY),
68            "DS" => Ok(RecordType::DS),
69            "TLSA" => Ok(RecordType::TLSA),
70            "SSHFP" => Ok(RecordType::SSHFP),
71            "ANY" | "*" => Ok(RecordType::ANY),
72            _ => Err(SeerError::InvalidRecordType(s.to_string())),
73        }
74    }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct DnsRecord {
79    pub name: String,
80    pub record_type: RecordType,
81    pub ttl: u32,
82    pub data: RecordData,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(tag = "record_type", content = "value", rename_all = "UPPERCASE")]
87#[allow(clippy::upper_case_acronyms)]
88pub enum RecordData {
89    A {
90        address: String,
91    },
92    AAAA {
93        address: String,
94    },
95    CNAME {
96        target: String,
97    },
98    MX {
99        preference: u16,
100        exchange: String,
101    },
102    NS {
103        nameserver: String,
104    },
105    TXT {
106        text: String,
107    },
108    SOA {
109        mname: String,
110        rname: String,
111        serial: u32,
112        refresh: u32,
113        retry: u32,
114        expire: u32,
115        minimum: u32,
116    },
117    PTR {
118        target: String,
119    },
120    SRV {
121        priority: u16,
122        weight: u16,
123        port: u16,
124        target: String,
125    },
126    CAA {
127        flags: u8,
128        tag: String,
129        value: String,
130    },
131    DNSKEY {
132        flags: u16,
133        protocol: u8,
134        algorithm: u8,
135        public_key: String,
136    },
137    DS {
138        key_tag: u16,
139        algorithm: u8,
140        digest_type: u8,
141        digest: String,
142    },
143    Unknown {
144        raw: String,
145    },
146}
147
148impl fmt::Display for RecordData {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150        match self {
151            RecordData::A { address } => write!(f, "{}", address),
152            RecordData::AAAA { address } => write!(f, "{}", address),
153            RecordData::CNAME { target } => write!(f, "{}", target),
154            RecordData::MX {
155                preference,
156                exchange,
157            } => write!(f, "{} {}", preference, exchange),
158            RecordData::NS { nameserver } => write!(f, "{}", nameserver),
159            RecordData::TXT { text } => write!(f, "\"{}\"", text),
160            RecordData::SOA {
161                mname,
162                rname,
163                serial,
164                refresh,
165                retry,
166                expire,
167                minimum,
168            } => write!(
169                f,
170                "{} {} {} {} {} {} {}",
171                mname, rname, serial, refresh, retry, expire, minimum
172            ),
173            RecordData::PTR { target } => write!(f, "{}", target),
174            RecordData::SRV {
175                priority,
176                weight,
177                port,
178                target,
179            } => write!(f, "{} {} {} {}", priority, weight, port, target),
180            RecordData::CAA { flags, tag, value } => write!(f, "{} {} \"{}\"", flags, tag, value),
181            RecordData::DNSKEY {
182                flags,
183                protocol,
184                algorithm,
185                public_key,
186            } => write!(f, "{} {} {} {}", flags, protocol, algorithm, public_key),
187            RecordData::DS {
188                key_tag,
189                algorithm,
190                digest_type,
191                digest,
192            } => write!(f, "{} {} {} {}", key_tag, algorithm, digest_type, digest),
193            RecordData::Unknown { raw } => write!(f, "{}", raw),
194        }
195    }
196}
197
198impl DnsRecord {
199    pub fn format_short(&self) -> String {
200        format!("{}", self.data)
201    }
202
203    pub fn format_full(&self) -> String {
204        format!(
205            "{}\t{}\tIN\t{}\t{}",
206            self.name, self.ttl, self.record_type, self.data
207        )
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_record_type_from_str() {
217        assert_eq!("A".parse::<RecordType>().unwrap(), RecordType::A);
218        assert_eq!("aaaa".parse::<RecordType>().unwrap(), RecordType::AAAA);
219        assert_eq!("MX".parse::<RecordType>().unwrap(), RecordType::MX);
220        assert_eq!("*".parse::<RecordType>().unwrap(), RecordType::ANY);
221        assert!("INVALID".parse::<RecordType>().is_err());
222    }
223
224    #[test]
225    fn test_record_type_display() {
226        assert_eq!(RecordType::A.to_string(), "A");
227        assert_eq!(RecordType::AAAA.to_string(), "AAAA");
228        assert_eq!(RecordType::MX.to_string(), "MX");
229        assert_eq!(RecordType::SOA.to_string(), "SOA");
230    }
231
232    #[test]
233    fn test_dns_record_format_short() {
234        let record = DnsRecord {
235            name: "example.com".to_string(),
236            record_type: RecordType::A,
237            ttl: 300,
238            data: RecordData::A {
239                address: "1.2.3.4".to_string(),
240            },
241        };
242        assert_eq!(record.format_short(), "1.2.3.4");
243    }
244
245    #[test]
246    fn test_dns_record_format_full() {
247        let record = DnsRecord {
248            name: "example.com".to_string(),
249            record_type: RecordType::A,
250            ttl: 300,
251            data: RecordData::A {
252                address: "1.2.3.4".to_string(),
253            },
254        };
255        assert_eq!(record.format_full(), "example.com\t300\tIN\tA\t1.2.3.4");
256    }
257
258    #[test]
259    fn test_record_data_display() {
260        let mx = RecordData::MX {
261            preference: 10,
262            exchange: "mail.example.com".to_string(),
263        };
264        assert_eq!(format!("{}", mx), "10 mail.example.com");
265
266        let txt = RecordData::TXT {
267            text: "v=spf1 include:example.com".to_string(),
268        };
269        assert_eq!(format!("{}", txt), "\"v=spf1 include:example.com\"");
270
271        let srv = RecordData::SRV {
272            priority: 10,
273            weight: 5,
274            port: 443,
275            target: "server.example.com".to_string(),
276        };
277        assert_eq!(format!("{}", srv), "10 5 443 server.example.com");
278    }
279
280    #[test]
281    fn test_record_serialization_roundtrip() {
282        let record = DnsRecord {
283            name: "example.com".to_string(),
284            record_type: RecordType::A,
285            ttl: 300,
286            data: RecordData::A {
287                address: "1.2.3.4".to_string(),
288            },
289        };
290        let json = serde_json::to_string(&record).unwrap();
291        assert!(json.contains("\"A\""));
292        assert!(json.contains("1.2.3.4"));
293    }
294
295    #[test]
296    fn test_soa_display() {
297        let soa = RecordData::SOA {
298            mname: "ns1.example.com".to_string(),
299            rname: "admin.example.com".to_string(),
300            serial: 2024010101,
301            refresh: 3600,
302            retry: 900,
303            expire: 604800,
304            minimum: 86400,
305        };
306        let display = format!("{}", soa);
307        assert!(display.contains("ns1.example.com"));
308        assert!(display.contains("2024010101"));
309    }
310}