1use std::{
2 collections::{btree_set, BTreeSet},
3 convert::TryFrom,
4 fmt,
5 net::{self, Ipv4Addr, Ipv6Addr},
6 str::{self, FromStr},
7 string::FromUtf8Error,
8};
9
10use trust_dns::rr::{self, rdata};
11
12#[derive(Debug, Clone, Eq, PartialEq)]
18pub struct RecordSet {
19 name: rr::Name,
20 dns_class: rr::DNSClass,
21 data: RsData,
22}
23
24impl RecordSet {
25 pub fn new(name: rr::Name, data: RsData) -> Self {
26 RecordSet {
27 name,
28 dns_class: rr::DNSClass::IN,
29 data,
30 }
31 }
32
33 pub fn name(&self) -> &rr::Name {
34 &self.name
35 }
36
37 pub fn dns_class(&self) -> rr::DNSClass {
38 self.dns_class
39 }
40
41 pub fn record_type(&self) -> rr::RecordType {
42 self.data.record_type()
43 }
44
45 pub fn to_rrset(&self, ttl: u32) -> rr::RecordSet {
46 let mut rrset = rr::RecordSet::new(&self.name, self.record_type(), ttl);
47 for data in self.iter_data() {
48 rrset.add_rdata(data);
49 }
50 rrset
51 }
52
53 pub fn data(&self) -> &RsData {
54 &self.data
55 }
56
57 pub fn iter_data(&self) -> RsDataIter {
58 let inner = match &self.data {
59 RsData::TXT(txts) => RsDataIterInner::TXT(txts.iter()),
60 RsData::A(addrs) => RsDataIterInner::A(addrs.iter()),
61 RsData::AAAA(addrs) => RsDataIterInner::AAAA(addrs.iter()),
62 };
63 RsDataIter(inner)
64 }
65
66 pub fn contains(&self, entry: &rr::RData) -> bool {
67 match (&self.data, entry) {
68 (RsData::TXT(txts), rr::RData::TXT(txt)) => {
69 if let Ok(txt) = txt_string(txt) {
70 txts.contains(&txt)
71 } else {
72 false
73 }
74 }
75 (RsData::A(addrs), rr::RData::A(addr)) => addrs.contains(addr),
76 (RsData::AAAA(addrs), rr::RData::AAAA(addr)) => addrs.contains(addr),
77 _ => false,
78 }
79 }
80
81 pub fn is_empty(&self) -> bool {
82 match &self.data {
83 RsData::TXT(txts) => txts.is_empty(),
84 RsData::A(addrs) => addrs.is_empty(),
85 RsData::AAAA(addrs) => addrs.is_empty(),
86 }
87 }
88
89 pub fn is_subset(&self, other: &RecordSet) -> bool {
90 use RsData::*;
91 if self.name() != other.name() {
92 return false;
93 }
94 match (&self.data, &other.data) {
95 (TXT(txts), TXT(other_txts)) => txts.is_subset(other_txts),
96 (A(addrs), A(other_addrs)) => addrs.is_subset(other_addrs),
97 (AAAA(addrs), AAAA(other_addrs)) => addrs.is_subset(other_addrs),
98 _ => false,
99 }
100 }
101}
102
103impl fmt::Display for RecordSet {
104 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105 write!(f, "{} ({})", self.name, self.data)
106 }
107}
108
109#[derive(Debug)]
110pub struct RsDataIter<'a>(RsDataIterInner<'a>);
111
112impl<'a> Iterator for RsDataIter<'a> {
113 type Item = rr::RData;
114
115 fn next(&mut self) -> Option<Self::Item> {
116 use RsDataIterInner::*;
117 match &mut self.0 {
118 A(iter) => iter.next().map(|item| rr::RData::A(*item)),
119 AAAA(iter) => iter.next().map(|item| rr::RData::AAAA(*item)),
120 TXT(iter) => iter
121 .next()
122 .map(|item| rr::RData::TXT(rdata::TXT::new(vec![item.into()]))),
123 }
124 }
125}
126
127#[derive(Debug)]
128enum RsDataIterInner<'a> {
129 TXT(btree_set::Iter<'a, String>),
130 A(btree_set::Iter<'a, Ipv4Addr>),
131 AAAA(btree_set::Iter<'a, Ipv6Addr>),
132}
133
134#[derive(Debug, Clone, Hash, Eq, PartialEq)]
135pub enum RsData {
136 TXT(BTreeSet<String>), A(BTreeSet<Ipv4Addr>),
138 AAAA(BTreeSet<Ipv6Addr>),
139}
140
141impl RsData {
142 pub fn record_type(&self) -> rr::RecordType {
143 match self {
144 RsData::TXT(_) => rr::RecordType::TXT,
145 RsData::A(_) => rr::RecordType::A,
146 RsData::AAAA(_) => rr::RecordType::AAAA,
147 }
148 }
149}
150
151impl fmt::Display for RsData {
152 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153 match self {
155 RsData::A(addrs) => {
156 write!(f, "A:")?;
157 for (i, addr) in addrs.iter().enumerate() {
158 if i > 0 {
159 write!(f, ",")?;
160 }
161 write!(f, "{}", addr)?;
162 }
163 }
164 RsData::AAAA(addrs) => {
165 write!(f, "A:")?;
166 for (i, addr) in addrs.iter().enumerate() {
167 if i > 0 {
168 write!(f, ",")?;
169 }
170 write!(f, "{}", addr)?;
171 }
172 }
173 RsData::TXT(txts) => {
174 write!(f, "TXT:")?;
175 for (i, txt) in txts.iter().enumerate() {
176 if i > 0 {
177 write!(f, ",")?;
178 }
179 write!(f, "{}", txt)?;
180 }
181 }
182 }
183 Ok(())
184 }
185}
186
187impl FromStr for RsData {
188 type Err = RsDataParseError;
189
190 fn from_str(s: &str) -> Result<Self, Self::Err> {
191 let parts: Vec<_> = s.splitn(2, ':').collect();
192 if parts.len() == 1 {
193 return match parts[0] {
194 "TXT" => Ok(RsData::TXT(Default::default())),
195 "A" => Ok(RsData::A(Default::default())),
196 "AAAA" => Ok(RsData::AAAA(Default::default())),
197 _ => Err(RsDataParseError::UnknownType),
198 };
199 }
200 if parts.len() != 2 {
201 return Err(RsDataParseError::MissingType);
202 }
203 let (rtype, rdata) = (parts[0].to_uppercase(), parts[1]);
204 let rdata_parts = rdata.split(',');
205 match rtype.as_str() {
206 "TXT" => Ok(RsData::TXT(rdata_parts.map(|s| s.to_owned()).collect())),
207 "A" => {
208 let addrs = rdata_parts
209 .map(|part| part.parse().map_err(RsDataParseError::Addr))
210 .collect::<Result<_, _>>()?;
211 Ok(RsData::A(addrs))
212 }
213 "AAAA" => {
214 let addrs = rdata_parts
215 .map(|part| part.parse().map_err(RsDataParseError::Addr))
216 .collect::<Result<_, _>>()?;
217 Ok(RsData::AAAA(addrs))
218 }
219 _ => Err(RsDataParseError::UnknownType),
220 }
221 }
222}
223
224#[derive(Debug)]
225pub enum RsDataParseError {
226 MissingType,
227 UnknownType,
228 Addr(net::AddrParseError),
229}
230
231impl fmt::Display for RsDataParseError {
232 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
233 use RsDataParseError::*;
234 match self {
235 MissingType => write!(f, "missing type"),
236 UnknownType => write!(f, "unknown type"),
237 Addr(e) => write!(f, "invalid address: {}", e),
238 }
239 }
240}
241
242impl std::error::Error for RsDataParseError {}
243
244#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
245pub struct RsKey {
246 name: rr::Name,
247 dns_class: rr::DNSClass,
248 record_type: rr::RecordType,
249}
250
251impl RsKey {
252 pub fn name(&self) -> &rr::Name {
253 &self.name
254 }
255 pub fn dns_class(&self) -> rr::DNSClass {
256 self.dns_class
257 }
258 pub fn record_type(&self) -> rr::RecordType {
259 self.record_type
260 }
261}
262
263impl From<&rr::Record> for RsKey {
264 fn from(rr: &rr::Record) -> Self {
265 RsKey {
266 name: rr.name().clone(),
267 dns_class: rr.dns_class(),
268 record_type: rr.record_type(),
269 }
270 }
271}
272
273fn txt_string(txt: &rdata::TXT) -> Result<String, TryFromRecordsError> {
274 let data = txt.txt_data();
275 if data.len() != 1 {
276 return Err(TryFromRecordsError::UnsupportedTxtValue);
277 }
278 str::from_utf8(&data[0])
279 .map(Into::into)
280 .map_err(TryFromRecordsError::Utf8)
281}
282
283impl TryFrom<&[rr::Record]> for RecordSet {
284 type Error = TryFromRecordsError;
285
286 fn try_from(rrs: &[rr::Record]) -> Result<Self, Self::Error> {
287 let keys: BTreeSet<RsKey> = rrs.iter().map(Into::into).collect();
288 match keys.len() {
289 0 => Err(TryFromRecordsError::Empty),
290 1 => {
291 let key = keys.iter().nth(0).unwrap();
292 let data = match key.record_type {
297 rr::RecordType::A => {
298 RsData::A(rrs.iter().map(|rr| *rr.rdata().as_a().unwrap()).collect())
299 }
300 rr::RecordType::AAAA => RsData::AAAA(
301 rrs.iter()
302 .map(|rr| *rr.rdata().as_aaaa().unwrap())
303 .collect(),
304 ),
305 rr::RecordType::TXT => RsData::TXT(
306 rrs.iter()
307 .map(|rr| txt_string(rr.rdata().as_txt().unwrap()))
308 .collect::<Result<_, _>>()?,
309 ),
310 rtype => return Err(TryFromRecordsError::UnsupportedType(rtype)),
311 };
312 Ok(RecordSet {
313 name: key.name.clone(),
314 dns_class: key.dns_class,
315 data,
316 })
317 }
318 _ => Err(TryFromRecordsError::MultipleKeys(keys)),
319 }
320 }
321}
322
323#[derive(Debug)]
324pub enum TryFromRecordsError {
325 Empty,
326 MultipleKeys(BTreeSet<RsKey>),
327 UnsupportedType(rr::RecordType),
328 UnsupportedTxtValue,
329 FromUtf8(FromUtf8Error),
330 Utf8(str::Utf8Error),
331}
332
333impl fmt::Display for TryFromRecordsError {
334 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
335 use TryFromRecordsError::*;
336 match self {
337 Empty => write!(f, "no records"),
338 MultipleKeys(_) => write!(f, "multiple keys"),
339 UnsupportedType(rtype) => write!(f, "unsupported record type {}", rtype),
340 UnsupportedTxtValue => write!(f, "unsupported TXT value"),
341 Utf8(e) => write!(f, "non-UTF8 content: {}", e),
342 FromUtf8(e) => write!(f, "non-UTF8 content: {}", e),
343 }
344 }
345}
346
347impl std::error::Error for TryFromRecordsError {}