1use crate::ip::IpNextLevelProtocol;
4use crate::Packet;
5
6use alloc::vec::Vec;
7
8use xenet_macro::packet;
9use xenet_macro_helper::types::*;
10
11use crate::util;
12use std::net::{Ipv4Addr, Ipv6Addr};
13
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16
17pub const UDP_HEADER_LEN: usize = 8;
19
20#[derive(Clone, Debug, PartialEq, Eq)]
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23pub struct UdpHeader {
24 pub source: u16be,
25 pub destination: u16be,
26 pub length: u16be,
27 pub checksum: u16be,
28}
29
30impl UdpHeader {
31 pub fn from_bytes(packet: &[u8]) -> Result<UdpHeader, String> {
33 if packet.len() < UDP_HEADER_LEN {
34 return Err("Packet is too small for UDP header".to_string());
35 }
36 match UdpPacket::new(packet) {
37 Some(udp_packet) => Ok(UdpHeader {
38 source: udp_packet.get_source(),
39 destination: udp_packet.get_destination(),
40 length: udp_packet.get_length(),
41 checksum: udp_packet.get_checksum(),
42 }),
43 None => Err("Failed to parse UDP packet".to_string()),
44 }
45 }
46 pub(crate) fn from_packet(udp_packet: &UdpPacket) -> UdpHeader {
48 UdpHeader {
49 source: udp_packet.get_source(),
50 destination: udp_packet.get_destination(),
51 length: udp_packet.get_length(),
52 checksum: udp_packet.get_checksum(),
53 }
54 }
55}
56
57#[packet]
59pub struct Udp {
60 pub source: u16be,
61 pub destination: u16be,
62 pub length: u16be,
63 pub checksum: u16be,
64 #[payload]
65 pub payload: Vec<u8>,
66}
67
68pub fn ipv4_checksum(packet: &UdpPacket, source: &Ipv4Addr, destination: &Ipv4Addr) -> u16be {
70 ipv4_checksum_adv(packet, &[], source, destination)
71}
72
73pub fn ipv4_checksum_adv(
81 packet: &UdpPacket,
82 extra_data: &[u8],
83 source: &Ipv4Addr,
84 destination: &Ipv4Addr,
85) -> u16be {
86 util::ipv4_checksum(
87 packet.packet(),
88 3,
89 extra_data,
90 source,
91 destination,
92 IpNextLevelProtocol::Udp,
93 )
94}
95
96#[test]
97fn udp_header_ipv4_test() {
98 use crate::ip::IpNextLevelProtocol;
99 use crate::ipv4::MutableIpv4Packet;
100
101 let mut packet = [0u8; 20 + 8 + 4];
102 let ipv4_source = Ipv4Addr::new(192, 168, 0, 1);
103 let ipv4_destination = Ipv4Addr::new(192, 168, 0, 199);
104 {
105 let mut ip_header = MutableIpv4Packet::new(&mut packet[..]).unwrap();
106 ip_header.set_next_level_protocol(IpNextLevelProtocol::Udp);
107 ip_header.set_source(ipv4_source);
108 ip_header.set_destination(ipv4_destination);
109 }
110
111 packet[20 + 8] = 't' as u8;
113 packet[20 + 8 + 1] = 'e' as u8;
114 packet[20 + 8 + 2] = 's' as u8;
115 packet[20 + 8 + 3] = 't' as u8;
116
117 {
118 let mut udp_header = MutableUdpPacket::new(&mut packet[20..]).unwrap();
119 udp_header.set_source(12345);
120 assert_eq!(udp_header.get_source(), 12345);
121
122 udp_header.set_destination(54321);
123 assert_eq!(udp_header.get_destination(), 54321);
124
125 udp_header.set_length(8 + 4);
126 assert_eq!(udp_header.get_length(), 8 + 4);
127
128 let checksum = ipv4_checksum(&udp_header.to_immutable(), &ipv4_source, &ipv4_destination);
129 udp_header.set_checksum(checksum);
130 assert_eq!(udp_header.get_checksum(), 0x9178);
131 }
132
133 let ref_packet = [
134 0x30, 0x39, 0xd4, 0x31, 0x00, 0x0c, 0x91, 0x78, ];
139 assert_eq!(&ref_packet[..], &packet[20..28]);
140}
141
142pub fn ipv6_checksum(packet: &UdpPacket, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16be {
144 ipv6_checksum_adv(packet, &[], source, destination)
145}
146
147pub fn ipv6_checksum_adv(
155 packet: &UdpPacket,
156 extra_data: &[u8],
157 source: &Ipv6Addr,
158 destination: &Ipv6Addr,
159) -> u16be {
160 util::ipv6_checksum(
161 packet.packet(),
162 3,
163 extra_data,
164 source,
165 destination,
166 IpNextLevelProtocol::Udp,
167 )
168}
169
170#[test]
171fn udp_header_ipv6_test() {
172 use crate::ip::IpNextLevelProtocol;
173 use crate::ipv6::MutableIpv6Packet;
174
175 let mut packet = [0u8; 40 + 8 + 4];
176 let ipv6_source = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1);
177 let ipv6_destination = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1);
178 {
179 let mut ip_header = MutableIpv6Packet::new(&mut packet[..]).unwrap();
180 ip_header.set_next_header(IpNextLevelProtocol::Udp);
181 ip_header.set_source(ipv6_source);
182 ip_header.set_destination(ipv6_destination);
183 }
184
185 packet[40 + 8] = 't' as u8;
187 packet[40 + 8 + 1] = 'e' as u8;
188 packet[40 + 8 + 2] = 's' as u8;
189 packet[40 + 8 + 3] = 't' as u8;
190
191 {
192 let mut udp_header = MutableUdpPacket::new(&mut packet[40..]).unwrap();
193 udp_header.set_source(12345);
194 assert_eq!(udp_header.get_source(), 12345);
195
196 udp_header.set_destination(54321);
197 assert_eq!(udp_header.get_destination(), 54321);
198
199 udp_header.set_length(8 + 4);
200 assert_eq!(udp_header.get_length(), 8 + 4);
201
202 let checksum = ipv6_checksum(&udp_header.to_immutable(), &ipv6_source, &ipv6_destination);
203 udp_header.set_checksum(checksum);
204 assert_eq!(udp_header.get_checksum(), 0x1390);
205 }
206
207 let ref_packet = [
208 0x30, 0x39, 0xd4, 0x31, 0x00, 0x0c, 0x13, 0x90, ];
213 assert_eq!(&ref_packet[..], &packet[40..48]);
214}