Skip to main content

stackforge_core/layer/udp/
builder.rs

1//! UDP packet builder.
2//!
3//! Provides a fluent API for constructing UDP packets with automatic
4//! field calculation (length, checksum).
5//!
6//! # Example
7//!
8//! ```rust
9//! use stackforge_core::layer::udp::UdpBuilder;
10//! use std::net::Ipv4Addr;
11//!
12//! // Build a DNS request packet
13//! let packet = UdpBuilder::new()
14//!     .src_port(12345)
15//!     .dst_port(53)
16//!     .payload(b"DNS query data")
17//!     .build();
18//! ```
19
20use std::net::{Ipv4Addr, Ipv6Addr};
21
22use super::checksum::{udp_checksum_ipv4, udp_checksum_ipv6};
23use super::{UDP_HEADER_LEN, offsets};
24use crate::layer::field::FieldError;
25
26/// Builder for UDP packets.
27#[derive(Debug, Clone)]
28pub struct UdpBuilder {
29    // Header fields
30    src_port: u16,
31    dst_port: u16,
32    length: Option<u16>,
33    checksum: Option<u16>,
34
35    // Payload
36    payload: Vec<u8>,
37
38    // Build options
39    auto_length: bool,
40    auto_checksum: bool,
41
42    // IP addresses for checksum calculation
43    src_ip: Option<IpAddr>,
44    dst_ip: Option<IpAddr>,
45}
46
47/// IP address enum for checksum calculation.
48#[derive(Debug, Clone, Copy)]
49pub enum IpAddr {
50    V4(Ipv4Addr),
51    V6(Ipv6Addr),
52}
53
54impl From<Ipv4Addr> for IpAddr {
55    fn from(addr: Ipv4Addr) -> Self {
56        IpAddr::V4(addr)
57    }
58}
59
60impl From<Ipv6Addr> for IpAddr {
61    fn from(addr: Ipv6Addr) -> Self {
62        IpAddr::V6(addr)
63    }
64}
65
66impl Default for UdpBuilder {
67    fn default() -> Self {
68        Self {
69            src_port: 53,
70            dst_port: 53,
71            length: None,
72            checksum: None,
73            payload: Vec::new(),
74            auto_length: true,
75            auto_checksum: true,
76            src_ip: None,
77            dst_ip: None,
78        }
79    }
80}
81
82impl UdpBuilder {
83    /// Create a new UDP builder with default values.
84    pub fn new() -> Self {
85        Self::default()
86    }
87
88    /// Create a builder initialized from an existing packet.
89    pub fn from_bytes(data: &[u8]) -> Result<Self, FieldError> {
90        if data.len() < UDP_HEADER_LEN {
91            return Err(FieldError::BufferTooShort {
92                offset: 0,
93                need: UDP_HEADER_LEN,
94                have: data.len(),
95            });
96        }
97
98        let src_port = u16::from_be_bytes([data[offsets::SRC_PORT], data[offsets::SRC_PORT + 1]]);
99        let dst_port = u16::from_be_bytes([data[offsets::DST_PORT], data[offsets::DST_PORT + 1]]);
100        let length = u16::from_be_bytes([data[offsets::LENGTH], data[offsets::LENGTH + 1]]);
101        let checksum = u16::from_be_bytes([data[offsets::CHECKSUM], data[offsets::CHECKSUM + 1]]);
102
103        let mut builder = Self::new();
104        builder.src_port = src_port;
105        builder.dst_port = dst_port;
106        builder.length = Some(length);
107        builder.checksum = Some(checksum);
108
109        // Copy payload if present
110        if data.len() > UDP_HEADER_LEN {
111            builder.payload = data[UDP_HEADER_LEN..].to_vec();
112        }
113
114        // Disable auto-calculation since we're copying exact values
115        builder.auto_length = false;
116        builder.auto_checksum = false;
117
118        Ok(builder)
119    }
120
121    // ========== Header Field Setters ==========
122
123    /// Set the source port.
124    pub fn src_port(mut self, port: u16) -> Self {
125        self.src_port = port;
126        self
127    }
128
129    /// Alias for src_port (Scapy compatibility).
130    pub fn sport(self, port: u16) -> Self {
131        self.src_port(port)
132    }
133
134    /// Set the destination port.
135    pub fn dst_port(mut self, port: u16) -> Self {
136        self.dst_port = port;
137        self
138    }
139
140    /// Alias for dst_port (Scapy compatibility).
141    pub fn dport(self, port: u16) -> Self {
142        self.dst_port(port)
143    }
144
145    /// Set the UDP length manually.
146    ///
147    /// If not set, the length will be calculated automatically (8 + payload length).
148    pub fn length(mut self, len: u16) -> Self {
149        self.length = Some(len);
150        self.auto_length = false;
151        self
152    }
153
154    /// Alias for length (Scapy compatibility).
155    pub fn len(self, len: u16) -> Self {
156        self.length(len)
157    }
158
159    /// Set the checksum manually.
160    ///
161    /// If not set, the checksum will be calculated automatically if IP addresses are provided.
162    pub fn checksum(mut self, csum: u16) -> Self {
163        self.checksum = Some(csum);
164        self.auto_checksum = false;
165        self
166    }
167
168    /// Alias for checksum (Scapy compatibility).
169    pub fn chksum(self, csum: u16) -> Self {
170        self.checksum(csum)
171    }
172
173    /// Enable automatic length calculation (default).
174    pub fn enable_auto_length(mut self) -> Self {
175        self.auto_length = true;
176        self.length = None;
177        self
178    }
179
180    /// Disable automatic length calculation.
181    pub fn disable_auto_length(mut self) -> Self {
182        self.auto_length = false;
183        self
184    }
185
186    /// Enable automatic checksum calculation (default).
187    pub fn enable_auto_checksum(mut self) -> Self {
188        self.auto_checksum = true;
189        self.checksum = None;
190        self
191    }
192
193    /// Disable automatic checksum calculation.
194    pub fn disable_auto_checksum(mut self) -> Self {
195        self.auto_checksum = false;
196        self
197    }
198
199    // ========== IP Address Setters ==========
200
201    /// Set source IPv4 address for checksum calculation.
202    pub fn src_ipv4(mut self, addr: Ipv4Addr) -> Self {
203        self.src_ip = Some(IpAddr::V4(addr));
204        self
205    }
206
207    /// Set destination IPv4 address for checksum calculation.
208    pub fn dst_ipv4(mut self, addr: Ipv4Addr) -> Self {
209        self.dst_ip = Some(IpAddr::V4(addr));
210        self
211    }
212
213    /// Set source IPv6 address for checksum calculation.
214    pub fn src_ipv6(mut self, addr: Ipv6Addr) -> Self {
215        self.src_ip = Some(IpAddr::V6(addr));
216        self
217    }
218
219    /// Set destination IPv6 address for checksum calculation.
220    pub fn dst_ipv6(mut self, addr: Ipv6Addr) -> Self {
221        self.dst_ip = Some(IpAddr::V6(addr));
222        self
223    }
224
225    /// Set both source and destination IPv4 addresses.
226    pub fn ipv4_addrs(self, src: Ipv4Addr, dst: Ipv4Addr) -> Self {
227        self.src_ipv4(src).dst_ipv4(dst)
228    }
229
230    /// Set both source and destination IPv6 addresses.
231    pub fn ipv6_addrs(self, src: Ipv6Addr, dst: Ipv6Addr) -> Self {
232        self.src_ipv6(src).dst_ipv6(dst)
233    }
234
235    // ========== Payload Setters ==========
236
237    /// Set the payload data.
238    pub fn payload<T: Into<Vec<u8>>>(mut self, data: T) -> Self {
239        self.payload = data.into();
240        self
241    }
242
243    /// Append to the payload data.
244    pub fn append_payload<T: AsRef<[u8]>>(mut self, data: T) -> Self {
245        self.payload.extend_from_slice(data.as_ref());
246        self
247    }
248
249    // ========== Size Calculation ==========
250
251    /// Get the total packet size (header + payload).
252    pub fn packet_size(&self) -> usize {
253        UDP_HEADER_LEN + self.payload.len()
254    }
255
256    /// Get the header size (always 8 bytes for UDP).
257    pub fn header_size(&self) -> usize {
258        UDP_HEADER_LEN
259    }
260
261    // ========== Build Methods ==========
262
263    /// Build the UDP packet into a new buffer.
264    pub fn build(&self) -> Vec<u8> {
265        let total_size = self.packet_size();
266        let mut buf = vec![0u8; total_size];
267        self.build_into(&mut buf)
268            .expect("buffer is correctly sized");
269        buf
270    }
271
272    /// Build the UDP packet into an existing buffer.
273    pub fn build_into(&self, buf: &mut [u8]) -> Result<usize, FieldError> {
274        let total_size = self.packet_size();
275
276        if buf.len() < total_size {
277            return Err(FieldError::BufferTooShort {
278                offset: 0,
279                need: total_size,
280                have: buf.len(),
281            });
282        }
283
284        // Calculate length (header + payload)
285        let length = if self.auto_length {
286            total_size as u16
287        } else {
288            self.length.unwrap_or(total_size as u16)
289        };
290
291        // Source Port (big-endian)
292        buf[offsets::SRC_PORT..offsets::SRC_PORT + 2].copy_from_slice(&self.src_port.to_be_bytes());
293
294        // Destination Port (big-endian)
295        buf[offsets::DST_PORT..offsets::DST_PORT + 2].copy_from_slice(&self.dst_port.to_be_bytes());
296
297        // Length (big-endian)
298        buf[offsets::LENGTH..offsets::LENGTH + 2].copy_from_slice(&length.to_be_bytes());
299
300        // Checksum (initially 0, calculated later if auto_checksum is enabled)
301        buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&[0, 0]);
302
303        // Copy payload
304        if !self.payload.is_empty() {
305            buf[UDP_HEADER_LEN..total_size].copy_from_slice(&self.payload);
306        }
307
308        // Calculate checksum if enabled and IP addresses are available
309        if self.auto_checksum {
310            let checksum = self.calculate_checksum(&buf[..total_size]);
311            if let Some(csum) = checksum {
312                // RFC 768: If checksum is 0, it should be 0xFFFF
313                let final_csum = if csum == 0 { 0xFFFF } else { csum };
314                buf[offsets::CHECKSUM..offsets::CHECKSUM + 2]
315                    .copy_from_slice(&final_csum.to_be_bytes());
316            }
317        } else if let Some(csum) = self.checksum {
318            buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&csum.to_be_bytes());
319        }
320
321        Ok(total_size)
322    }
323
324    /// Build just the UDP header (without payload).
325    pub fn build_header(&self) -> Vec<u8> {
326        let mut buf = vec![0u8; UDP_HEADER_LEN];
327
328        // Create a copy without payload for header-only build
329        let builder = Self {
330            payload: Vec::new(),
331            ..self.clone()
332        };
333        builder
334            .build_into(&mut buf)
335            .expect("buffer is correctly sized");
336
337        buf
338    }
339
340    /// Calculate the checksum based on IP addresses.
341    fn calculate_checksum(&self, udp_packet: &[u8]) -> Option<u16> {
342        match (self.src_ip, self.dst_ip) {
343            (Some(IpAddr::V4(src)), Some(IpAddr::V4(dst))) => {
344                Some(udp_checksum_ipv4(src, dst, udp_packet))
345            }
346            (Some(IpAddr::V6(src)), Some(IpAddr::V6(dst))) => {
347                Some(udp_checksum_ipv6(src, dst, udp_packet))
348            }
349            _ => None, // Can't calculate without IP addresses or with mismatched versions
350        }
351    }
352}
353
354// ========== Convenience Constructors ==========
355
356impl UdpBuilder {
357    /// Create a DNS query packet builder (port 53).
358    pub fn dns_query() -> Self {
359        Self::new().src_port(53).dst_port(53)
360    }
361
362    /// Create a DHCP client packet builder (ports 68 -> 67).
363    pub fn dhcp_client() -> Self {
364        Self::new().src_port(68).dst_port(67)
365    }
366
367    /// Create a DHCP server packet builder (ports 67 -> 68).
368    pub fn dhcp_server() -> Self {
369        Self::new().src_port(67).dst_port(68)
370    }
371
372    /// Create a NTP packet builder (port 123).
373    pub fn ntp() -> Self {
374        Self::new().src_port(123).dst_port(123)
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn test_builder_defaults() {
384        let builder = UdpBuilder::new();
385        assert_eq!(builder.src_port, 53);
386        assert_eq!(builder.dst_port, 53);
387        assert!(builder.auto_length);
388        assert!(builder.auto_checksum);
389    }
390
391    #[test]
392    fn test_build_basic() {
393        let packet = UdpBuilder::new()
394            .src_port(12345)
395            .dst_port(80)
396            .payload(b"Hello")
397            .build();
398
399        // Check header
400        assert_eq!(packet.len(), 8 + 5); // header + payload
401        assert_eq!(u16::from_be_bytes([packet[0], packet[1]]), 12345); // sport
402        assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 80); // dport
403        assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 13); // length (8+5)
404
405        // Check payload
406        assert_eq!(&packet[8..], b"Hello");
407    }
408
409    #[test]
410    fn test_build_with_manual_length() {
411        let packet = UdpBuilder::new()
412            .src_port(1234)
413            .dst_port(5678)
414            .length(100)
415            .build();
416
417        assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 100);
418    }
419
420    #[test]
421    fn test_build_with_checksum() {
422        let packet = UdpBuilder::new()
423            .src_port(1234)
424            .dst_port(5678)
425            .src_ipv4(Ipv4Addr::new(192, 168, 1, 1))
426            .dst_ipv4(Ipv4Addr::new(192, 168, 1, 2))
427            .payload(b"test")
428            .build();
429
430        // Checksum should be calculated
431        let checksum = u16::from_be_bytes([packet[6], packet[7]]);
432        assert_ne!(checksum, 0); // Should have a non-zero checksum
433    }
434
435    #[test]
436    fn test_build_with_zero_checksum_becomes_ffff() {
437        // This is a contrived test - in practice it's hard to get exactly 0
438        // But the code handles it per RFC 768
439        let builder = UdpBuilder::new()
440            .src_port(0)
441            .dst_port(0)
442            .disable_auto_checksum()
443            .checksum(0);
444
445        let packet = builder.build();
446        let checksum = u16::from_be_bytes([packet[6], packet[7]]);
447        assert_eq!(checksum, 0); // Manual checksum is used as-is
448    }
449
450    #[test]
451    fn test_from_bytes() {
452        let original = UdpBuilder::new()
453            .src_port(1234)
454            .dst_port(5678)
455            .payload(b"test data")
456            .build();
457
458        let rebuilt = UdpBuilder::from_bytes(&original).unwrap();
459        assert_eq!(rebuilt.src_port, 1234);
460        assert_eq!(rebuilt.dst_port, 5678);
461        assert_eq!(rebuilt.payload, b"test data");
462    }
463
464    #[test]
465    fn test_scapy_aliases() {
466        let packet = UdpBuilder::new()
467            .sport(1234) // alias for src_port
468            .dport(5678) // alias for dst_port
469            .len(20) // alias for length
470            .chksum(0xABCD) // alias for checksum
471            .build();
472
473        assert_eq!(u16::from_be_bytes([packet[0], packet[1]]), 1234);
474        assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 5678);
475        assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 20);
476        assert_eq!(u16::from_be_bytes([packet[6], packet[7]]), 0xABCD);
477    }
478
479    #[test]
480    fn test_convenience_constructors() {
481        let dns = UdpBuilder::dns_query().build();
482        assert_eq!(u16::from_be_bytes([dns[0], dns[1]]), 53);
483        assert_eq!(u16::from_be_bytes([dns[2], dns[3]]), 53);
484
485        let dhcp_client = UdpBuilder::dhcp_client().build();
486        assert_eq!(u16::from_be_bytes([dhcp_client[0], dhcp_client[1]]), 68);
487        assert_eq!(u16::from_be_bytes([dhcp_client[2], dhcp_client[3]]), 67);
488    }
489
490    #[test]
491    fn test_build_header_only() {
492        let header = UdpBuilder::new()
493            .src_port(1234)
494            .dst_port(5678)
495            .payload(b"this should not be included")
496            .build_header();
497
498        assert_eq!(header.len(), 8); // Header only, no payload
499        assert_eq!(u16::from_be_bytes([header[0], header[1]]), 1234);
500    }
501}