zero_packet/transport/
udp.rs1use crate::network::checksum::internet_checksum;
2use core::fmt;
3
4pub const UDP_HEADER_LENGTH: usize = 8;
6
7pub struct UdpWriter<'a> {
9 pub bytes: &'a mut [u8],
10}
11
12impl<'a> UdpWriter<'a> {
13 #[inline]
15 pub fn new(bytes: &'a mut [u8]) -> Result<Self, &'static str> {
16 if bytes.len() < UDP_HEADER_LENGTH {
17 return Err("Slice is too short to contain a UDP header.");
18 }
19
20 Ok(Self { bytes })
21 }
22
23 #[inline]
25 pub fn header_len(&self) -> usize {
26 UDP_HEADER_LENGTH
27 }
28
29 #[inline]
31 pub fn packet_len(&self) -> usize {
32 self.bytes.len()
33 }
34
35 #[inline]
37 pub fn set_src_port(&mut self, src_port: u16) {
38 self.bytes[0] = (src_port >> 8) as u8;
39 self.bytes[1] = src_port as u8;
40 }
41
42 #[inline]
44 pub fn set_dest_port(&mut self, dst_port: u16) {
45 self.bytes[2] = (dst_port >> 8) as u8;
46 self.bytes[3] = dst_port as u8;
47 }
48
49 #[inline]
53 pub fn set_length(&mut self, length: u16) {
54 self.bytes[4] = (length >> 8) as u8;
55 self.bytes[5] = length as u8;
56 }
57
58 #[inline]
65 pub fn set_checksum(&mut self, pseudo_sum: u32) {
66 self.bytes[6] = 0;
67 self.bytes[7] = 0;
68 let checksum = internet_checksum(self.bytes, pseudo_sum);
69 self.bytes[6] = (checksum >> 8) as u8;
70 self.bytes[7] = (checksum & 0xff) as u8;
71 }
72
73 #[inline]
79 pub fn set_payload(&mut self, payload: &[u8]) -> Result<(), &'static str> {
80 let start = self.header_len();
81 let payload_len = payload.len();
82
83 if self.packet_len() - start < payload_len {
84 return Err("Payload is too large to fit in the TCP packet.");
85 }
86
87 let end = start + payload_len;
88 self.bytes[start..end].copy_from_slice(payload);
89
90 Ok(())
91 }
92}
93
94#[derive(PartialEq)]
96pub struct UdpReader<'a> {
97 pub bytes: &'a [u8],
98}
99
100impl<'a> UdpReader<'a> {
101 #[inline]
103 pub fn new(bytes: &'a [u8]) -> Result<Self, &'static str> {
104 if bytes.len() < UDP_HEADER_LENGTH {
105 return Err("Slice is too short to contain a UDP header.");
106 }
107
108 Ok(Self { bytes })
109 }
110
111 #[inline]
113 pub fn src_port(&self) -> u16 {
114 ((self.bytes[0] as u16) << 8) | (self.bytes[1] as u16)
115 }
116
117 #[inline]
119 pub fn dest_port(&self) -> u16 {
120 ((self.bytes[2] as u16) << 8) | (self.bytes[3] as u16)
121 }
122
123 #[inline]
125 pub fn checksum(&self) -> u16 {
126 ((self.bytes[6] as u16) << 8) | (self.bytes[7] as u16)
127 }
128
129 #[inline]
133 pub fn length(&self) -> u16 {
134 ((self.bytes[4] as u16) << 8) | (self.bytes[5] as u16)
135 }
136
137 #[inline]
139 pub fn header_len(&self) -> usize {
140 UDP_HEADER_LENGTH
141 }
142
143 #[inline]
145 pub fn header(&self) -> &'a [u8] {
146 &self.bytes[..UDP_HEADER_LENGTH]
147 }
148
149 #[inline]
151 pub fn payload(&self) -> &'a [u8] {
152 &self.bytes[UDP_HEADER_LENGTH..]
153 }
154}
155
156impl fmt::Debug for UdpReader<'_> {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 f.debug_struct("UdpDatagram")
159 .field("src_port", &self.src_port())
160 .field("dest_port", &self.dest_port())
161 .field("length", &self.length())
162 .finish()
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::network::checksum::pseudo_header;
170
171 #[test]
172 fn getters_and_setters() {
173 let mut bytes = [0u8; 8];
175
176 let src_port = 12345;
178 let dst_port = 54321;
179 let length = 8;
180 let src_ip = [192, 168, 1, 1];
181 let dest_ip = [192, 168, 1, 2];
182
183 let mut writer = UdpWriter::new(&mut bytes).unwrap();
185
186 writer.set_src_port(src_port);
188 writer.set_dest_port(dst_port);
189 writer.set_length(length);
190
191 let pseudo_sum = pseudo_header(&src_ip, &dest_ip, 17, writer.packet_len());
193 writer.set_checksum(pseudo_sum);
194
195 let reader = UdpReader::new(&bytes).unwrap();
197
198 assert_eq!(reader.src_port(), src_port);
200 assert_eq!(reader.dest_port(), dst_port);
201 assert_eq!(reader.length(), length);
202 }
203}