Skip to main content

toe_beans/v4/message/
socket.rs

1use super::{Decodable, Encodable};
2use crate::v4::{MAX_MESSAGE_SIZE, Result, UndecodedMessage};
3use inherface::get_interfaces;
4use log::{debug, info};
5use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket};
6
7/// A connection to a UdpSocket that understands how to send/receive Deliverable.
8/// Meant to be used by both Client and Server.
9///
10/// Currently wraps [UdpSocket::bind](https://doc.rust-lang.org/std/net/struct.UdpSocket.html#method.bind)
11#[derive(Debug)]
12pub struct Socket {
13    socket: UdpSocket,
14    broadcast: Ipv4Addr,
15}
16
17/// Bytes, received from the socket, that can be decoded into a Message.
18pub type MessageBuffer = [u8; MAX_MESSAGE_SIZE];
19
20impl Socket {
21    /// A `MessageBuffer` with all 0 bytes.
22    pub const EMPTY_MESSAGE_BUFFER: MessageBuffer = [0; MAX_MESSAGE_SIZE];
23
24    /// Bind to an ip address/port and require that broadcast is enabled on the socket.
25    ///
26    /// Should be be called by both `Server::new` and `Client::new`, so this can be slower/panic
27    /// since it is not in the `listen_once` hot path.
28    pub fn new(address: SocketAddrV4, interface: Option<&String>) -> Self {
29        let broadcast = Self::get_interface_broadcast(interface).unwrap_or(Ipv4Addr::BROADCAST);
30
31        let socket = UdpSocket::bind(address).expect("failed to bind to address");
32        info!("UDP socket bound on {}", address);
33
34        // set_broadcast sets the SO_BROADCAST option on the socket
35        // which is a way of ensuring that the application
36        // can't _accidentally_ spam many devices with a broadcast
37        // address without intentionally intending to do so.
38        socket
39            .set_broadcast(true)
40            .expect("Failed to enable broadcasting");
41
42        Self { socket, broadcast }
43    }
44
45    /// Tries to get the directed broadcast address for the interface with the given name.
46    fn get_interface_broadcast(maybe_interface_name: Option<&String>) -> Option<Ipv4Addr> {
47        let interface_name = maybe_interface_name?;
48        let interfaces = get_interfaces().ok()?;
49        let maybe_interface = interfaces.get(interface_name);
50        let maybe_addr = maybe_interface?
51            .v4_addr
52            .iter()
53            .find(|address| address.broadcast.is_some());
54        let maybe_broadcast = maybe_addr?.broadcast;
55
56        debug!(
57            "Found ipv4 broadcast address ({:?}) in list of interface addresses",
58            maybe_broadcast
59        );
60
61        maybe_broadcast
62    }
63
64    /// Returns the ip address of the bound socket.
65    ///
66    /// Can panic as typically called outside of `listen_once` hot path.
67    pub fn get_ip(&self) -> Ipv4Addr {
68        match self.socket.local_addr().unwrap().ip() {
69            std::net::IpAddr::V4(ip) => ip,
70            std::net::IpAddr::V6(_) => todo!("ipv6 is not supported yet"),
71        }
72    }
73
74    /// Decodes received bytes into a "message" type that implements `DecodeMessage`
75    /// and returns it and the source address.
76    pub fn receive<D: Decodable<Output = D>>(&self) -> Result<(D, SocketAddr)> {
77        // Any bytes over MAX_MESSAGE_SIZE will be discarded.
78        let mut buffer = Self::EMPTY_MESSAGE_BUFFER;
79        let (_, src) = match self.socket.recv_from(&mut buffer) {
80            Ok(values) => values,
81            Err(_) => return Err("Failed to receive data"),
82        };
83
84        let decoded = D::from_bytes(&UndecodedMessage::new(buffer));
85
86        debug!("Received dhcp message (from {}): {:?}", src, decoded);
87
88        Ok((decoded, src))
89    }
90
91    /// Returns received bytes directly without being decoded into a `Deliverable`,
92    /// which allows you to partially decode them yourself later.
93    pub fn receive_raw(&self) -> Result<(UndecodedMessage, SocketAddr)> {
94        // Any bytes over MAX_MESSAGE_SIZE will be discarded.
95        let mut buffer = Self::EMPTY_MESSAGE_BUFFER;
96        let (_, src) = match self.socket.recv_from(&mut buffer) {
97            Ok(values) => values,
98            Err(_) => return Err("Failed to receive data"),
99        };
100
101        debug!("Received dhcp message from {}", src);
102
103        Ok((UndecodedMessage::new(buffer), src))
104    }
105
106    /// Fills an empty undecoded message with passed bytes.
107    /// Used to mock receiving a specific `UndecodedMessage` in tests, etc.
108    pub fn receive_mock(partial_message: &[u8]) -> UndecodedMessage {
109        let mut buffer = Self::EMPTY_MESSAGE_BUFFER;
110        partial_message
111            .iter()
112            .enumerate()
113            .for_each(|(i, byte)| buffer[i] = *byte);
114        UndecodedMessage::new(buffer)
115    }
116
117    /// Converts a message to bytes and then sends it to the passed address.
118    pub fn unicast<E: Encodable>(&self, encodable: &E, address: SocketAddrV4) -> Result<()> {
119        let encoded = encodable.to_bytes();
120
121        let ip = address.ip();
122        debug!("Sending dhcp message (to {:?}): {:?}", ip, encodable);
123
124        match self.socket.send_to(&encoded, address) {
125            Ok(_) => Ok(()),
126            Err(_) => Err("Failed to send data"),
127        }
128    }
129
130    /// Send a message as bytes to many devices on the local network.
131    pub fn broadcast<E: Encodable>(&self, encodable: &E, port: u16) -> Result<()> {
132        self.unicast(encodable, SocketAddrV4::new(self.broadcast, port))
133    }
134}