srv_rs/resolver/libresolv/
mod.rs

1//! SRV Resolver backed by `libresolv`.
2
3use super::SrvResolver;
4use crate::SrvRecord;
5use async_trait::async_trait;
6use std::{
7    convert::TryInto,
8    ffi::CString,
9    time::{Duration, Instant},
10};
11
12mod ffi;
13
14/// Errors encountered by [`LibResolv`].
15#[derive(Debug, thiserror::Error, PartialEq, Eq)]
16pub enum LibResolvError {
17    /// Rust -> C string conversion errors.
18    #[error("srv name contained interior null byte: {0}")]
19    InteriorNul(#[from] std::ffi::NulError),
20    /// SRV resolver errors.
21    #[error("resolver: {0}")]
22    Resolver(#[from] ffi::ResolverError),
23    /// Tried to parse non-SRV record as SRV.
24    #[error("record type is not SRV")]
25    NotSrv,
26    /// DNS answer larger than allowed by RFC.
27    #[error("DNS answer did not fit in maximum message size (65535)")]
28    AnswerTooLarge,
29}
30
31/// SRV Resolver backed by `libresolv`.
32#[derive(Debug)]
33pub struct LibResolv {
34    initial_buf_size: usize,
35}
36
37impl LibResolv {
38    /// Initialzes a resolver with a specific initial buffer size for DNS answers.
39    pub fn new(initial_buf_size: usize) -> Self {
40        Self { initial_buf_size }
41    }
42}
43
44impl Default for LibResolv {
45    fn default() -> Self {
46        Self::new(ffi::NS_PACKETSZ as usize)
47    }
48}
49
50#[async_trait]
51impl SrvResolver for LibResolv {
52    type Record = LibResolvSrvRecord;
53    type Error = LibResolvError;
54
55    async fn get_srv_records_unordered(
56        &self,
57        srv: &str,
58    ) -> Result<(Vec<Self::Record>, Instant), Self::Error> {
59        let srv = CString::new(srv)?;
60        let mut buf = vec![0u8; self.initial_buf_size];
61        ffi::RESOLV_STATE.with(|state| {
62            let mut state = state.borrow_mut();
63            let (len, response_time) = loop {
64                let len = unsafe {
65                    ffi::res_nsearch(
66                        state.as_mut(),
67                        srv.as_ptr(),
68                        ffi::ns_c_in as i32,
69                        ffi::ns_t_srv as i32,
70                        buf.as_mut_ptr(),
71                        buf.len() as i32,
72                    )
73                };
74                let len = match state.check(len) {
75                    Ok(()) => len as usize,
76                    Err(e) => return Err(e.into()),
77                };
78                if len <= buf.len() {
79                    break (len, Instant::now());
80                } else if len <= ffi::NS_MAXMSG as usize {
81                    // Retry with larger buffer
82                    buf.resize(len, 0)
83                } else {
84                    return Err(LibResolvError::AnswerTooLarge);
85                }
86            };
87
88            let response = &buf[..len];
89            let mut msg = unsafe { std::mem::zeroed() };
90            let ret =
91                unsafe { ffi::ns_initparse(response.as_ptr(), response.len() as i32, &mut msg) };
92            state.check(ret)?;
93
94            let mut rr = unsafe { std::mem::zeroed() };
95            let num_records = ffi::ns_msg_count(msg, ffi::ns_s_an);
96            let mut records = Vec::with_capacity(num_records as usize);
97            let mut min_ttl = None;
98            for idx in 0..num_records {
99                let ret = unsafe { ffi::ns_parserr(&mut msg, ffi::ns_s_an, idx as i32, &mut rr) };
100                state.check(ret)?;
101                let (record, ttl) = LibResolvSrvRecord::try_parse(&state, msg, rr)?;
102                records.push(record);
103                min_ttl = min_ttl.min(Some(ttl)).or(Some(ttl));
104            }
105
106            Ok((records, response_time + min_ttl.unwrap_or_default()))
107        })
108    }
109}
110
111/// Representation of SRV records used by [`LibResolv`].
112#[derive(Clone, Debug, PartialEq, Eq)]
113pub struct LibResolvSrvRecord {
114    /// Records's target.
115    pub target: String,
116    /// Record's port.
117    pub port: u16,
118    /// Record's priority.
119    pub priority: u16,
120    /// Record's weight.
121    pub weight: u16,
122}
123
124impl SrvRecord for LibResolvSrvRecord {
125    type Target = str;
126
127    fn target(&self) -> &Self::Target {
128        &self.target
129    }
130
131    fn port(&self) -> u16 {
132        self.port
133    }
134
135    fn priority(&self) -> u16 {
136        self.priority
137    }
138
139    fn weight(&self) -> u16 {
140        self.weight
141    }
142}
143
144impl LibResolvSrvRecord {
145    fn try_parse(
146        state: &ffi::ResolverState,
147        msg: ffi::ns_msg,
148        rr: ffi::ns_rr,
149    ) -> Result<(Self, Duration), LibResolvError> {
150        if rr.type_ as u32 != ffi::ns_t_srv {
151            return Err(LibResolvError::NotSrv);
152        }
153
154        let (header, rest) =
155            unsafe { std::slice::from_raw_parts(rr.rdata, rr.rdlength as usize) }.split_at(6);
156
157        let mut chunks = header
158            .chunks_exact(2)
159            .map(|chunk| u16::from_be_bytes(chunk.try_into().unwrap()));
160
161        let priority = chunks.next().unwrap();
162        let weight = chunks.next().unwrap();
163        let port = chunks.next().unwrap();
164
165        let mut name = [0u8; ffi::NS_MAXDNAME as usize];
166        let ret = unsafe {
167            ffi::dn_expand(
168                ffi::ns_msg_base(msg),
169                ffi::ns_msg_end(msg),
170                rest.as_ptr(),
171                name.as_mut_ptr().cast(),
172                name.len() as i32,
173            )
174        };
175        state.check(ret)?;
176
177        let target = unsafe { std::ffi::CStr::from_ptr(name.as_ptr().cast()) };
178        let ttl = Duration::from_secs(rr.ttl as u64);
179        let record = Self {
180            target: target.to_string_lossy().to_string(),
181            port,
182            priority,
183            weight,
184        };
185        Ok((record, ttl))
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[tokio::test]
194    async fn srv_lookup() -> Result<(), LibResolvError> {
195        let (records, valid_until) = LibResolv::default()
196            .get_srv_records_unordered(crate::EXAMPLE_SRV)
197            .await?;
198        assert_ne!(records.len(), 0);
199        assert!(valid_until > Instant::now());
200        Ok(())
201    }
202
203    #[tokio::test]
204    async fn srv_lookup_ordered() -> Result<(), LibResolvError> {
205        let (records, _) = LibResolv::default()
206            .get_srv_records(crate::EXAMPLE_SRV)
207            .await?;
208        assert_ne!(records.len(), 0);
209        assert!((0..records.len() - 1).all(|i| records[i].priority() <= records[i + 1].priority()));
210        Ok(())
211    }
212
213    #[tokio::test]
214    async fn invalid_host() {
215        assert_eq!(
216            LibResolv::default()
217                .get_srv_records("_http._tcp.foobar.deshaw.com")
218                .await,
219            Err(ffi::ResolverError::HostNotFound.into())
220        );
221    }
222
223    #[tokio::test]
224    async fn malformed_srv_name() {
225        assert_eq!(
226            LibResolv::default()
227                .get_srv_records("_http.foobar.deshaw.com")
228                .await,
229            Err(ffi::ResolverError::HostNotFound.into())
230        );
231    }
232
233    #[tokio::test]
234    async fn very_malformed_srv_name() {
235        assert_eq!(
236            LibResolv::default()
237                .get_srv_records("  @#*^[_hsd flt.com")
238                .await,
239            Err(ffi::ResolverError::HostNotFound.into())
240        );
241    }
242
243    #[tokio::test]
244    async fn srv_name_containing_nul_terminator() {
245        assert!(matches!(
246            LibResolv::default()
247                .get_srv_records("_http.\0_tcp.foo.com")
248                .await,
249            Err(LibResolvError::InteriorNul(_))
250        ));
251    }
252}