toe_beans/v4/message/
socket.rs

1use super::{Decodable, Encodable};
2use crate::v4::{Result, UndecodedMessage};
3use inherface::get_interfaces;
4use log::{debug, info};
5use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs, UdpSocket};
6
7/// A connection to a UdpSocket that understands how to send/receive Deliverable.
8/// Meant to be used by both Client and Server.
9///
10/// Currently wraps [UdpSocket::bind](https://doc.rust-lang.org/std/net/struct.UdpSocket.html#method.bind)
11pub struct Socket {
12    socket: UdpSocket,
13    broadcast: Ipv4Addr,
14}
15
16/// Bytes, received from the socket, that can be decoded into a Message.
17pub type MessageBuffer = [u8; Socket::MAX_MESSAGE_SIZE as usize];
18
19impl Socket {
20    /// "A DHCP client must be prepared to receive DHCP messages with an options
21    /// field of at least length 312 octets"
22    ///
23    /// All other fields add to length 236, so 236 + 312 = 548
24    pub const MAX_MESSAGE_SIZE: u16 = 548;
25
26    /// A `MessageBuffer` with all 0 bytes.
27    pub const EMPTY_MESSAGE_BUFFER: MessageBuffer = [0; Self::MAX_MESSAGE_SIZE as usize];
28
29    /// Bind to an ip address/port and require that broadcast is enabled on the socket.
30    ///
31    /// Should be be called by both `Server::new` and `Client::new`, so this can be slower/panic
32    /// since it is not in the `listen_once` hot path.
33    pub fn new(address: SocketAddrV4, interface: Option<&String>) -> Self {
34        // hack to fix integration tests
35        let broadcast = if cfg!(feature = "integration") {
36            Ipv4Addr::LOCALHOST
37        } else {
38            Self::get_interface_broadcast(interface).unwrap_or(Ipv4Addr::BROADCAST)
39        };
40
41        let socket = UdpSocket::bind(address).expect("failed to bind to address");
42        info!("UDP socket bound on {}", address);
43
44        // set_broadcast sets the SO_BROADCAST option on the socket
45        // which is a way of ensuring that the application
46        // can't _accidentally_ spam many devices with a broadcast
47        // address without intentionally intending to do so.
48        socket
49            .set_broadcast(true)
50            .expect("Failed to enable broadcasting");
51
52        Self { socket, broadcast }
53    }
54
55    fn get_interface_broadcast(maybe_interface_name: Option<&String>) -> Option<Ipv4Addr> {
56        let interface = maybe_interface_name?;
57        let interfaces = get_interfaces().ok()?;
58        let maybe_interface = interfaces.get(interface);
59        let maybe_addr = maybe_interface?
60            .v4_addr
61            .iter()
62            .find(|address| address.broadcast.is_some());
63        let maybe_broadcast = maybe_addr?.broadcast;
64
65        debug!(
66            "Found ipv4 broadcast address ({:?}) in list of interface addresses",
67            maybe_broadcast
68        );
69
70        maybe_broadcast
71    }
72
73    /// Returns the ip address of the bound socket.
74    ///
75    /// Can panic as typically called outside of `listen_once` hot path.
76    pub fn get_ip(&self) -> Ipv4Addr {
77        match self.socket.local_addr().unwrap().ip() {
78            std::net::IpAddr::V4(ip) => ip,
79            std::net::IpAddr::V6(_) => todo!("ipv6 is not supported yet"),
80        }
81    }
82
83    /// Decodes received bytes into a "message" type that implements `DecodeMessage`
84    /// and returns it and the source address.
85    pub fn receive<M: Decodable<Output = M>>(&self) -> Result<(M, SocketAddr)> {
86        // Any bytes over MAX_MESSAGE_SIZE will be discarded.
87        let mut buffer = Self::EMPTY_MESSAGE_BUFFER;
88        let (_, src) = match self.socket.recv_from(&mut buffer) {
89            Ok(values) => values,
90            Err(_) => return Err("Failed to receive data"),
91        };
92
93        let decoded = M::from_bytes(&UndecodedMessage::new(buffer));
94
95        debug!("Received dhcp message (from {}): {:?}", src, decoded);
96
97        Ok((decoded, src))
98    }
99
100    /// Returns received bytes directly without being decoded into a `Deliverable`,
101    /// which allows you to partially decode them yourself later.
102    pub fn receive_raw(&self) -> Result<(UndecodedMessage, SocketAddr)> {
103        // Any bytes over MAX_MESSAGE_SIZE will be discarded.
104        let mut buffer = Self::EMPTY_MESSAGE_BUFFER;
105        let (_, src) = match self.socket.recv_from(&mut buffer) {
106            Ok(values) => values,
107            Err(_) => return Err("Failed to receive data"),
108        };
109
110        debug!("Received dhcp message from {}", src);
111
112        Ok((UndecodedMessage::new(buffer), src))
113    }
114
115    /// Fills an empty undecoded message with passed bytes.
116    /// Used to mock receiving a specific `UndecodedMessage` in tests, etc.
117    pub fn receive_mock(partial_message: &[u8]) -> UndecodedMessage {
118        let mut buffer = Self::EMPTY_MESSAGE_BUFFER;
119        partial_message
120            .iter()
121            .enumerate()
122            .for_each(|(i, byte)| buffer[i] = *byte);
123        UndecodedMessage::new(buffer)
124    }
125
126    /// Converts a message to bytes and then sends it to the passed address.
127    pub fn unicast<A: ToSocketAddrs, M: Encodable>(&self, message: &M, address: A) -> Result<()> {
128        let encoded = message.to_bytes();
129
130        let address = address.to_socket_addrs().unwrap().next().unwrap();
131        debug!("Sending dhcp message (to {:?}): {:?}", address, message);
132
133        match self.socket.send_to(&encoded, address) {
134            Ok(_) => Ok(()),
135            Err(_) => Err("Failed to send data"),
136        }
137    }
138
139    /// Send a message as bytes to many devices on the local network.
140    pub fn broadcast<M: Encodable>(&self, message: &M, port: u16) -> Result<()> {
141        self.unicast(message, SocketAddr::new(IpAddr::V4(self.broadcast), port))
142    }
143}
144
145// -----
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_max_message_size_is_large_enough() {
153        // Otherwise other code will not satisfy RFC
154        // requirements and may cause panics.
155        assert!(Socket::MAX_MESSAGE_SIZE >= 548);
156    }
157}