Skip to main content

stackforge_core/layer/udp/
builder.rs

1//! UDP packet builder.
2//!
3//! Provides a fluent API for constructing UDP packets with automatic
4//! field calculation (length, checksum).
5//!
6//! # Example
7//!
8//! ```rust
9//! use stackforge_core::layer::udp::UdpBuilder;
10//! use std::net::Ipv4Addr;
11//!
12//! // Build a DNS request packet
13//! let packet = UdpBuilder::new()
14//!     .src_port(12345)
15//!     .dst_port(53)
16//!     .payload(b"DNS query data")
17//!     .build();
18//! ```
19
20use std::net::{Ipv4Addr, Ipv6Addr};
21
22use super::checksum::{udp_checksum_ipv4, udp_checksum_ipv6};
23use super::{UDP_HEADER_LEN, offsets};
24use crate::layer::field::FieldError;
25
26/// Builder for UDP packets.
27#[derive(Debug, Clone)]
28pub struct UdpBuilder {
29    // Header fields
30    src_port: u16,
31    dst_port: u16,
32    length: Option<u16>,
33    checksum: Option<u16>,
34
35    // Payload
36    payload: Vec<u8>,
37
38    // Build options
39    auto_length: bool,
40    auto_checksum: bool,
41
42    // IP addresses for checksum calculation
43    src_ip: Option<IpAddr>,
44    dst_ip: Option<IpAddr>,
45}
46
47/// IP address enum for checksum calculation.
48#[derive(Debug, Clone, Copy)]
49pub enum IpAddr {
50    V4(Ipv4Addr),
51    V6(Ipv6Addr),
52}
53
54impl From<Ipv4Addr> for IpAddr {
55    fn from(addr: Ipv4Addr) -> Self {
56        IpAddr::V4(addr)
57    }
58}
59
60impl From<Ipv6Addr> for IpAddr {
61    fn from(addr: Ipv6Addr) -> Self {
62        IpAddr::V6(addr)
63    }
64}
65
66impl Default for UdpBuilder {
67    fn default() -> Self {
68        Self {
69            src_port: 53,
70            dst_port: 53,
71            length: None,
72            checksum: None,
73            payload: Vec::new(),
74            auto_length: true,
75            auto_checksum: true,
76            src_ip: None,
77            dst_ip: None,
78        }
79    }
80}
81
82impl UdpBuilder {
83    /// Create a new UDP builder with default values.
84    #[must_use]
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Create a builder initialized from an existing packet.
90    pub fn from_bytes(data: &[u8]) -> Result<Self, FieldError> {
91        if data.len() < UDP_HEADER_LEN {
92            return Err(FieldError::BufferTooShort {
93                offset: 0,
94                need: UDP_HEADER_LEN,
95                have: data.len(),
96            });
97        }
98
99        let src_port = u16::from_be_bytes([data[offsets::SRC_PORT], data[offsets::SRC_PORT + 1]]);
100        let dst_port = u16::from_be_bytes([data[offsets::DST_PORT], data[offsets::DST_PORT + 1]]);
101        let length = u16::from_be_bytes([data[offsets::LENGTH], data[offsets::LENGTH + 1]]);
102        let checksum = u16::from_be_bytes([data[offsets::CHECKSUM], data[offsets::CHECKSUM + 1]]);
103
104        let mut builder = Self::new();
105        builder.src_port = src_port;
106        builder.dst_port = dst_port;
107        builder.length = Some(length);
108        builder.checksum = Some(checksum);
109
110        // Copy payload if present
111        if data.len() > UDP_HEADER_LEN {
112            builder.payload = data[UDP_HEADER_LEN..].to_vec();
113        }
114
115        // Disable auto-calculation since we're copying exact values
116        builder.auto_length = false;
117        builder.auto_checksum = false;
118
119        Ok(builder)
120    }
121
122    // ========== Header Field Setters ==========
123
124    /// Set the source port.
125    #[must_use]
126    pub fn src_port(mut self, port: u16) -> Self {
127        self.src_port = port;
128        self
129    }
130
131    /// Alias for `src_port` (Scapy compatibility).
132    #[must_use]
133    pub fn sport(self, port: u16) -> Self {
134        self.src_port(port)
135    }
136
137    /// Set the destination port.
138    #[must_use]
139    pub fn dst_port(mut self, port: u16) -> Self {
140        self.dst_port = port;
141        self
142    }
143
144    /// Alias for `dst_port` (Scapy compatibility).
145    #[must_use]
146    pub fn dport(self, port: u16) -> Self {
147        self.dst_port(port)
148    }
149
150    /// Set the UDP length manually.
151    ///
152    /// If not set, the length will be calculated automatically (8 + payload length).
153    #[must_use]
154    pub fn length(mut self, len: u16) -> Self {
155        self.length = Some(len);
156        self.auto_length = false;
157        self
158    }
159
160    /// Alias for length (Scapy compatibility).
161    #[must_use]
162    pub fn len(self, len: u16) -> Self {
163        self.length(len)
164    }
165
166    /// Set the checksum manually.
167    ///
168    /// If not set, the checksum will be calculated automatically if IP addresses are provided.
169    #[must_use]
170    pub fn checksum(mut self, csum: u16) -> Self {
171        self.checksum = Some(csum);
172        self.auto_checksum = false;
173        self
174    }
175
176    /// Alias for checksum (Scapy compatibility).
177    #[must_use]
178    pub fn chksum(self, csum: u16) -> Self {
179        self.checksum(csum)
180    }
181
182    /// Enable automatic length calculation (default).
183    #[must_use]
184    pub fn enable_auto_length(mut self) -> Self {
185        self.auto_length = true;
186        self.length = None;
187        self
188    }
189
190    /// Disable automatic length calculation.
191    #[must_use]
192    pub fn disable_auto_length(mut self) -> Self {
193        self.auto_length = false;
194        self
195    }
196
197    /// Enable automatic checksum calculation (default).
198    #[must_use]
199    pub fn enable_auto_checksum(mut self) -> Self {
200        self.auto_checksum = true;
201        self.checksum = None;
202        self
203    }
204
205    /// Disable automatic checksum calculation.
206    #[must_use]
207    pub fn disable_auto_checksum(mut self) -> Self {
208        self.auto_checksum = false;
209        self
210    }
211
212    // ========== IP Address Setters ==========
213
214    /// Set source IPv4 address for checksum calculation.
215    #[must_use]
216    pub fn src_ipv4(mut self, addr: Ipv4Addr) -> Self {
217        self.src_ip = Some(IpAddr::V4(addr));
218        self
219    }
220
221    /// Set destination IPv4 address for checksum calculation.
222    #[must_use]
223    pub fn dst_ipv4(mut self, addr: Ipv4Addr) -> Self {
224        self.dst_ip = Some(IpAddr::V4(addr));
225        self
226    }
227
228    /// Set source IPv6 address for checksum calculation.
229    #[must_use]
230    pub fn src_ipv6(mut self, addr: Ipv6Addr) -> Self {
231        self.src_ip = Some(IpAddr::V6(addr));
232        self
233    }
234
235    /// Set destination IPv6 address for checksum calculation.
236    #[must_use]
237    pub fn dst_ipv6(mut self, addr: Ipv6Addr) -> Self {
238        self.dst_ip = Some(IpAddr::V6(addr));
239        self
240    }
241
242    /// Set both source and destination IPv4 addresses.
243    #[must_use]
244    pub fn ipv4_addrs(self, src: Ipv4Addr, dst: Ipv4Addr) -> Self {
245        self.src_ipv4(src).dst_ipv4(dst)
246    }
247
248    /// Set both source and destination IPv6 addresses.
249    #[must_use]
250    pub fn ipv6_addrs(self, src: Ipv6Addr, dst: Ipv6Addr) -> Self {
251        self.src_ipv6(src).dst_ipv6(dst)
252    }
253
254    // ========== Payload Setters ==========
255
256    /// Set the payload data.
257    pub fn payload<T: Into<Vec<u8>>>(mut self, data: T) -> Self {
258        self.payload = data.into();
259        self
260    }
261
262    /// Append to the payload data.
263    pub fn append_payload<T: AsRef<[u8]>>(mut self, data: T) -> Self {
264        self.payload.extend_from_slice(data.as_ref());
265        self
266    }
267
268    // ========== Size Calculation ==========
269
270    /// Get the total packet size (header + payload).
271    #[must_use]
272    pub fn packet_size(&self) -> usize {
273        UDP_HEADER_LEN + self.payload.len()
274    }
275
276    /// Get the header size (always 8 bytes for UDP).
277    #[must_use]
278    pub fn header_size(&self) -> usize {
279        UDP_HEADER_LEN
280    }
281
282    // ========== Build Methods ==========
283
284    /// Build the UDP packet into a new buffer.
285    #[must_use]
286    pub fn build(&self) -> Vec<u8> {
287        let total_size = self.packet_size();
288        let mut buf = vec![0u8; total_size];
289        self.build_into(&mut buf)
290            .expect("buffer is correctly sized");
291        buf
292    }
293
294    /// Build the UDP packet into an existing buffer.
295    pub fn build_into(&self, buf: &mut [u8]) -> Result<usize, FieldError> {
296        let total_size = self.packet_size();
297
298        if buf.len() < total_size {
299            return Err(FieldError::BufferTooShort {
300                offset: 0,
301                need: total_size,
302                have: buf.len(),
303            });
304        }
305
306        // Calculate length (header + payload)
307        let length = if self.auto_length {
308            total_size as u16
309        } else {
310            self.length.unwrap_or(total_size as u16)
311        };
312
313        // Source Port (big-endian)
314        buf[offsets::SRC_PORT..offsets::SRC_PORT + 2].copy_from_slice(&self.src_port.to_be_bytes());
315
316        // Destination Port (big-endian)
317        buf[offsets::DST_PORT..offsets::DST_PORT + 2].copy_from_slice(&self.dst_port.to_be_bytes());
318
319        // Length (big-endian)
320        buf[offsets::LENGTH..offsets::LENGTH + 2].copy_from_slice(&length.to_be_bytes());
321
322        // Checksum (initially 0, calculated later if auto_checksum is enabled)
323        buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&[0, 0]);
324
325        // Copy payload
326        if !self.payload.is_empty() {
327            buf[UDP_HEADER_LEN..total_size].copy_from_slice(&self.payload);
328        }
329
330        // Calculate checksum if enabled and IP addresses are available
331        if self.auto_checksum {
332            let checksum = self.calculate_checksum(&buf[..total_size]);
333            if let Some(csum) = checksum {
334                // RFC 768: If checksum is 0, it should be 0xFFFF
335                let final_csum = if csum == 0 { 0xFFFF } else { csum };
336                buf[offsets::CHECKSUM..offsets::CHECKSUM + 2]
337                    .copy_from_slice(&final_csum.to_be_bytes());
338            }
339        } else if let Some(csum) = self.checksum {
340            buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&csum.to_be_bytes());
341        }
342
343        Ok(total_size)
344    }
345
346    /// Build just the UDP header (without payload).
347    #[must_use]
348    pub fn build_header(&self) -> Vec<u8> {
349        let mut buf = vec![0u8; UDP_HEADER_LEN];
350
351        // Create a copy without payload for header-only build
352        let builder = Self {
353            payload: Vec::new(),
354            ..self.clone()
355        };
356        builder
357            .build_into(&mut buf)
358            .expect("buffer is correctly sized");
359
360        buf
361    }
362
363    /// Calculate the checksum based on IP addresses.
364    fn calculate_checksum(&self, udp_packet: &[u8]) -> Option<u16> {
365        match (self.src_ip, self.dst_ip) {
366            (Some(IpAddr::V4(src)), Some(IpAddr::V4(dst))) => {
367                Some(udp_checksum_ipv4(src, dst, udp_packet))
368            },
369            (Some(IpAddr::V6(src)), Some(IpAddr::V6(dst))) => {
370                Some(udp_checksum_ipv6(src, dst, udp_packet))
371            },
372            _ => None, // Can't calculate without IP addresses or with mismatched versions
373        }
374    }
375}
376
377// ========== Convenience Constructors ==========
378
379impl UdpBuilder {
380    /// Create a DNS query packet builder (port 53).
381    #[must_use]
382    pub fn dns_query() -> Self {
383        Self::new().src_port(53).dst_port(53)
384    }
385
386    /// Create a DHCP client packet builder (ports 68 -> 67).
387    #[must_use]
388    pub fn dhcp_client() -> Self {
389        Self::new().src_port(68).dst_port(67)
390    }
391
392    /// Create a DHCP server packet builder (ports 67 -> 68).
393    #[must_use]
394    pub fn dhcp_server() -> Self {
395        Self::new().src_port(67).dst_port(68)
396    }
397
398    /// Create a NTP packet builder (port 123).
399    #[must_use]
400    pub fn ntp() -> Self {
401        Self::new().src_port(123).dst_port(123)
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_builder_defaults() {
411        let builder = UdpBuilder::new();
412        assert_eq!(builder.src_port, 53);
413        assert_eq!(builder.dst_port, 53);
414        assert!(builder.auto_length);
415        assert!(builder.auto_checksum);
416    }
417
418    #[test]
419    fn test_build_basic() {
420        let packet = UdpBuilder::new()
421            .src_port(12345)
422            .dst_port(80)
423            .payload(b"Hello")
424            .build();
425
426        // Check header
427        assert_eq!(packet.len(), 8 + 5); // header + payload
428        assert_eq!(u16::from_be_bytes([packet[0], packet[1]]), 12345); // sport
429        assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 80); // dport
430        assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 13); // length (8+5)
431
432        // Check payload
433        assert_eq!(&packet[8..], b"Hello");
434    }
435
436    #[test]
437    fn test_build_with_manual_length() {
438        let packet = UdpBuilder::new()
439            .src_port(1234)
440            .dst_port(5678)
441            .length(100)
442            .build();
443
444        assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 100);
445    }
446
447    #[test]
448    fn test_build_with_checksum() {
449        let packet = UdpBuilder::new()
450            .src_port(1234)
451            .dst_port(5678)
452            .src_ipv4(Ipv4Addr::new(192, 168, 1, 1))
453            .dst_ipv4(Ipv4Addr::new(192, 168, 1, 2))
454            .payload(b"test")
455            .build();
456
457        // Checksum should be calculated
458        let checksum = u16::from_be_bytes([packet[6], packet[7]]);
459        assert_ne!(checksum, 0); // Should have a non-zero checksum
460    }
461
462    #[test]
463    fn test_build_with_zero_checksum_becomes_ffff() {
464        // This is a contrived test - in practice it's hard to get exactly 0
465        // But the code handles it per RFC 768
466        let builder = UdpBuilder::new()
467            .src_port(0)
468            .dst_port(0)
469            .disable_auto_checksum()
470            .checksum(0);
471
472        let packet = builder.build();
473        let checksum = u16::from_be_bytes([packet[6], packet[7]]);
474        assert_eq!(checksum, 0); // Manual checksum is used as-is
475    }
476
477    #[test]
478    fn test_from_bytes() {
479        let original = UdpBuilder::new()
480            .src_port(1234)
481            .dst_port(5678)
482            .payload(b"test data")
483            .build();
484
485        let rebuilt = UdpBuilder::from_bytes(&original).unwrap();
486        assert_eq!(rebuilt.src_port, 1234);
487        assert_eq!(rebuilt.dst_port, 5678);
488        assert_eq!(rebuilt.payload, b"test data");
489    }
490
491    #[test]
492    fn test_scapy_aliases() {
493        let packet = UdpBuilder::new()
494            .sport(1234) // alias for src_port
495            .dport(5678) // alias for dst_port
496            .len(20) // alias for length
497            .chksum(0xABCD) // alias for checksum
498            .build();
499
500        assert_eq!(u16::from_be_bytes([packet[0], packet[1]]), 1234);
501        assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 5678);
502        assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 20);
503        assert_eq!(u16::from_be_bytes([packet[6], packet[7]]), 0xABCD);
504    }
505
506    #[test]
507    fn test_convenience_constructors() {
508        let dns = UdpBuilder::dns_query().build();
509        assert_eq!(u16::from_be_bytes([dns[0], dns[1]]), 53);
510        assert_eq!(u16::from_be_bytes([dns[2], dns[3]]), 53);
511
512        let dhcp_client = UdpBuilder::dhcp_client().build();
513        assert_eq!(u16::from_be_bytes([dhcp_client[0], dhcp_client[1]]), 68);
514        assert_eq!(u16::from_be_bytes([dhcp_client[2], dhcp_client[3]]), 67);
515    }
516
517    #[test]
518    fn test_build_header_only() {
519        let header = UdpBuilder::new()
520            .src_port(1234)
521            .dst_port(5678)
522            .payload(b"this should not be included")
523            .build_header();
524
525        assert_eq!(header.len(), 8); // Header only, no payload
526        assert_eq!(u16::from_be_bytes([header[0], header[1]]), 1234);
527    }
528}