ya_smoltcp/wire/
udp.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use core::fmt;
3
4use crate::phy::ChecksumCapabilities;
5use crate::wire::ip::checksum;
6use crate::wire::{IpAddress, IpProtocol};
7use crate::{Error, Result};
8
9/// A read/write wrapper around an User Datagram Protocol packet buffer.
10#[derive(Debug, PartialEq, Clone)]
11#[cfg_attr(feature = "defmt", derive(defmt::Format))]
12pub struct Packet<T: AsRef<[u8]>> {
13    buffer: T,
14}
15
16mod field {
17    #![allow(non_snake_case)]
18
19    use crate::wire::field::*;
20
21    pub const SRC_PORT: Field = 0..2;
22    pub const DST_PORT: Field = 2..4;
23    pub const LENGTH: Field = 4..6;
24    pub const CHECKSUM: Field = 6..8;
25
26    pub fn PAYLOAD(length: u16) -> Field {
27        CHECKSUM.end..(length as usize)
28    }
29}
30
31pub const HEADER_LEN: usize = field::CHECKSUM.end;
32
33#[allow(clippy::len_without_is_empty)]
34impl<T: AsRef<[u8]>> Packet<T> {
35    /// Imbue a raw octet buffer with UDP packet structure.
36    pub fn new_unchecked(buffer: T) -> Packet<T> {
37        Packet { buffer }
38    }
39
40    /// Shorthand for a combination of [new_unchecked] and [check_len].
41    ///
42    /// [new_unchecked]: #method.new_unchecked
43    /// [check_len]: #method.check_len
44    pub fn new_checked(buffer: T) -> Result<Packet<T>> {
45        let packet = Self::new_unchecked(buffer);
46        packet.check_len()?;
47        Ok(packet)
48    }
49
50    /// Ensure that no accessor method will panic if called.
51    /// Returns `Err(Error::Truncated)` if the buffer is too short.
52    /// Returns `Err(Error::Malformed)` if the length field has a value smaller
53    /// than the header length.
54    ///
55    /// The result of this check is invalidated by calling [set_len].
56    ///
57    /// [set_len]: #method.set_len
58    pub fn check_len(&self) -> Result<()> {
59        let buffer_len = self.buffer.as_ref().len();
60        if buffer_len < HEADER_LEN {
61            Err(Error::Truncated)
62        } else {
63            let field_len = self.len() as usize;
64            if buffer_len < field_len {
65                Err(Error::Truncated)
66            } else if field_len < HEADER_LEN {
67                Err(Error::Malformed)
68            } else {
69                Ok(())
70            }
71        }
72    }
73
74    /// Consume the packet, returning the underlying buffer.
75    pub fn into_inner(self) -> T {
76        self.buffer
77    }
78
79    /// Return the source port field.
80    #[inline]
81    pub fn src_port(&self) -> u16 {
82        let data = self.buffer.as_ref();
83        NetworkEndian::read_u16(&data[field::SRC_PORT])
84    }
85
86    /// Return the destination port field.
87    #[inline]
88    pub fn dst_port(&self) -> u16 {
89        let data = self.buffer.as_ref();
90        NetworkEndian::read_u16(&data[field::DST_PORT])
91    }
92
93    /// Return the length field.
94    #[inline]
95    pub fn len(&self) -> u16 {
96        let data = self.buffer.as_ref();
97        NetworkEndian::read_u16(&data[field::LENGTH])
98    }
99
100    /// Return the checksum field.
101    #[inline]
102    pub fn checksum(&self) -> u16 {
103        let data = self.buffer.as_ref();
104        NetworkEndian::read_u16(&data[field::CHECKSUM])
105    }
106
107    /// Validate the packet checksum.
108    ///
109    /// # Panics
110    /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
111    /// and that family is IPv4 or IPv6.
112    ///
113    /// # Fuzzing
114    /// This function always returns `true` when fuzzing.
115    pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool {
116        if cfg!(fuzzing) {
117            return true;
118        }
119
120        let data = self.buffer.as_ref();
121        checksum::combine(&[
122            checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
123            checksum::data(&data[..self.len() as usize]),
124        ]) == !0
125    }
126}
127
128impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
129    /// Return a pointer to the payload.
130    #[inline]
131    pub fn payload(&self) -> &'a [u8] {
132        let length = self.len();
133        let data = self.buffer.as_ref();
134        &data[field::PAYLOAD(length)]
135    }
136}
137
138impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
139    /// Set the source port field.
140    #[inline]
141    pub fn set_src_port(&mut self, value: u16) {
142        let data = self.buffer.as_mut();
143        NetworkEndian::write_u16(&mut data[field::SRC_PORT], value)
144    }
145
146    /// Set the destination port field.
147    #[inline]
148    pub fn set_dst_port(&mut self, value: u16) {
149        let data = self.buffer.as_mut();
150        NetworkEndian::write_u16(&mut data[field::DST_PORT], value)
151    }
152
153    /// Set the length field.
154    #[inline]
155    pub fn set_len(&mut self, value: u16) {
156        let data = self.buffer.as_mut();
157        NetworkEndian::write_u16(&mut data[field::LENGTH], value)
158    }
159
160    /// Set the checksum field.
161    #[inline]
162    pub fn set_checksum(&mut self, value: u16) {
163        let data = self.buffer.as_mut();
164        NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
165    }
166
167    /// Compute and fill in the header checksum.
168    ///
169    /// # Panics
170    /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
171    /// and that family is IPv4 or IPv6.
172    pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) {
173        self.set_checksum(0);
174        let checksum = {
175            let data = self.buffer.as_ref();
176            !checksum::combine(&[
177                checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
178                checksum::data(&data[..self.len() as usize]),
179            ])
180        };
181        // UDP checksum value of 0 means no checksum; if the checksum really is zero,
182        // use all-ones, which indicates that the remote end must verify the checksum.
183        // Arithmetically, RFC 1071 checksums of all-zeroes and all-ones behave identically,
184        // so no action is necessary on the remote end.
185        self.set_checksum(if checksum == 0 { 0xffff } else { checksum })
186    }
187
188    /// Return a mutable pointer to the payload.
189    #[inline]
190    pub fn payload_mut(&mut self) -> &mut [u8] {
191        let length = self.len();
192        let data = self.buffer.as_mut();
193        &mut data[field::PAYLOAD(length)]
194    }
195}
196
197impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
198    fn as_ref(&self) -> &[u8] {
199        self.buffer.as_ref()
200    }
201}
202
203/// A high-level representation of an User Datagram Protocol packet.
204#[derive(Debug, PartialEq, Eq, Clone, Copy)]
205#[cfg_attr(feature = "defmt", derive(defmt::Format))]
206pub struct Repr {
207    pub src_port: u16,
208    pub dst_port: u16,
209}
210
211impl Repr {
212    /// Parse an User Datagram Protocol packet and return a high-level representation.
213    pub fn parse<T>(
214        packet: &Packet<&T>,
215        src_addr: &IpAddress,
216        dst_addr: &IpAddress,
217        checksum_caps: &ChecksumCapabilities,
218    ) -> Result<Repr>
219    where
220        T: AsRef<[u8]> + ?Sized,
221    {
222        // Destination port cannot be omitted (but source port can be).
223        if packet.dst_port() == 0 {
224            return Err(Error::Malformed);
225        }
226        // Valid checksum is expected...
227        if checksum_caps.udp.rx() && !packet.verify_checksum(src_addr, dst_addr) {
228            match (src_addr, dst_addr) {
229                // ... except on UDP-over-IPv4, where it can be omitted.
230                #[cfg(feature = "proto-ipv4")]
231                (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_)) if packet.checksum() == 0 => (),
232                _ => return Err(Error::Checksum),
233            }
234        }
235
236        Ok(Repr {
237            src_port: packet.src_port(),
238            dst_port: packet.dst_port(),
239        })
240    }
241
242    /// Return the length of the packet header that will be emitted from this high-level representation.
243    pub fn header_len(&self) -> usize {
244        HEADER_LEN
245    }
246
247    /// Emit a high-level representation into an User Datagram Protocol packet.
248    pub fn emit<T: ?Sized>(
249        &self,
250        packet: &mut Packet<&mut T>,
251        src_addr: &IpAddress,
252        dst_addr: &IpAddress,
253        payload_len: usize,
254        emit_payload: impl FnOnce(&mut [u8]),
255        checksum_caps: &ChecksumCapabilities,
256    ) where
257        T: AsRef<[u8]> + AsMut<[u8]>,
258    {
259        packet.set_src_port(self.src_port);
260        packet.set_dst_port(self.dst_port);
261        packet.set_len((HEADER_LEN + payload_len) as u16);
262        emit_payload(packet.payload_mut());
263
264        if checksum_caps.udp.tx() {
265            packet.fill_checksum(src_addr, dst_addr)
266        } else {
267            // make sure we get a consistently zeroed checksum,
268            // since implementations might rely on it
269            packet.set_checksum(0);
270        }
271    }
272}
273
274impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
275    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
276        // Cannot use Repr::parse because we don't have the IP addresses.
277        write!(
278            f,
279            "UDP src={} dst={} len={}",
280            self.src_port(),
281            self.dst_port(),
282            self.payload().len()
283        )
284    }
285}
286
287impl fmt::Display for Repr {
288    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
289        write!(f, "UDP src={} dst={}", self.src_port, self.dst_port)
290    }
291}
292
293use crate::wire::pretty_print::{PrettyIndent, PrettyPrint};
294
295impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
296    fn pretty_print(
297        buffer: &dyn AsRef<[u8]>,
298        f: &mut fmt::Formatter,
299        indent: &mut PrettyIndent,
300    ) -> fmt::Result {
301        match Packet::new_checked(buffer) {
302            Err(err) => write!(f, "{}({})", indent, err),
303            Ok(packet) => write!(f, "{}{}", indent, packet),
304        }
305    }
306}
307
308#[cfg(test)]
309mod test {
310    use super::*;
311    #[cfg(feature = "proto-ipv4")]
312    use crate::wire::Ipv4Address;
313
314    #[cfg(feature = "proto-ipv4")]
315    const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
316    #[cfg(feature = "proto-ipv4")]
317    const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
318
319    #[cfg(feature = "proto-ipv4")]
320    static PACKET_BYTES: [u8; 12] = [
321        0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff,
322    ];
323
324    #[cfg(feature = "proto-ipv4")]
325    static NO_CHECKSUM_PACKET: [u8; 12] = [
326        0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0xaa, 0x00, 0x00, 0xff,
327    ];
328
329    #[cfg(feature = "proto-ipv4")]
330    static PAYLOAD_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff];
331
332    #[test]
333    #[cfg(feature = "proto-ipv4")]
334    fn test_deconstruct() {
335        let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
336        assert_eq!(packet.src_port(), 48896);
337        assert_eq!(packet.dst_port(), 53);
338        assert_eq!(packet.len(), 12);
339        assert_eq!(packet.checksum(), 0x124d);
340        assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]);
341        assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()));
342    }
343
344    #[test]
345    #[cfg(feature = "proto-ipv4")]
346    fn test_construct() {
347        let mut bytes = vec![0xa5; 12];
348        let mut packet = Packet::new_unchecked(&mut bytes);
349        packet.set_src_port(48896);
350        packet.set_dst_port(53);
351        packet.set_len(12);
352        packet.set_checksum(0xffff);
353        packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]);
354        packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
355        assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
356    }
357
358    #[test]
359    fn test_impossible_len() {
360        let mut bytes = vec![0; 12];
361        let mut packet = Packet::new_unchecked(&mut bytes);
362        packet.set_len(4);
363        assert_eq!(packet.check_len(), Err(Error::Malformed));
364    }
365
366    #[test]
367    #[cfg(feature = "proto-ipv4")]
368    fn test_zero_checksum() {
369        let mut bytes = vec![0; 8];
370        let mut packet = Packet::new_unchecked(&mut bytes);
371        packet.set_src_port(1);
372        packet.set_dst_port(31881);
373        packet.set_len(8);
374        packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
375        assert_eq!(packet.checksum(), 0xffff);
376    }
377
378    #[cfg(feature = "proto-ipv4")]
379    fn packet_repr() -> Repr {
380        Repr {
381            src_port: 48896,
382            dst_port: 53,
383        }
384    }
385
386    #[test]
387    #[cfg(feature = "proto-ipv4")]
388    fn test_parse() {
389        let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
390        let repr = Repr::parse(
391            &packet,
392            &SRC_ADDR.into(),
393            &DST_ADDR.into(),
394            &ChecksumCapabilities::default(),
395        )
396        .unwrap();
397        assert_eq!(repr, packet_repr());
398    }
399
400    #[test]
401    #[cfg(feature = "proto-ipv4")]
402    fn test_emit() {
403        let repr = packet_repr();
404        let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()];
405        let mut packet = Packet::new_unchecked(&mut bytes);
406        repr.emit(
407            &mut packet,
408            &SRC_ADDR.into(),
409            &DST_ADDR.into(),
410            PAYLOAD_BYTES.len(),
411            |payload| payload.copy_from_slice(&PAYLOAD_BYTES),
412            &ChecksumCapabilities::default(),
413        );
414        assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
415    }
416
417    #[test]
418    #[cfg(feature = "proto-ipv4")]
419    fn test_checksum_omitted() {
420        let packet = Packet::new_unchecked(&NO_CHECKSUM_PACKET[..]);
421        let repr = Repr::parse(
422            &packet,
423            &SRC_ADDR.into(),
424            &DST_ADDR.into(),
425            &ChecksumCapabilities::default(),
426        )
427        .unwrap();
428        assert_eq!(repr, packet_repr());
429    }
430}