stackforge_core/layer/udp/
builder.rs1use std::net::{Ipv4Addr, Ipv6Addr};
21
22use super::checksum::{udp_checksum_ipv4, udp_checksum_ipv6};
23use super::{UDP_HEADER_LEN, offsets};
24use crate::layer::field::FieldError;
25
26#[derive(Debug, Clone)]
28pub struct UdpBuilder {
29 src_port: u16,
31 dst_port: u16,
32 length: Option<u16>,
33 checksum: Option<u16>,
34
35 payload: Vec<u8>,
37
38 auto_length: bool,
40 auto_checksum: bool,
41
42 src_ip: Option<IpAddr>,
44 dst_ip: Option<IpAddr>,
45}
46
47#[derive(Debug, Clone, Copy)]
49pub enum IpAddr {
50 V4(Ipv4Addr),
51 V6(Ipv6Addr),
52}
53
54impl From<Ipv4Addr> for IpAddr {
55 fn from(addr: Ipv4Addr) -> Self {
56 IpAddr::V4(addr)
57 }
58}
59
60impl From<Ipv6Addr> for IpAddr {
61 fn from(addr: Ipv6Addr) -> Self {
62 IpAddr::V6(addr)
63 }
64}
65
66impl Default for UdpBuilder {
67 fn default() -> Self {
68 Self {
69 src_port: 53,
70 dst_port: 53,
71 length: None,
72 checksum: None,
73 payload: Vec::new(),
74 auto_length: true,
75 auto_checksum: true,
76 src_ip: None,
77 dst_ip: None,
78 }
79 }
80}
81
82impl UdpBuilder {
83 #[must_use]
85 pub fn new() -> Self {
86 Self::default()
87 }
88
89 pub fn from_bytes(data: &[u8]) -> Result<Self, FieldError> {
91 if data.len() < UDP_HEADER_LEN {
92 return Err(FieldError::BufferTooShort {
93 offset: 0,
94 need: UDP_HEADER_LEN,
95 have: data.len(),
96 });
97 }
98
99 let src_port = u16::from_be_bytes([data[offsets::SRC_PORT], data[offsets::SRC_PORT + 1]]);
100 let dst_port = u16::from_be_bytes([data[offsets::DST_PORT], data[offsets::DST_PORT + 1]]);
101 let length = u16::from_be_bytes([data[offsets::LENGTH], data[offsets::LENGTH + 1]]);
102 let checksum = u16::from_be_bytes([data[offsets::CHECKSUM], data[offsets::CHECKSUM + 1]]);
103
104 let mut builder = Self::new();
105 builder.src_port = src_port;
106 builder.dst_port = dst_port;
107 builder.length = Some(length);
108 builder.checksum = Some(checksum);
109
110 if data.len() > UDP_HEADER_LEN {
112 builder.payload = data[UDP_HEADER_LEN..].to_vec();
113 }
114
115 builder.auto_length = false;
117 builder.auto_checksum = false;
118
119 Ok(builder)
120 }
121
122 #[must_use]
126 pub fn src_port(mut self, port: u16) -> Self {
127 self.src_port = port;
128 self
129 }
130
131 #[must_use]
133 pub fn sport(self, port: u16) -> Self {
134 self.src_port(port)
135 }
136
137 #[must_use]
139 pub fn dst_port(mut self, port: u16) -> Self {
140 self.dst_port = port;
141 self
142 }
143
144 #[must_use]
146 pub fn dport(self, port: u16) -> Self {
147 self.dst_port(port)
148 }
149
150 #[must_use]
154 pub fn length(mut self, len: u16) -> Self {
155 self.length = Some(len);
156 self.auto_length = false;
157 self
158 }
159
160 #[must_use]
162 pub fn len(self, len: u16) -> Self {
163 self.length(len)
164 }
165
166 #[must_use]
170 pub fn checksum(mut self, csum: u16) -> Self {
171 self.checksum = Some(csum);
172 self.auto_checksum = false;
173 self
174 }
175
176 #[must_use]
178 pub fn chksum(self, csum: u16) -> Self {
179 self.checksum(csum)
180 }
181
182 #[must_use]
184 pub fn enable_auto_length(mut self) -> Self {
185 self.auto_length = true;
186 self.length = None;
187 self
188 }
189
190 #[must_use]
192 pub fn disable_auto_length(mut self) -> Self {
193 self.auto_length = false;
194 self
195 }
196
197 #[must_use]
199 pub fn enable_auto_checksum(mut self) -> Self {
200 self.auto_checksum = true;
201 self.checksum = None;
202 self
203 }
204
205 #[must_use]
207 pub fn disable_auto_checksum(mut self) -> Self {
208 self.auto_checksum = false;
209 self
210 }
211
212 #[must_use]
216 pub fn src_ipv4(mut self, addr: Ipv4Addr) -> Self {
217 self.src_ip = Some(IpAddr::V4(addr));
218 self
219 }
220
221 #[must_use]
223 pub fn dst_ipv4(mut self, addr: Ipv4Addr) -> Self {
224 self.dst_ip = Some(IpAddr::V4(addr));
225 self
226 }
227
228 #[must_use]
230 pub fn src_ipv6(mut self, addr: Ipv6Addr) -> Self {
231 self.src_ip = Some(IpAddr::V6(addr));
232 self
233 }
234
235 #[must_use]
237 pub fn dst_ipv6(mut self, addr: Ipv6Addr) -> Self {
238 self.dst_ip = Some(IpAddr::V6(addr));
239 self
240 }
241
242 #[must_use]
244 pub fn ipv4_addrs(self, src: Ipv4Addr, dst: Ipv4Addr) -> Self {
245 self.src_ipv4(src).dst_ipv4(dst)
246 }
247
248 #[must_use]
250 pub fn ipv6_addrs(self, src: Ipv6Addr, dst: Ipv6Addr) -> Self {
251 self.src_ipv6(src).dst_ipv6(dst)
252 }
253
254 pub fn payload<T: Into<Vec<u8>>>(mut self, data: T) -> Self {
258 self.payload = data.into();
259 self
260 }
261
262 pub fn append_payload<T: AsRef<[u8]>>(mut self, data: T) -> Self {
264 self.payload.extend_from_slice(data.as_ref());
265 self
266 }
267
268 #[must_use]
272 pub fn packet_size(&self) -> usize {
273 UDP_HEADER_LEN + self.payload.len()
274 }
275
276 #[must_use]
278 pub fn header_size(&self) -> usize {
279 UDP_HEADER_LEN
280 }
281
282 #[must_use]
286 pub fn build(&self) -> Vec<u8> {
287 let total_size = self.packet_size();
288 let mut buf = vec![0u8; total_size];
289 self.build_into(&mut buf)
290 .expect("buffer is correctly sized");
291 buf
292 }
293
294 pub fn build_into(&self, buf: &mut [u8]) -> Result<usize, FieldError> {
296 let total_size = self.packet_size();
297
298 if buf.len() < total_size {
299 return Err(FieldError::BufferTooShort {
300 offset: 0,
301 need: total_size,
302 have: buf.len(),
303 });
304 }
305
306 let length = if self.auto_length {
308 total_size as u16
309 } else {
310 self.length.unwrap_or(total_size as u16)
311 };
312
313 buf[offsets::SRC_PORT..offsets::SRC_PORT + 2].copy_from_slice(&self.src_port.to_be_bytes());
315
316 buf[offsets::DST_PORT..offsets::DST_PORT + 2].copy_from_slice(&self.dst_port.to_be_bytes());
318
319 buf[offsets::LENGTH..offsets::LENGTH + 2].copy_from_slice(&length.to_be_bytes());
321
322 buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&[0, 0]);
324
325 if !self.payload.is_empty() {
327 buf[UDP_HEADER_LEN..total_size].copy_from_slice(&self.payload);
328 }
329
330 if self.auto_checksum {
332 let checksum = self.calculate_checksum(&buf[..total_size]);
333 if let Some(csum) = checksum {
334 let final_csum = if csum == 0 { 0xFFFF } else { csum };
336 buf[offsets::CHECKSUM..offsets::CHECKSUM + 2]
337 .copy_from_slice(&final_csum.to_be_bytes());
338 }
339 } else if let Some(csum) = self.checksum {
340 buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&csum.to_be_bytes());
341 }
342
343 Ok(total_size)
344 }
345
346 #[must_use]
348 pub fn build_header(&self) -> Vec<u8> {
349 let mut buf = vec![0u8; UDP_HEADER_LEN];
350
351 let builder = Self {
353 payload: Vec::new(),
354 ..self.clone()
355 };
356 builder
357 .build_into(&mut buf)
358 .expect("buffer is correctly sized");
359
360 buf
361 }
362
363 fn calculate_checksum(&self, udp_packet: &[u8]) -> Option<u16> {
365 match (self.src_ip, self.dst_ip) {
366 (Some(IpAddr::V4(src)), Some(IpAddr::V4(dst))) => {
367 Some(udp_checksum_ipv4(src, dst, udp_packet))
368 },
369 (Some(IpAddr::V6(src)), Some(IpAddr::V6(dst))) => {
370 Some(udp_checksum_ipv6(src, dst, udp_packet))
371 },
372 _ => None, }
374 }
375}
376
377impl UdpBuilder {
380 #[must_use]
382 pub fn dns_query() -> Self {
383 Self::new().src_port(53).dst_port(53)
384 }
385
386 #[must_use]
388 pub fn dhcp_client() -> Self {
389 Self::new().src_port(68).dst_port(67)
390 }
391
392 #[must_use]
394 pub fn dhcp_server() -> Self {
395 Self::new().src_port(67).dst_port(68)
396 }
397
398 #[must_use]
400 pub fn ntp() -> Self {
401 Self::new().src_port(123).dst_port(123)
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_builder_defaults() {
411 let builder = UdpBuilder::new();
412 assert_eq!(builder.src_port, 53);
413 assert_eq!(builder.dst_port, 53);
414 assert!(builder.auto_length);
415 assert!(builder.auto_checksum);
416 }
417
418 #[test]
419 fn test_build_basic() {
420 let packet = UdpBuilder::new()
421 .src_port(12345)
422 .dst_port(80)
423 .payload(b"Hello")
424 .build();
425
426 assert_eq!(packet.len(), 8 + 5); assert_eq!(u16::from_be_bytes([packet[0], packet[1]]), 12345); assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 80); assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 13); assert_eq!(&packet[8..], b"Hello");
434 }
435
436 #[test]
437 fn test_build_with_manual_length() {
438 let packet = UdpBuilder::new()
439 .src_port(1234)
440 .dst_port(5678)
441 .length(100)
442 .build();
443
444 assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 100);
445 }
446
447 #[test]
448 fn test_build_with_checksum() {
449 let packet = UdpBuilder::new()
450 .src_port(1234)
451 .dst_port(5678)
452 .src_ipv4(Ipv4Addr::new(192, 168, 1, 1))
453 .dst_ipv4(Ipv4Addr::new(192, 168, 1, 2))
454 .payload(b"test")
455 .build();
456
457 let checksum = u16::from_be_bytes([packet[6], packet[7]]);
459 assert_ne!(checksum, 0); }
461
462 #[test]
463 fn test_build_with_zero_checksum_becomes_ffff() {
464 let builder = UdpBuilder::new()
467 .src_port(0)
468 .dst_port(0)
469 .disable_auto_checksum()
470 .checksum(0);
471
472 let packet = builder.build();
473 let checksum = u16::from_be_bytes([packet[6], packet[7]]);
474 assert_eq!(checksum, 0); }
476
477 #[test]
478 fn test_from_bytes() {
479 let original = UdpBuilder::new()
480 .src_port(1234)
481 .dst_port(5678)
482 .payload(b"test data")
483 .build();
484
485 let rebuilt = UdpBuilder::from_bytes(&original).unwrap();
486 assert_eq!(rebuilt.src_port, 1234);
487 assert_eq!(rebuilt.dst_port, 5678);
488 assert_eq!(rebuilt.payload, b"test data");
489 }
490
491 #[test]
492 fn test_scapy_aliases() {
493 let packet = UdpBuilder::new()
494 .sport(1234) .dport(5678) .len(20) .chksum(0xABCD) .build();
499
500 assert_eq!(u16::from_be_bytes([packet[0], packet[1]]), 1234);
501 assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 5678);
502 assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 20);
503 assert_eq!(u16::from_be_bytes([packet[6], packet[7]]), 0xABCD);
504 }
505
506 #[test]
507 fn test_convenience_constructors() {
508 let dns = UdpBuilder::dns_query().build();
509 assert_eq!(u16::from_be_bytes([dns[0], dns[1]]), 53);
510 assert_eq!(u16::from_be_bytes([dns[2], dns[3]]), 53);
511
512 let dhcp_client = UdpBuilder::dhcp_client().build();
513 assert_eq!(u16::from_be_bytes([dhcp_client[0], dhcp_client[1]]), 68);
514 assert_eq!(u16::from_be_bytes([dhcp_client[2], dhcp_client[3]]), 67);
515 }
516
517 #[test]
518 fn test_build_header_only() {
519 let header = UdpBuilder::new()
520 .src_port(1234)
521 .dst_port(5678)
522 .payload(b"this should not be included")
523 .build_header();
524
525 assert_eq!(header.len(), 8); assert_eq!(u16::from_be_bytes([header[0], header[1]]), 1234);
527 }
528}