sync_resolve/
socket.rs

1//! Low-level UDP socket operations
2
3use std::net::{IpAddr, Ipv6Addr, SocketAddr, ToSocketAddrs, UdpSocket};
4use std::{fmt, io};
5
6use crate::address::socket_address_equal;
7use crate::message::{DecodeError, DnsError, EncodeError, Message, MESSAGE_LIMIT};
8
9/// Represents a socket transmitting DNS messages.
10pub struct DnsSocket {
11    sock: UdpSocket,
12}
13
14impl DnsSocket {
15    /// Returns a `DnsSocket`, bound to an unspecified address.
16    pub fn new() -> io::Result<DnsSocket> {
17        DnsSocket::bind(SocketAddr::new(
18            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
19            0,
20        ))
21    }
22
23    /// Returns a `DnsSocket`, bound to the given address.
24    pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<DnsSocket> {
25        Ok(DnsSocket {
26            sock: UdpSocket::bind(addr)?,
27        })
28    }
29
30    /// Returns a reference to the wrapped `UdpSocket`.
31    pub fn get(&self) -> &UdpSocket {
32        &self.sock
33    }
34
35    /// Sends a message to the given address.
36    pub fn send_message<A: ToSocketAddrs>(&self, message: &Message, addr: A) -> Result<(), Error> {
37        let mut buf = [0; MESSAGE_LIMIT];
38        let data = message.encode(&mut buf)?;
39        self.sock.send_to(data, addr)?;
40        Ok(())
41    }
42
43    /// Receives a message, returning the address of the sender.
44    /// The given buffer is used to store and parse message data.
45    ///
46    /// The buffer should be exactly `MESSAGE_LIMIT` bytes in length.
47    pub fn recv_from<'buf>(
48        &self,
49        buf: &'buf mut [u8],
50    ) -> Result<(Message<'buf>, SocketAddr), Error> {
51        let (n, addr) = self.sock.recv_from(buf)?;
52
53        let msg = Message::decode(&buf[..n])?;
54        Ok((msg, addr))
55    }
56
57    /// Attempts to read a DNS message. The message will only be decoded if the
58    /// remote address matches `addr`. If a packet is received from a
59    /// non-matching address, the message is not decoded and `Ok(None)` is
60    /// returned.
61    ///
62    /// The buffer should be exactly `MESSAGE_LIMIT` bytes in length.
63    pub fn recv_message<'buf>(
64        &self,
65        addr: &SocketAddr,
66        buf: &'buf mut [u8],
67    ) -> Result<Option<Message<'buf>>, Error> {
68        let (n, recv_addr) = self.sock.recv_from(buf)?;
69
70        if !socket_address_equal(&recv_addr, addr) {
71            Ok(None)
72        } else {
73            let msg = Message::decode(&buf[..n])?;
74            Ok(Some(msg))
75        }
76    }
77}
78
79/// Represents an error in sending or receiving a DNS message.
80#[derive(Debug)]
81pub enum Error {
82    /// Error decoding received data
83    DecodeError(DecodeError),
84    /// Error encoding data to be sent
85    EncodeError(EncodeError),
86    /// Server responded with error message
87    DnsError(DnsError),
88    /// Error generated by network operation
89    IoError(io::Error),
90}
91
92impl Error {
93    /// Returns `true` if the error is the result of an operation having timed
94    /// out.
95    pub fn is_timeout(&self) -> bool {
96        match *self {
97            Error::IoError(ref e) => {
98                let kind = e.kind();
99                kind == io::ErrorKind::TimedOut || kind == io::ErrorKind::WouldBlock
100            }
101            _ => false,
102        }
103    }
104}
105
106impl fmt::Display for Error {
107    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108        match *self {
109            Error::DecodeError(e) => write!(f, "error decoding message: {}", e),
110            Error::EncodeError(ref e) => write!(f, "error encoding message: {}", e),
111            Error::DnsError(e) => write!(f, "server responded with error: {}", e),
112            Error::IoError(ref e) => fmt::Display::fmt(e, f),
113        }
114    }
115}
116
117impl From<DecodeError> for Error {
118    fn from(err: DecodeError) -> Error {
119        Error::DecodeError(err)
120    }
121}
122
123impl From<EncodeError> for Error {
124    fn from(err: EncodeError) -> Error {
125        Error::EncodeError(err)
126    }
127}
128
129impl From<DnsError> for Error {
130    fn from(err: DnsError) -> Error {
131        Error::DnsError(err)
132    }
133}
134
135impl From<io::Error> for Error {
136    fn from(err: io::Error) -> Error {
137        Error::IoError(err)
138    }
139}