srv_rs/resolver/libresolv/
mod.rs1use super::SrvResolver;
4use crate::SrvRecord;
5use async_trait::async_trait;
6use std::{
7 convert::TryInto,
8 ffi::CString,
9 time::{Duration, Instant},
10};
11
12mod ffi;
13
14#[derive(Debug, thiserror::Error, PartialEq, Eq)]
16pub enum LibResolvError {
17 #[error("srv name contained interior null byte: {0}")]
19 InteriorNul(#[from] std::ffi::NulError),
20 #[error("resolver: {0}")]
22 Resolver(#[from] ffi::ResolverError),
23 #[error("record type is not SRV")]
25 NotSrv,
26 #[error("DNS answer did not fit in maximum message size (65535)")]
28 AnswerTooLarge,
29}
30
31#[derive(Debug)]
33pub struct LibResolv {
34 initial_buf_size: usize,
35}
36
37impl LibResolv {
38 pub fn new(initial_buf_size: usize) -> Self {
40 Self { initial_buf_size }
41 }
42}
43
44impl Default for LibResolv {
45 fn default() -> Self {
46 Self::new(ffi::NS_PACKETSZ as usize)
47 }
48}
49
50#[async_trait]
51impl SrvResolver for LibResolv {
52 type Record = LibResolvSrvRecord;
53 type Error = LibResolvError;
54
55 async fn get_srv_records_unordered(
56 &self,
57 srv: &str,
58 ) -> Result<(Vec<Self::Record>, Instant), Self::Error> {
59 let srv = CString::new(srv)?;
60 let mut buf = vec![0u8; self.initial_buf_size];
61 ffi::RESOLV_STATE.with(|state| {
62 let mut state = state.borrow_mut();
63 let (len, response_time) = loop {
64 let len = unsafe {
65 ffi::res_nsearch(
66 state.as_mut(),
67 srv.as_ptr(),
68 ffi::ns_c_in as i32,
69 ffi::ns_t_srv as i32,
70 buf.as_mut_ptr(),
71 buf.len() as i32,
72 )
73 };
74 let len = match state.check(len) {
75 Ok(()) => len as usize,
76 Err(e) => return Err(e.into()),
77 };
78 if len <= buf.len() {
79 break (len, Instant::now());
80 } else if len <= ffi::NS_MAXMSG as usize {
81 buf.resize(len, 0)
83 } else {
84 return Err(LibResolvError::AnswerTooLarge);
85 }
86 };
87
88 let response = &buf[..len];
89 let mut msg = unsafe { std::mem::zeroed() };
90 let ret =
91 unsafe { ffi::ns_initparse(response.as_ptr(), response.len() as i32, &mut msg) };
92 state.check(ret)?;
93
94 let mut rr = unsafe { std::mem::zeroed() };
95 let num_records = ffi::ns_msg_count(msg, ffi::ns_s_an);
96 let mut records = Vec::with_capacity(num_records as usize);
97 let mut min_ttl = None;
98 for idx in 0..num_records {
99 let ret = unsafe { ffi::ns_parserr(&mut msg, ffi::ns_s_an, idx as i32, &mut rr) };
100 state.check(ret)?;
101 let (record, ttl) = LibResolvSrvRecord::try_parse(&state, msg, rr)?;
102 records.push(record);
103 min_ttl = min_ttl.min(Some(ttl)).or(Some(ttl));
104 }
105
106 Ok((records, response_time + min_ttl.unwrap_or_default()))
107 })
108 }
109}
110
111#[derive(Clone, Debug, PartialEq, Eq)]
113pub struct LibResolvSrvRecord {
114 pub target: String,
116 pub port: u16,
118 pub priority: u16,
120 pub weight: u16,
122}
123
124impl SrvRecord for LibResolvSrvRecord {
125 type Target = str;
126
127 fn target(&self) -> &Self::Target {
128 &self.target
129 }
130
131 fn port(&self) -> u16 {
132 self.port
133 }
134
135 fn priority(&self) -> u16 {
136 self.priority
137 }
138
139 fn weight(&self) -> u16 {
140 self.weight
141 }
142}
143
144impl LibResolvSrvRecord {
145 fn try_parse(
146 state: &ffi::ResolverState,
147 msg: ffi::ns_msg,
148 rr: ffi::ns_rr,
149 ) -> Result<(Self, Duration), LibResolvError> {
150 if rr.type_ as u32 != ffi::ns_t_srv {
151 return Err(LibResolvError::NotSrv);
152 }
153
154 let (header, rest) =
155 unsafe { std::slice::from_raw_parts(rr.rdata, rr.rdlength as usize) }.split_at(6);
156
157 let mut chunks = header
158 .chunks_exact(2)
159 .map(|chunk| u16::from_be_bytes(chunk.try_into().unwrap()));
160
161 let priority = chunks.next().unwrap();
162 let weight = chunks.next().unwrap();
163 let port = chunks.next().unwrap();
164
165 let mut name = [0u8; ffi::NS_MAXDNAME as usize];
166 let ret = unsafe {
167 ffi::dn_expand(
168 ffi::ns_msg_base(msg),
169 ffi::ns_msg_end(msg),
170 rest.as_ptr(),
171 name.as_mut_ptr().cast(),
172 name.len() as i32,
173 )
174 };
175 state.check(ret)?;
176
177 let target = unsafe { std::ffi::CStr::from_ptr(name.as_ptr().cast()) };
178 let ttl = Duration::from_secs(rr.ttl as u64);
179 let record = Self {
180 target: target.to_string_lossy().to_string(),
181 port,
182 priority,
183 weight,
184 };
185 Ok((record, ttl))
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[tokio::test]
194 async fn srv_lookup() -> Result<(), LibResolvError> {
195 let (records, valid_until) = LibResolv::default()
196 .get_srv_records_unordered(crate::EXAMPLE_SRV)
197 .await?;
198 assert_ne!(records.len(), 0);
199 assert!(valid_until > Instant::now());
200 Ok(())
201 }
202
203 #[tokio::test]
204 async fn srv_lookup_ordered() -> Result<(), LibResolvError> {
205 let (records, _) = LibResolv::default()
206 .get_srv_records(crate::EXAMPLE_SRV)
207 .await?;
208 assert_ne!(records.len(), 0);
209 assert!((0..records.len() - 1).all(|i| records[i].priority() <= records[i + 1].priority()));
210 Ok(())
211 }
212
213 #[tokio::test]
214 async fn invalid_host() {
215 assert_eq!(
216 LibResolv::default()
217 .get_srv_records("_http._tcp.foobar.deshaw.com")
218 .await,
219 Err(ffi::ResolverError::HostNotFound.into())
220 );
221 }
222
223 #[tokio::test]
224 async fn malformed_srv_name() {
225 assert_eq!(
226 LibResolv::default()
227 .get_srv_records("_http.foobar.deshaw.com")
228 .await,
229 Err(ffi::ResolverError::HostNotFound.into())
230 );
231 }
232
233 #[tokio::test]
234 async fn very_malformed_srv_name() {
235 assert_eq!(
236 LibResolv::default()
237 .get_srv_records(" @#*^[_hsd flt.com")
238 .await,
239 Err(ffi::ResolverError::HostNotFound.into())
240 );
241 }
242
243 #[tokio::test]
244 async fn srv_name_containing_nul_terminator() {
245 assert!(matches!(
246 LibResolv::default()
247 .get_srv_records("_http.\0_tcp.foo.com")
248 .await,
249 Err(LibResolvError::InteriorNul(_))
250 ));
251 }
252}