stackforge_core/layer/udp/
builder.rs1use std::net::{Ipv4Addr, Ipv6Addr};
21
22use super::checksum::{udp_checksum_ipv4, udp_checksum_ipv6};
23use super::{UDP_HEADER_LEN, offsets};
24use crate::layer::field::FieldError;
25
26#[derive(Debug, Clone)]
28pub struct UdpBuilder {
29 src_port: u16,
31 dst_port: u16,
32 length: Option<u16>,
33 checksum: Option<u16>,
34
35 payload: Vec<u8>,
37
38 auto_length: bool,
40 auto_checksum: bool,
41
42 src_ip: Option<IpAddr>,
44 dst_ip: Option<IpAddr>,
45}
46
47#[derive(Debug, Clone, Copy)]
49pub enum IpAddr {
50 V4(Ipv4Addr),
51 V6(Ipv6Addr),
52}
53
54impl From<Ipv4Addr> for IpAddr {
55 fn from(addr: Ipv4Addr) -> Self {
56 IpAddr::V4(addr)
57 }
58}
59
60impl From<Ipv6Addr> for IpAddr {
61 fn from(addr: Ipv6Addr) -> Self {
62 IpAddr::V6(addr)
63 }
64}
65
66impl Default for UdpBuilder {
67 fn default() -> Self {
68 Self {
69 src_port: 53,
70 dst_port: 53,
71 length: None,
72 checksum: None,
73 payload: Vec::new(),
74 auto_length: true,
75 auto_checksum: true,
76 src_ip: None,
77 dst_ip: None,
78 }
79 }
80}
81
82impl UdpBuilder {
83 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub fn from_bytes(data: &[u8]) -> Result<Self, FieldError> {
90 if data.len() < UDP_HEADER_LEN {
91 return Err(FieldError::BufferTooShort {
92 offset: 0,
93 need: UDP_HEADER_LEN,
94 have: data.len(),
95 });
96 }
97
98 let src_port = u16::from_be_bytes([data[offsets::SRC_PORT], data[offsets::SRC_PORT + 1]]);
99 let dst_port = u16::from_be_bytes([data[offsets::DST_PORT], data[offsets::DST_PORT + 1]]);
100 let length = u16::from_be_bytes([data[offsets::LENGTH], data[offsets::LENGTH + 1]]);
101 let checksum = u16::from_be_bytes([data[offsets::CHECKSUM], data[offsets::CHECKSUM + 1]]);
102
103 let mut builder = Self::new();
104 builder.src_port = src_port;
105 builder.dst_port = dst_port;
106 builder.length = Some(length);
107 builder.checksum = Some(checksum);
108
109 if data.len() > UDP_HEADER_LEN {
111 builder.payload = data[UDP_HEADER_LEN..].to_vec();
112 }
113
114 builder.auto_length = false;
116 builder.auto_checksum = false;
117
118 Ok(builder)
119 }
120
121 pub fn src_port(mut self, port: u16) -> Self {
125 self.src_port = port;
126 self
127 }
128
129 pub fn sport(self, port: u16) -> Self {
131 self.src_port(port)
132 }
133
134 pub fn dst_port(mut self, port: u16) -> Self {
136 self.dst_port = port;
137 self
138 }
139
140 pub fn dport(self, port: u16) -> Self {
142 self.dst_port(port)
143 }
144
145 pub fn length(mut self, len: u16) -> Self {
149 self.length = Some(len);
150 self.auto_length = false;
151 self
152 }
153
154 pub fn len(self, len: u16) -> Self {
156 self.length(len)
157 }
158
159 pub fn checksum(mut self, csum: u16) -> Self {
163 self.checksum = Some(csum);
164 self.auto_checksum = false;
165 self
166 }
167
168 pub fn chksum(self, csum: u16) -> Self {
170 self.checksum(csum)
171 }
172
173 pub fn enable_auto_length(mut self) -> Self {
175 self.auto_length = true;
176 self.length = None;
177 self
178 }
179
180 pub fn disable_auto_length(mut self) -> Self {
182 self.auto_length = false;
183 self
184 }
185
186 pub fn enable_auto_checksum(mut self) -> Self {
188 self.auto_checksum = true;
189 self.checksum = None;
190 self
191 }
192
193 pub fn disable_auto_checksum(mut self) -> Self {
195 self.auto_checksum = false;
196 self
197 }
198
199 pub fn src_ipv4(mut self, addr: Ipv4Addr) -> Self {
203 self.src_ip = Some(IpAddr::V4(addr));
204 self
205 }
206
207 pub fn dst_ipv4(mut self, addr: Ipv4Addr) -> Self {
209 self.dst_ip = Some(IpAddr::V4(addr));
210 self
211 }
212
213 pub fn src_ipv6(mut self, addr: Ipv6Addr) -> Self {
215 self.src_ip = Some(IpAddr::V6(addr));
216 self
217 }
218
219 pub fn dst_ipv6(mut self, addr: Ipv6Addr) -> Self {
221 self.dst_ip = Some(IpAddr::V6(addr));
222 self
223 }
224
225 pub fn ipv4_addrs(self, src: Ipv4Addr, dst: Ipv4Addr) -> Self {
227 self.src_ipv4(src).dst_ipv4(dst)
228 }
229
230 pub fn ipv6_addrs(self, src: Ipv6Addr, dst: Ipv6Addr) -> Self {
232 self.src_ipv6(src).dst_ipv6(dst)
233 }
234
235 pub fn payload<T: Into<Vec<u8>>>(mut self, data: T) -> Self {
239 self.payload = data.into();
240 self
241 }
242
243 pub fn append_payload<T: AsRef<[u8]>>(mut self, data: T) -> Self {
245 self.payload.extend_from_slice(data.as_ref());
246 self
247 }
248
249 pub fn packet_size(&self) -> usize {
253 UDP_HEADER_LEN + self.payload.len()
254 }
255
256 pub fn header_size(&self) -> usize {
258 UDP_HEADER_LEN
259 }
260
261 pub fn build(&self) -> Vec<u8> {
265 let total_size = self.packet_size();
266 let mut buf = vec![0u8; total_size];
267 self.build_into(&mut buf)
268 .expect("buffer is correctly sized");
269 buf
270 }
271
272 pub fn build_into(&self, buf: &mut [u8]) -> Result<usize, FieldError> {
274 let total_size = self.packet_size();
275
276 if buf.len() < total_size {
277 return Err(FieldError::BufferTooShort {
278 offset: 0,
279 need: total_size,
280 have: buf.len(),
281 });
282 }
283
284 let length = if self.auto_length {
286 total_size as u16
287 } else {
288 self.length.unwrap_or(total_size as u16)
289 };
290
291 buf[offsets::SRC_PORT..offsets::SRC_PORT + 2].copy_from_slice(&self.src_port.to_be_bytes());
293
294 buf[offsets::DST_PORT..offsets::DST_PORT + 2].copy_from_slice(&self.dst_port.to_be_bytes());
296
297 buf[offsets::LENGTH..offsets::LENGTH + 2].copy_from_slice(&length.to_be_bytes());
299
300 buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&[0, 0]);
302
303 if !self.payload.is_empty() {
305 buf[UDP_HEADER_LEN..total_size].copy_from_slice(&self.payload);
306 }
307
308 if self.auto_checksum {
310 let checksum = self.calculate_checksum(&buf[..total_size]);
311 if let Some(csum) = checksum {
312 let final_csum = if csum == 0 { 0xFFFF } else { csum };
314 buf[offsets::CHECKSUM..offsets::CHECKSUM + 2]
315 .copy_from_slice(&final_csum.to_be_bytes());
316 }
317 } else if let Some(csum) = self.checksum {
318 buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&csum.to_be_bytes());
319 }
320
321 Ok(total_size)
322 }
323
324 pub fn build_header(&self) -> Vec<u8> {
326 let mut buf = vec![0u8; UDP_HEADER_LEN];
327
328 let builder = Self {
330 payload: Vec::new(),
331 ..self.clone()
332 };
333 builder
334 .build_into(&mut buf)
335 .expect("buffer is correctly sized");
336
337 buf
338 }
339
340 fn calculate_checksum(&self, udp_packet: &[u8]) -> Option<u16> {
342 match (self.src_ip, self.dst_ip) {
343 (Some(IpAddr::V4(src)), Some(IpAddr::V4(dst))) => {
344 Some(udp_checksum_ipv4(src, dst, udp_packet))
345 }
346 (Some(IpAddr::V6(src)), Some(IpAddr::V6(dst))) => {
347 Some(udp_checksum_ipv6(src, dst, udp_packet))
348 }
349 _ => None, }
351 }
352}
353
354impl UdpBuilder {
357 pub fn dns_query() -> Self {
359 Self::new().src_port(53).dst_port(53)
360 }
361
362 pub fn dhcp_client() -> Self {
364 Self::new().src_port(68).dst_port(67)
365 }
366
367 pub fn dhcp_server() -> Self {
369 Self::new().src_port(67).dst_port(68)
370 }
371
372 pub fn ntp() -> Self {
374 Self::new().src_port(123).dst_port(123)
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_builder_defaults() {
384 let builder = UdpBuilder::new();
385 assert_eq!(builder.src_port, 53);
386 assert_eq!(builder.dst_port, 53);
387 assert!(builder.auto_length);
388 assert!(builder.auto_checksum);
389 }
390
391 #[test]
392 fn test_build_basic() {
393 let packet = UdpBuilder::new()
394 .src_port(12345)
395 .dst_port(80)
396 .payload(b"Hello")
397 .build();
398
399 assert_eq!(packet.len(), 8 + 5); assert_eq!(u16::from_be_bytes([packet[0], packet[1]]), 12345); assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 80); assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 13); assert_eq!(&packet[8..], b"Hello");
407 }
408
409 #[test]
410 fn test_build_with_manual_length() {
411 let packet = UdpBuilder::new()
412 .src_port(1234)
413 .dst_port(5678)
414 .length(100)
415 .build();
416
417 assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 100);
418 }
419
420 #[test]
421 fn test_build_with_checksum() {
422 let packet = UdpBuilder::new()
423 .src_port(1234)
424 .dst_port(5678)
425 .src_ipv4(Ipv4Addr::new(192, 168, 1, 1))
426 .dst_ipv4(Ipv4Addr::new(192, 168, 1, 2))
427 .payload(b"test")
428 .build();
429
430 let checksum = u16::from_be_bytes([packet[6], packet[7]]);
432 assert_ne!(checksum, 0); }
434
435 #[test]
436 fn test_build_with_zero_checksum_becomes_ffff() {
437 let builder = UdpBuilder::new()
440 .src_port(0)
441 .dst_port(0)
442 .disable_auto_checksum()
443 .checksum(0);
444
445 let packet = builder.build();
446 let checksum = u16::from_be_bytes([packet[6], packet[7]]);
447 assert_eq!(checksum, 0); }
449
450 #[test]
451 fn test_from_bytes() {
452 let original = UdpBuilder::new()
453 .src_port(1234)
454 .dst_port(5678)
455 .payload(b"test data")
456 .build();
457
458 let rebuilt = UdpBuilder::from_bytes(&original).unwrap();
459 assert_eq!(rebuilt.src_port, 1234);
460 assert_eq!(rebuilt.dst_port, 5678);
461 assert_eq!(rebuilt.payload, b"test data");
462 }
463
464 #[test]
465 fn test_scapy_aliases() {
466 let packet = UdpBuilder::new()
467 .sport(1234) .dport(5678) .len(20) .chksum(0xABCD) .build();
472
473 assert_eq!(u16::from_be_bytes([packet[0], packet[1]]), 1234);
474 assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 5678);
475 assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 20);
476 assert_eq!(u16::from_be_bytes([packet[6], packet[7]]), 0xABCD);
477 }
478
479 #[test]
480 fn test_convenience_constructors() {
481 let dns = UdpBuilder::dns_query().build();
482 assert_eq!(u16::from_be_bytes([dns[0], dns[1]]), 53);
483 assert_eq!(u16::from_be_bytes([dns[2], dns[3]]), 53);
484
485 let dhcp_client = UdpBuilder::dhcp_client().build();
486 assert_eq!(u16::from_be_bytes([dhcp_client[0], dhcp_client[1]]), 68);
487 assert_eq!(u16::from_be_bytes([dhcp_client[2], dhcp_client[3]]), 67);
488 }
489
490 #[test]
491 fn test_build_header_only() {
492 let header = UdpBuilder::new()
493 .src_port(1234)
494 .dst_port(5678)
495 .payload(b"this should not be included")
496 .build_header();
497
498 assert_eq!(header.len(), 8); assert_eq!(u16::from_be_bytes([header[0], header[1]]), 1234);
500 }
501}