zero_packet/transport/
udp.rs

1use crate::network::checksum::internet_checksum;
2use core::fmt;
3
4/// The length of a UDP header in bytes.
5pub const UDP_HEADER_LENGTH: usize = 8;
6
7/// Writes UDP header fields.
8pub struct UdpWriter<'a> {
9    pub bytes: &'a mut [u8],
10}
11
12impl<'a> UdpWriter<'a> {
13    /// Creates a new `UdpPacket` from the given data slice.
14    #[inline]
15    pub fn new(bytes: &'a mut [u8]) -> Result<Self, &'static str> {
16        if bytes.len() < UDP_HEADER_LENGTH {
17            return Err("Slice is too short to contain a UDP header.");
18        }
19
20        Ok(Self { bytes })
21    }
22
23    /// Returns the header length in bytes.
24    #[inline]
25    pub fn header_len(&self) -> usize {
26        UDP_HEADER_LENGTH
27    }
28
29    /// Returns the total length of the packet in bytes.
30    #[inline]
31    pub fn packet_len(&self) -> usize {
32        self.bytes.len()
33    }
34
35    /// Sets the source port field.
36    #[inline]
37    pub fn set_src_port(&mut self, src_port: u16) {
38        self.bytes[0] = (src_port >> 8) as u8;
39        self.bytes[1] = src_port as u8;
40    }
41
42    /// Sets the destination port field.
43    #[inline]
44    pub fn set_dest_port(&mut self, dst_port: u16) {
45        self.bytes[2] = (dst_port >> 8) as u8;
46        self.bytes[3] = dst_port as u8;
47    }
48
49    /// Sets the length field.
50    ///
51    /// Should include the entire packet (header and payload).
52    #[inline]
53    pub fn set_length(&mut self, length: u16) {
54        self.bytes[4] = (length >> 8) as u8;
55        self.bytes[5] = length as u8;
56    }
57
58    /// Calculates the checksum field for error checking.
59    ///
60    /// Includes the UDP header, payload and IPv4 pseudo header.
61    ///
62    /// Checksum is optional in UDP for IPv4, but mandatory for IPv6.
63    /// Although, it is strongly recommended to use checksums for both. See: RFC 1122.
64    #[inline]
65    pub fn set_checksum(&mut self, pseudo_sum: u32) {
66        self.bytes[6] = 0;
67        self.bytes[7] = 0;
68        let checksum = internet_checksum(self.bytes, pseudo_sum);
69        self.bytes[6] = (checksum >> 8) as u8;
70        self.bytes[7] = (checksum & 0xff) as u8;
71    }
72
73    /// Sets the payload field.
74    ///
75    /// The `PacketBuilder` sets the payload before the checksum.
76    ///
77    /// That is, because the checksum is invalidated if a payload is set after it.
78    #[inline]
79    pub fn set_payload(&mut self, payload: &[u8]) -> Result<(), &'static str> {
80        let start = self.header_len();
81        let payload_len = payload.len();
82
83        if self.packet_len() - start < payload_len {
84            return Err("Payload is too large to fit in the TCP packet.");
85        }
86
87        let end = start + payload_len;
88        self.bytes[start..end].copy_from_slice(payload);
89
90        Ok(())
91    }
92}
93
94/// Reads UDP header fields.
95#[derive(PartialEq)]
96pub struct UdpReader<'a> {
97    pub bytes: &'a [u8],
98}
99
100impl<'a> UdpReader<'a> {
101    /// Creates a new `UdpPacket` from the given data slice.
102    #[inline]
103    pub fn new(bytes: &'a [u8]) -> Result<Self, &'static str> {
104        if bytes.len() < UDP_HEADER_LENGTH {
105            return Err("Slice is too short to contain a UDP header.");
106        }
107
108        Ok(Self { bytes })
109    }
110
111    /// Returns the source port field.
112    #[inline]
113    pub fn src_port(&self) -> u16 {
114        ((self.bytes[0] as u16) << 8) | (self.bytes[1] as u16)
115    }
116
117    /// Returns the destination port field.
118    #[inline]
119    pub fn dest_port(&self) -> u16 {
120        ((self.bytes[2] as u16) << 8) | (self.bytes[3] as u16)
121    }
122
123    /// Returns the checksum field.
124    #[inline]
125    pub fn checksum(&self) -> u16 {
126        ((self.bytes[6] as u16) << 8) | (self.bytes[7] as u16)
127    }
128
129    /// Returns the length field.
130    ///
131    /// Includes the entire packet (header and payload).
132    #[inline]
133    pub fn length(&self) -> u16 {
134        ((self.bytes[4] as u16) << 8) | (self.bytes[5] as u16)
135    }
136
137    /// Returns the header length in bytes.
138    #[inline]
139    pub fn header_len(&self) -> usize {
140        UDP_HEADER_LENGTH
141    }
142
143    /// Returns a reference to the header.
144    #[inline]
145    pub fn header(&self) -> &'a [u8] {
146        &self.bytes[..UDP_HEADER_LENGTH]
147    }
148
149    /// Returns a reference to the payload.
150    #[inline]
151    pub fn payload(&self) -> &'a [u8] {
152        &self.bytes[UDP_HEADER_LENGTH..]
153    }
154}
155
156impl fmt::Debug for UdpReader<'_> {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        f.debug_struct("UdpDatagram")
159            .field("src_port", &self.src_port())
160            .field("dest_port", &self.dest_port())
161            .field("length", &self.length())
162            .finish()
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use crate::network::checksum::pseudo_header;
170
171    #[test]
172    fn getters_and_setters() {
173        // Raw packet.
174        let mut bytes = [0u8; 8];
175
176        // Random values.
177        let src_port = 12345;
178        let dst_port = 54321;
179        let length = 8;
180        let src_ip = [192, 168, 1, 1];
181        let dest_ip = [192, 168, 1, 2];
182
183        // Create a UDP packet writer.
184        let mut writer = UdpWriter::new(&mut bytes).unwrap();
185
186        // Set the fields.
187        writer.set_src_port(src_port);
188        writer.set_dest_port(dst_port);
189        writer.set_length(length);
190
191        // Calculate the checksum.
192        let pseudo_sum = pseudo_header(&src_ip, &dest_ip, 17, writer.packet_len());
193        writer.set_checksum(pseudo_sum);
194
195        // Create a UDP packet reader.
196        let reader = UdpReader::new(&bytes).unwrap();
197
198        // Ensure the fields are set and retrieved correctly.
199        assert_eq!(reader.src_port(), src_port);
200        assert_eq!(reader.dest_port(), dst_port);
201        assert_eq!(reader.length(), length);
202    }
203}