Skip to main content

toe_beans/v4/message/socket/
mod.rs

1mod config;
2
3pub use config::*;
4
5use super::{Decodable, Encodable};
6use crate::v4::{MAX_MESSAGE_SIZE, Result, UndecodedMessage};
7use inherface::get_interfaces;
8use log::{debug, info};
9use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket};
10
11/// A connection to a UdpSocket that understands how to send/receive an Encodable/Decodable message.
12/// Meant to be used by both Client and Server.
13///
14/// Currently wraps [UdpSocket::bind](https://doc.rust-lang.org/std/net/struct.UdpSocket.html#method.bind)
15#[derive(Debug)]
16pub struct Socket {
17    socket: UdpSocket,
18    broadcast: Ipv4Addr,
19}
20
21/// Bytes, received from the socket, that can be decoded into a Message.
22pub type MessageBuffer = [u8; MAX_MESSAGE_SIZE];
23
24impl Socket {
25    /// A `MessageBuffer` with all 0 bytes.
26    pub const EMPTY_MESSAGE_BUFFER: MessageBuffer = [0; MAX_MESSAGE_SIZE];
27
28    /// Bind to an ip address/port and require that broadcast is enabled on the socket.
29    ///
30    /// Should be be called by both `Server::new` and `Client::new`, so this can be slower/panic
31    /// since it is not in the `listen_once` hot path.
32    pub fn new(config: SocketConfig) -> Self {
33        let broadcast =
34            Self::get_interface_broadcast(config.interface).unwrap_or(Ipv4Addr::BROADCAST);
35
36        let socket = UdpSocket::bind(config.listen_address).expect("failed to bind to address");
37        info!("UDP socket bound on {}", config.listen_address);
38
39        // set_broadcast sets the SO_BROADCAST option on the socket
40        // which is a way of ensuring that the application
41        // can't _accidentally_ spam many devices with a broadcast
42        // address without intentionally intending to do so.
43        socket
44            .set_broadcast(true)
45            .expect("Failed to enable broadcasting");
46
47        Self { socket, broadcast }
48    }
49
50    /// Tries to get the directed broadcast address for the interface with the given name.
51    fn get_interface_broadcast(maybe_interface_name: Option<String>) -> Option<Ipv4Addr> {
52        let interface_name = maybe_interface_name?;
53        let interfaces = get_interfaces().ok()?;
54        let maybe_interface = interfaces.get(&interface_name);
55        let maybe_addr = maybe_interface?
56            .v4_addr
57            .iter()
58            .find(|address| address.broadcast.is_some());
59        let maybe_broadcast = maybe_addr?.broadcast;
60
61        debug!(
62            "Found ipv4 broadcast address ({:?}) in list of interface addresses",
63            maybe_broadcast
64        );
65
66        maybe_broadcast
67    }
68
69    /// Returns the ip address of the bound socket.
70    ///
71    /// Can panic as typically called outside of `listen_once` hot path.
72    pub fn get_ip(&self) -> Ipv4Addr {
73        match self.socket.local_addr().unwrap().ip() {
74            std::net::IpAddr::V4(ip) => ip,
75            std::net::IpAddr::V6(_) => todo!("ipv6 is not supported yet"),
76        }
77    }
78
79    /// Decodes received bytes into a "message" type that implements `Decodable`
80    /// and returns it and the source address.
81    pub fn receive<D: Decodable<Output = D>>(&self) -> Result<(D, SocketAddr)> {
82        // Any bytes over MAX_MESSAGE_SIZE will be discarded.
83        let mut buffer = Self::EMPTY_MESSAGE_BUFFER;
84        let (_, src) = match self.socket.recv_from(&mut buffer) {
85            Ok(values) => values,
86            Err(_) => return Err("Failed to receive data"),
87        };
88
89        let decoded = D::from_bytes(&UndecodedMessage::new(buffer));
90
91        debug!("Received dhcp message (from {}): {:?}", src, decoded);
92
93        Ok((decoded, src))
94    }
95
96    /// Returns received bytes directly without being decoded first,
97    /// which allows you to partially decode them yourself later.
98    pub fn receive_raw(&self) -> Result<(UndecodedMessage, SocketAddr)> {
99        // Any bytes over MAX_MESSAGE_SIZE will be discarded.
100        let mut buffer = Self::EMPTY_MESSAGE_BUFFER;
101        let (_, src) = match self.socket.recv_from(&mut buffer) {
102            Ok(values) => values,
103            Err(_) => return Err("Failed to receive data"),
104        };
105
106        debug!("Received dhcp message from {}", src);
107
108        Ok((UndecodedMessage::new(buffer), src))
109    }
110
111    /// Fills an empty undecoded message with passed bytes.
112    /// Used to mock receiving a specific `UndecodedMessage` in tests, etc.
113    pub fn receive_mock(partial_message: &[u8]) -> UndecodedMessage {
114        let mut buffer = Self::EMPTY_MESSAGE_BUFFER;
115        partial_message
116            .iter()
117            .enumerate()
118            .for_each(|(i, byte)| buffer[i] = *byte);
119        UndecodedMessage::new(buffer)
120    }
121
122    /// Converts a message to bytes and then sends it to the passed address.
123    pub fn unicast<E: Encodable>(&self, encodable: &E, address: SocketAddrV4) -> Result<()> {
124        let encoded = encodable.to_bytes();
125
126        let ip = address.ip();
127        debug!("Sending dhcp message (to {:?}): {:?}", ip, encodable);
128
129        match self.socket.send_to(&encoded, address) {
130            Ok(_) => Ok(()),
131            Err(_) => Err("Failed to send data"),
132        }
133    }
134
135    /// Send a message as bytes to many devices on the local network.
136    pub fn broadcast<E: Encodable>(&self, encodable: &E, port: u16) -> Result<()> {
137        self.unicast(encodable, SocketAddrV4::new(self.broadcast, port))
138    }
139}