xenet_packet/
udp.rs

1//! A UDP packet abstraction.
2
3use crate::ip::IpNextLevelProtocol;
4use crate::Packet;
5
6use alloc::vec::Vec;
7
8use xenet_macro::packet;
9use xenet_macro_helper::types::*;
10
11use crate::util;
12use std::net::{Ipv4Addr, Ipv6Addr};
13
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16
17/// UDP Header Length
18pub const UDP_HEADER_LEN: usize = 8;
19
20/// Represents the UDP header.
21#[derive(Clone, Debug, PartialEq, Eq)]
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23pub struct UdpHeader {
24    pub source: u16be,
25    pub destination: u16be,
26    pub length: u16be,
27    pub checksum: u16be,
28}
29
30impl UdpHeader {
31    /// Construct a UDP header from a byte slice.
32    pub fn from_bytes(packet: &[u8]) -> Result<UdpHeader, String> {
33        if packet.len() < UDP_HEADER_LEN {
34            return Err("Packet is too small for UDP header".to_string());
35        }
36        match UdpPacket::new(packet) {
37            Some(udp_packet) => Ok(UdpHeader {
38                source: udp_packet.get_source(),
39                destination: udp_packet.get_destination(),
40                length: udp_packet.get_length(),
41                checksum: udp_packet.get_checksum(),
42            }),
43            None => Err("Failed to parse UDP packet".to_string()),
44        }
45    }
46    /// Construct a UDP header from a UdpPacket.
47    pub(crate) fn from_packet(udp_packet: &UdpPacket) -> UdpHeader {
48        UdpHeader {
49            source: udp_packet.get_source(),
50            destination: udp_packet.get_destination(),
51            length: udp_packet.get_length(),
52            checksum: udp_packet.get_checksum(),
53        }
54    }
55}
56
57/// Represents a UDP Packet.
58#[packet]
59pub struct Udp {
60    pub source: u16be,
61    pub destination: u16be,
62    pub length: u16be,
63    pub checksum: u16be,
64    #[payload]
65    pub payload: Vec<u8>,
66}
67
68/// Calculate a checksum for a packet built on IPv4.
69pub fn ipv4_checksum(packet: &UdpPacket, source: &Ipv4Addr, destination: &Ipv4Addr) -> u16be {
70    ipv4_checksum_adv(packet, &[], source, destination)
71}
72
73/// Calculate a checksum for a packet built on IPv4. Advanced version which
74/// accepts an extra slice of data that will be included in the checksum
75/// as being part of the data portion of the packet.
76///
77/// If `packet` contains an odd number of bytes the last byte will not be
78/// counted as the first byte of a word together with the first byte of
79/// `extra_data`.
80pub fn ipv4_checksum_adv(
81    packet: &UdpPacket,
82    extra_data: &[u8],
83    source: &Ipv4Addr,
84    destination: &Ipv4Addr,
85) -> u16be {
86    util::ipv4_checksum(
87        packet.packet(),
88        3,
89        extra_data,
90        source,
91        destination,
92        IpNextLevelProtocol::Udp,
93    )
94}
95
96#[test]
97fn udp_header_ipv4_test() {
98    use crate::ip::IpNextLevelProtocol;
99    use crate::ipv4::MutableIpv4Packet;
100
101    let mut packet = [0u8; 20 + 8 + 4];
102    let ipv4_source = Ipv4Addr::new(192, 168, 0, 1);
103    let ipv4_destination = Ipv4Addr::new(192, 168, 0, 199);
104    {
105        let mut ip_header = MutableIpv4Packet::new(&mut packet[..]).unwrap();
106        ip_header.set_next_level_protocol(IpNextLevelProtocol::Udp);
107        ip_header.set_source(ipv4_source);
108        ip_header.set_destination(ipv4_destination);
109    }
110
111    // Set data
112    packet[20 + 8] = 't' as u8;
113    packet[20 + 8 + 1] = 'e' as u8;
114    packet[20 + 8 + 2] = 's' as u8;
115    packet[20 + 8 + 3] = 't' as u8;
116
117    {
118        let mut udp_header = MutableUdpPacket::new(&mut packet[20..]).unwrap();
119        udp_header.set_source(12345);
120        assert_eq!(udp_header.get_source(), 12345);
121
122        udp_header.set_destination(54321);
123        assert_eq!(udp_header.get_destination(), 54321);
124
125        udp_header.set_length(8 + 4);
126        assert_eq!(udp_header.get_length(), 8 + 4);
127
128        let checksum = ipv4_checksum(&udp_header.to_immutable(), &ipv4_source, &ipv4_destination);
129        udp_header.set_checksum(checksum);
130        assert_eq!(udp_header.get_checksum(), 0x9178);
131    }
132
133    let ref_packet = [
134        0x30, 0x39, /* source */
135        0xd4, 0x31, /* destination */
136        0x00, 0x0c, /* length */
137        0x91, 0x78, /* checksum */
138    ];
139    assert_eq!(&ref_packet[..], &packet[20..28]);
140}
141
142/// Calculate a checksum for a packet built on IPv6.
143pub fn ipv6_checksum(packet: &UdpPacket, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16be {
144    ipv6_checksum_adv(packet, &[], source, destination)
145}
146
147/// Calculate the checksum for a packet built on IPv6. Advanced version which
148/// accepts an extra slice of data that will be included in the checksum
149/// as being part of the data portion of the packet.
150///
151/// If `packet` contains an odd number of bytes the last byte will not be
152/// counted as the first byte of a word together with the first byte of
153/// `extra_data`.
154pub fn ipv6_checksum_adv(
155    packet: &UdpPacket,
156    extra_data: &[u8],
157    source: &Ipv6Addr,
158    destination: &Ipv6Addr,
159) -> u16be {
160    util::ipv6_checksum(
161        packet.packet(),
162        3,
163        extra_data,
164        source,
165        destination,
166        IpNextLevelProtocol::Udp,
167    )
168}
169
170#[test]
171fn udp_header_ipv6_test() {
172    use crate::ip::IpNextLevelProtocol;
173    use crate::ipv6::MutableIpv6Packet;
174
175    let mut packet = [0u8; 40 + 8 + 4];
176    let ipv6_source = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1);
177    let ipv6_destination = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1);
178    {
179        let mut ip_header = MutableIpv6Packet::new(&mut packet[..]).unwrap();
180        ip_header.set_next_header(IpNextLevelProtocol::Udp);
181        ip_header.set_source(ipv6_source);
182        ip_header.set_destination(ipv6_destination);
183    }
184
185    // Set data
186    packet[40 + 8] = 't' as u8;
187    packet[40 + 8 + 1] = 'e' as u8;
188    packet[40 + 8 + 2] = 's' as u8;
189    packet[40 + 8 + 3] = 't' as u8;
190
191    {
192        let mut udp_header = MutableUdpPacket::new(&mut packet[40..]).unwrap();
193        udp_header.set_source(12345);
194        assert_eq!(udp_header.get_source(), 12345);
195
196        udp_header.set_destination(54321);
197        assert_eq!(udp_header.get_destination(), 54321);
198
199        udp_header.set_length(8 + 4);
200        assert_eq!(udp_header.get_length(), 8 + 4);
201
202        let checksum = ipv6_checksum(&udp_header.to_immutable(), &ipv6_source, &ipv6_destination);
203        udp_header.set_checksum(checksum);
204        assert_eq!(udp_header.get_checksum(), 0x1390);
205    }
206
207    let ref_packet = [
208        0x30, 0x39, /* source */
209        0xd4, 0x31, /* destination */
210        0x00, 0x0c, /* length */
211        0x13, 0x90, /* checksum */
212    ];
213    assert_eq!(&ref_packet[..], &packet[40..48]);
214}