webrtc_mdns/conn/
mod.rs

1use core::sync::atomic;
2use std::net::{IpAddr, Ipv4Addr, SocketAddr};
3use std::sync::Arc;
4use std::time::Duration;
5
6use socket2::SockAddr;
7use tokio::net::{ToSocketAddrs, UdpSocket};
8use tokio::sync::{mpsc, Mutex};
9use util::ifaces;
10
11use crate::config::*;
12use crate::error::*;
13use crate::message::header::*;
14use crate::message::name::*;
15use crate::message::parser::*;
16use crate::message::question::*;
17use crate::message::resource::a::*;
18use crate::message::resource::*;
19use crate::message::*;
20
21mod conn_test;
22
23pub const DEFAULT_DEST_ADDR: &str = "224.0.0.251:5353";
24
25const INBOUND_BUFFER_SIZE: usize = 65535;
26const DEFAULT_QUERY_INTERVAL: Duration = Duration::from_secs(1);
27const MAX_MESSAGE_RECORDS: usize = 3;
28const RESPONSE_TTL: u32 = 120;
29
30// Conn represents a mDNS Server
31pub struct DnsConn {
32    socket: Arc<UdpSocket>,
33    dst_addr: SocketAddr,
34
35    query_interval: Duration,
36    queries: Arc<Mutex<Vec<Query>>>,
37
38    is_server_closed: Arc<atomic::AtomicBool>,
39    close_server: mpsc::Sender<()>,
40}
41
42struct Query {
43    name_with_suffix: String,
44    query_result_chan: mpsc::Sender<QueryResult>,
45}
46
47struct QueryResult {
48    answer: ResourceHeader,
49    addr: SocketAddr,
50}
51
52impl DnsConn {
53    /// server establishes a mDNS connection over an existing connection
54    pub fn server(addr: SocketAddr, config: Config) -> Result<Self> {
55        let socket = socket2::Socket::new(
56            socket2::Domain::IPV4,
57            socket2::Type::DGRAM,
58            Some(socket2::Protocol::UDP),
59        )?;
60
61        #[cfg(all(target_family = "unix", feature = "reuse_port"))]
62        socket.set_reuse_port(true)?;
63
64        socket.set_reuse_address(true)?;
65        socket.set_broadcast(true)?;
66        socket.set_nonblocking(true)?;
67
68        socket.bind(&SockAddr::from(addr))?;
69        {
70            let mut join_error_count = 0;
71            let interfaces = match ifaces::ifaces() {
72                Ok(e) => e,
73                Err(e) => {
74                    log::error!("Error getting interfaces: {e:?}");
75                    return Err(Error::Other(e.to_string()));
76                }
77            };
78
79            for interface in &interfaces {
80                if let Some(SocketAddr::V4(e)) = interface.addr {
81                    if let Err(e) = socket.join_multicast_v4(&Ipv4Addr::new(224, 0, 0, 251), e.ip())
82                    {
83                        log::trace!("Error connecting multicast, error: {e:?}");
84                        join_error_count += 1;
85                        continue;
86                    }
87
88                    log::trace!("Connected to interface address {e:?}");
89                }
90            }
91
92            if join_error_count >= interfaces.len() {
93                return Err(Error::ErrJoiningMulticastGroup);
94            }
95        }
96
97        let socket = UdpSocket::from_std(socket.into())?;
98
99        let local_names = config
100            .local_names
101            .iter()
102            .map(|l| l.to_string() + ".")
103            .collect();
104
105        let dst_addr: SocketAddr = DEFAULT_DEST_ADDR.parse()?;
106
107        let is_server_closed = Arc::new(atomic::AtomicBool::new(false));
108
109        let (close_server_send, close_server_rcv) = mpsc::channel(1);
110
111        let c = DnsConn {
112            query_interval: if config.query_interval != Duration::from_secs(0) {
113                config.query_interval
114            } else {
115                DEFAULT_QUERY_INTERVAL
116            },
117
118            queries: Arc::new(Mutex::new(vec![])),
119            socket: Arc::new(socket),
120            dst_addr,
121            is_server_closed: Arc::clone(&is_server_closed),
122            close_server: close_server_send,
123        };
124
125        let queries = c.queries.clone();
126        let socket = Arc::clone(&c.socket);
127
128        tokio::spawn(async move {
129            DnsConn::start(
130                close_server_rcv,
131                is_server_closed,
132                socket,
133                local_names,
134                dst_addr,
135                queries,
136            )
137            .await
138        });
139
140        Ok(c)
141    }
142
143    /// Close closes the mDNS Conn
144    pub async fn close(&self) -> Result<()> {
145        log::info!("Closing connection");
146        if self.is_server_closed.load(atomic::Ordering::SeqCst) {
147            return Err(Error::ErrConnectionClosed);
148        }
149
150        log::trace!("Sending close command to server");
151        match self.close_server.send(()).await {
152            Ok(_) => {
153                log::trace!("Close command sent");
154                Ok(())
155            }
156            Err(e) => {
157                log::warn!("Error sending close command to server: {e:?}");
158                Err(Error::ErrConnectionClosed)
159            }
160        }
161    }
162
163    /// Query sends mDNS Queries for the following name until
164    /// either there's a close signal or we get a result
165    pub async fn query(
166        &self,
167        name: &str,
168        mut close_query_signal: mpsc::Receiver<()>,
169    ) -> Result<(ResourceHeader, SocketAddr)> {
170        if self.is_server_closed.load(atomic::Ordering::SeqCst) {
171            return Err(Error::ErrConnectionClosed);
172        }
173
174        let name_with_suffix = name.to_owned() + ".";
175
176        let (query_tx, mut query_rx) = mpsc::channel(1);
177        {
178            let mut queries = self.queries.lock().await;
179            queries.push(Query {
180                name_with_suffix: name_with_suffix.clone(),
181                query_result_chan: query_tx,
182            });
183        }
184
185        log::trace!("Sending query");
186        self.send_question(&name_with_suffix).await;
187
188        loop {
189            tokio::select! {
190                _ = tokio::time::sleep(self.query_interval) => {
191                    log::trace!("Sending query");
192                    self.send_question(&name_with_suffix).await
193                },
194
195                _ = close_query_signal.recv() => {
196                    log::info!("Query close signal received.");
197                    return Err(Error::ErrConnectionClosed)
198                },
199
200                res_opt = query_rx.recv() =>{
201                    log::info!("Received query result");
202                    if let Some(res) = res_opt{
203                        return Ok((res.answer, res.addr));
204                    }
205                }
206            }
207        }
208    }
209
210    async fn send_question(&self, name: &str) {
211        let packed_name = match Name::new(name) {
212            Ok(pn) => pn,
213            Err(err) => {
214                log::warn!("Failed to construct mDNS packet: {err}");
215                return;
216            }
217        };
218
219        let raw_query = {
220            let mut msg = Message {
221                header: Header::default(),
222                questions: vec![Question {
223                    typ: DnsType::A,
224                    class: DNSCLASS_INET,
225                    name: packed_name,
226                }],
227                ..Default::default()
228            };
229
230            match msg.pack() {
231                Ok(v) => v,
232                Err(err) => {
233                    log::error!("Failed to construct mDNS packet {err}");
234                    return;
235                }
236            }
237        };
238
239        log::trace!("{:?} sending {:?}...", self.socket.local_addr(), raw_query);
240        if let Err(err) = self.socket.send_to(&raw_query, self.dst_addr).await {
241            log::error!("Failed to send mDNS packet {err}");
242        }
243    }
244
245    async fn start(
246        mut closed_rx: mpsc::Receiver<()>,
247        close_server: Arc<atomic::AtomicBool>,
248        socket: Arc<UdpSocket>,
249        local_names: Vec<String>,
250        dst_addr: SocketAddr,
251        queries: Arc<Mutex<Vec<Query>>>,
252    ) -> Result<()> {
253        log::info!("Looping and listening {:?}", socket.local_addr());
254
255        let mut b = vec![0u8; INBOUND_BUFFER_SIZE];
256        let (mut n, mut src);
257
258        loop {
259            tokio::select! {
260                _ = closed_rx.recv() => {
261                    log::info!("Closing server connection");
262                    close_server.store(true, atomic::Ordering::SeqCst);
263
264                    return Ok(());
265                }
266
267                result = socket.recv_from(&mut b) => {
268                    match result{
269                        Ok((len, addr)) => {
270                            n = len;
271                            src = addr;
272                            log::trace!("Received new connection from {addr:?}");
273                        },
274
275                        Err(err) => {
276                            log::error!("Error receiving from socket connection: {err:?}");
277                            continue;
278                        },
279                    }
280                }
281            }
282
283            let mut p = Parser::default();
284            if let Err(err) = p.start(&b[..n]) {
285                log::error!("Failed to parse mDNS packet {err}");
286                continue;
287            }
288
289            run(&mut p, &socket, &local_names, src, dst_addr, &queries).await
290        }
291    }
292}
293
294async fn run(
295    p: &mut Parser<'_>,
296    socket: &Arc<UdpSocket>,
297    local_names: &[String],
298    src: SocketAddr,
299    dst_addr: SocketAddr,
300    queries: &Arc<Mutex<Vec<Query>>>,
301) {
302    let mut interface_addr = None;
303    for _ in 0..=MAX_MESSAGE_RECORDS {
304        let q = match p.question() {
305            Ok(q) => q,
306            Err(err) => {
307                if Error::ErrSectionDone == err {
308                    log::trace!("Parsing has completed");
309                    break;
310                } else {
311                    log::error!("Failed to parse mDNS packet {err}");
312                    return;
313                }
314            }
315        };
316
317        for local_name in local_names {
318            if *local_name == q.name.data {
319                let interface_addr = match interface_addr {
320                    Some(addr) => addr,
321                    None => match get_interface_addr_for_ip(src).await {
322                        Ok(addr) => {
323                            interface_addr.replace(addr);
324                            addr
325                        }
326                        Err(e) => {
327                            log::warn!(
328                                "Failed to get local interface to communicate with {}: {:?}",
329                                &src,
330                                e
331                            );
332                            continue;
333                        }
334                    },
335                };
336
337                log::trace!(
338                    "Found local name: {} to send answer, IP {}, interface addr {}",
339                    local_name,
340                    src.ip(),
341                    interface_addr
342                );
343                if let Err(e) =
344                    send_answer(socket, &interface_addr, &q.name.data, src.ip(), dst_addr).await
345                {
346                    log::error!("Error sending answer to client: {e:?}");
347                    continue;
348                };
349            }
350        }
351    }
352
353    // There might be more than MAX_MESSAGE_RECORDS questions, so skip the rest
354    let _ = p.skip_all_questions();
355
356    for _ in 0..=MAX_MESSAGE_RECORDS {
357        let a = match p.answer_header() {
358            Ok(a) => a,
359            Err(err) => {
360                if Error::ErrSectionDone != err {
361                    log::warn!("Failed to parse mDNS packet {err}");
362                }
363                return;
364            }
365        };
366
367        if a.typ != DnsType::A && a.typ != DnsType::Aaaa {
368            continue;
369        }
370
371        let mut qs = queries.lock().await;
372        for j in (0..qs.len()).rev() {
373            if qs[j].name_with_suffix == a.name.data {
374                let _ = qs[j]
375                    .query_result_chan
376                    .send(QueryResult {
377                        answer: a.clone(),
378                        addr: src,
379                    })
380                    .await;
381                qs.remove(j);
382            }
383        }
384    }
385}
386
387async fn send_answer(
388    socket: &Arc<UdpSocket>,
389    interface_addr: &SocketAddr,
390    name: &str,
391    dst: IpAddr,
392    dst_addr: SocketAddr,
393) -> Result<()> {
394    let raw_answer = {
395        let mut msg = Message {
396            header: Header {
397                response: true,
398                authoritative: true,
399                ..Default::default()
400            },
401
402            answers: vec![Resource {
403                header: ResourceHeader {
404                    typ: DnsType::A,
405                    class: DNSCLASS_INET,
406                    name: Name::new(name)?,
407                    ttl: RESPONSE_TTL,
408                    ..Default::default()
409                },
410                body: Some(Box::new(AResource {
411                    a: match interface_addr.ip() {
412                        IpAddr::V4(ip) => ip.octets(),
413                        IpAddr::V6(_) => {
414                            return Err(Error::Other("Unexpected IpV6 addr".to_owned()))
415                        }
416                    },
417                })),
418            }],
419            ..Default::default()
420        };
421
422        msg.pack()?
423    };
424
425    socket.send_to(&raw_answer, dst_addr).await?;
426    log::trace!("Sent answer to IP {dst}");
427
428    Ok(())
429}
430
431async fn get_interface_addr_for_ip(addr: impl ToSocketAddrs) -> std::io::Result<SocketAddr> {
432    let socket = UdpSocket::bind("0.0.0.0:0").await?;
433    socket.connect(addr).await?;
434    socket.local_addr()
435}