Skip to main content

stackforge_core/layer/ipv4/
builder.rs

1//! IPv4 packet builder.
2//!
3//! Provides a fluent API for constructing IPv4 packets with automatic
4//! field calculation (checksum, length, IHL).
5
6use std::net::Ipv4Addr;
7
8use super::checksum::ipv4_checksum;
9use super::header::{IPV4_MIN_HEADER_LEN, Ipv4Flags, Ipv4Layer, offsets};
10use super::options::{Ipv4Option, Ipv4Options, Ipv4OptionsBuilder};
11use super::protocol;
12use crate::layer::field::FieldError;
13
14/// Builder for IPv4 packets.
15///
16/// # Example
17///
18/// ```rust
19/// use stackforge_core::layer::ipv4::{Ipv4Builder, protocol};
20/// use std::net::Ipv4Addr;
21///
22/// let packet = Ipv4Builder::new()
23///     .src(Ipv4Addr::new(192, 168, 1, 1))
24///     .dst(Ipv4Addr::new(192, 168, 1, 2))
25///     .ttl(64)
26///     .protocol(protocol::TCP)
27///     .dont_fragment()
28///     .build();
29///
30/// assert_eq!(packet.len(), 20); // Minimum header, no payload
31/// ```
32#[derive(Debug, Clone)]
33pub struct Ipv4Builder {
34    // Header fields
35    version: u8,
36    ihl: Option<u8>,
37    tos: u8,
38    total_len: Option<u16>,
39    id: u16,
40    flags: Ipv4Flags,
41    frag_offset: u16,
42    ttl: u8,
43    protocol: u8,
44    checksum: Option<u16>,
45    src: Ipv4Addr,
46    dst: Ipv4Addr,
47
48    // Options
49    options: Ipv4Options,
50
51    // Payload
52    payload: Vec<u8>,
53
54    // Build options
55    auto_checksum: bool,
56    auto_length: bool,
57    auto_ihl: bool,
58}
59
60impl Default for Ipv4Builder {
61    fn default() -> Self {
62        Self {
63            version: 4,
64            ihl: None,
65            tos: 0,
66            total_len: None,
67            id: 1,
68            flags: Ipv4Flags::NONE,
69            frag_offset: 0,
70            ttl: 64,
71            protocol: 0,
72            checksum: None,
73            src: Ipv4Addr::new(127, 0, 0, 1),
74            dst: Ipv4Addr::new(127, 0, 0, 1),
75            options: Ipv4Options::new(),
76            payload: Vec::new(),
77            auto_checksum: true,
78            auto_length: true,
79            auto_ihl: true,
80        }
81    }
82}
83
84impl Ipv4Builder {
85    /// Create a new IPv4 builder with default values.
86    pub fn new() -> Self {
87        Self::default()
88    }
89
90    /// Create a builder initialized from an existing packet.
91    pub fn from_bytes(data: &[u8]) -> Result<Self, FieldError> {
92        let layer = Ipv4Layer::at_offset_dynamic(data, 0)?;
93
94        let mut builder = Self::new();
95        builder.version = layer.version(data)?;
96        builder.ihl = Some(layer.ihl(data)?);
97        builder.tos = layer.tos(data)?;
98        builder.total_len = Some(layer.total_len(data)?);
99        builder.id = layer.id(data)?;
100        builder.flags = layer.flags(data)?;
101        builder.frag_offset = layer.frag_offset(data)?;
102        builder.ttl = layer.ttl(data)?;
103        builder.protocol = layer.protocol(data)?;
104        builder.checksum = Some(layer.checksum(data)?);
105        builder.src = layer.src(data)?;
106        builder.dst = layer.dst(data)?;
107
108        // Parse options if present
109        if layer.options_len(data) > 0 {
110            builder.options = layer.options(data)?;
111        }
112
113        // Copy payload
114        let header_len = layer.calculate_header_len(data);
115        let total_len = layer.total_len(data)? as usize;
116        if total_len > header_len && data.len() >= total_len {
117            builder.payload = data[header_len..total_len].to_vec();
118        }
119
120        // Disable auto-calculation since we're copying exact values
121        builder.auto_checksum = false;
122        builder.auto_length = false;
123        builder.auto_ihl = false;
124
125        Ok(builder)
126    }
127
128    // ========== Header Field Setters ==========
129
130    /// Set the IP version (should normally be 4).
131    pub fn version(mut self, version: u8) -> Self {
132        self.version = version;
133        self
134    }
135
136    /// Set the Internet Header Length (in 32-bit words).
137    /// If not set, will be calculated automatically.
138    pub fn ihl(mut self, ihl: u8) -> Self {
139        self.ihl = Some(ihl);
140        self.auto_ihl = false;
141        self
142    }
143
144    /// Set the Type of Service field.
145    pub fn tos(mut self, tos: u8) -> Self {
146        self.tos = tos;
147        self
148    }
149
150    /// Set the DSCP (Differentiated Services Code Point).
151    pub fn dscp(mut self, dscp: u8) -> Self {
152        self.tos = (self.tos & 0x03) | ((dscp & 0x3F) << 2);
153        self
154    }
155
156    /// Set the ECN (Explicit Congestion Notification).
157    pub fn ecn(mut self, ecn: u8) -> Self {
158        self.tos = (self.tos & 0xFC) | (ecn & 0x03);
159        self
160    }
161
162    /// Set the total length field.
163    /// If not set, will be calculated automatically.
164    pub fn total_len(mut self, len: u16) -> Self {
165        self.total_len = Some(len);
166        self.auto_length = false;
167        self
168    }
169
170    /// Alias for total_len (Scapy compatibility).
171    pub fn len(self, len: u16) -> Self {
172        self.total_len(len)
173    }
174
175    /// Set the identification field.
176    pub fn id(mut self, id: u16) -> Self {
177        self.id = id;
178        self
179    }
180
181    /// Set the flags field.
182    pub fn flags(mut self, flags: Ipv4Flags) -> Self {
183        self.flags = flags;
184        self
185    }
186
187    /// Set the Don't Fragment flag.
188    pub fn dont_fragment(mut self) -> Self {
189        self.flags.df = true;
190        self
191    }
192
193    /// Clear the Don't Fragment flag.
194    pub fn allow_fragment(mut self) -> Self {
195        self.flags.df = false;
196        self
197    }
198
199    /// Set the More Fragments flag.
200    pub fn more_fragments(mut self) -> Self {
201        self.flags.mf = true;
202        self
203    }
204
205    /// Set the reserved/evil bit.
206    pub fn evil(mut self) -> Self {
207        self.flags.reserved = true;
208        self
209    }
210
211    /// Set the fragment offset (in 8-byte units).
212    pub fn frag_offset(mut self, offset: u16) -> Self {
213        self.frag_offset = offset & 0x1FFF;
214        self
215    }
216
217    /// Set the fragment offset in bytes (will be divided by 8).
218    pub fn frag_offset_bytes(mut self, offset: u32) -> Self {
219        self.frag_offset = ((offset / 8) & 0x1FFF) as u16;
220        self
221    }
222
223    /// Set the TTL (Time to Live).
224    pub fn ttl(mut self, ttl: u8) -> Self {
225        self.ttl = ttl;
226        self
227    }
228
229    /// Set the protocol number.
230    pub fn protocol(mut self, protocol: u8) -> Self {
231        self.protocol = protocol;
232        self
233    }
234
235    /// Alias for protocol (Scapy compatibility).
236    pub fn proto(self, protocol: u8) -> Self {
237        self.protocol(protocol)
238    }
239
240    /// Set the checksum manually.
241    /// If not set, will be calculated automatically.
242    pub fn checksum(mut self, checksum: u16) -> Self {
243        self.checksum = Some(checksum);
244        self.auto_checksum = false;
245        self
246    }
247
248    /// Alias for checksum (Scapy compatibility).
249    pub fn chksum(self, checksum: u16) -> Self {
250        self.checksum(checksum)
251    }
252
253    /// Set the source IP address.
254    pub fn src(mut self, src: Ipv4Addr) -> Self {
255        self.src = src;
256        self
257    }
258
259    /// Set the destination IP address.
260    pub fn dst(mut self, dst: Ipv4Addr) -> Self {
261        self.dst = dst;
262        self
263    }
264
265    // ========== Options ==========
266
267    /// Set the options.
268    pub fn options(mut self, options: Ipv4Options) -> Self {
269        self.options = options;
270        self
271    }
272
273    /// Add a single option.
274    pub fn option(mut self, option: Ipv4Option) -> Self {
275        self.options.push(option);
276        self
277    }
278
279    /// Add options using a builder function.
280    pub fn with_options<F>(mut self, f: F) -> Self
281    where
282        F: FnOnce(Ipv4OptionsBuilder) -> Ipv4OptionsBuilder,
283    {
284        self.options = f(Ipv4OptionsBuilder::new()).build();
285        self
286    }
287
288    /// Add a Record Route option.
289    pub fn record_route(mut self, slots: usize) -> Self {
290        self.options.push(Ipv4Option::RecordRoute {
291            pointer: 4,
292            route: vec![Ipv4Addr::UNSPECIFIED; slots],
293        });
294        self
295    }
296
297    /// Add a Loose Source Route option.
298    pub fn lsrr(mut self, route: Vec<Ipv4Addr>) -> Self {
299        self.options.push(Ipv4Option::Lsrr { pointer: 4, route });
300        self
301    }
302
303    /// Add a Strict Source Route option.
304    pub fn ssrr(mut self, route: Vec<Ipv4Addr>) -> Self {
305        self.options.push(Ipv4Option::Ssrr { pointer: 4, route });
306        self
307    }
308
309    /// Add a Router Alert option.
310    pub fn router_alert(mut self, value: u16) -> Self {
311        self.options.push(Ipv4Option::RouterAlert { value });
312        self
313    }
314
315    // ========== Payload ==========
316
317    /// Set the payload data.
318    pub fn payload(mut self, payload: impl Into<Vec<u8>>) -> Self {
319        self.payload = payload.into();
320        self
321    }
322
323    /// Append data to the payload.
324    pub fn append_payload(mut self, data: &[u8]) -> Self {
325        self.payload.extend_from_slice(data);
326        self
327    }
328
329    // ========== Build Options ==========
330
331    /// Enable or disable automatic checksum calculation.
332    pub fn auto_checksum(mut self, enabled: bool) -> Self {
333        self.auto_checksum = enabled;
334        self
335    }
336
337    /// Enable or disable automatic length calculation.
338    pub fn auto_length(mut self, enabled: bool) -> Self {
339        self.auto_length = enabled;
340        self
341    }
342
343    /// Enable or disable automatic IHL calculation.
344    pub fn auto_ihl(mut self, enabled: bool) -> Self {
345        self.auto_ihl = enabled;
346        self
347    }
348
349    // ========== Build Methods ==========
350
351    /// Calculate the header size (including options).
352    pub fn header_size(&self) -> usize {
353        if let Some(ihl) = self.ihl {
354            (ihl as usize) * 4
355        } else {
356            let opts_len = self.options.padded_len();
357            IPV4_MIN_HEADER_LEN + opts_len
358        }
359    }
360
361    /// Calculate the total packet size.
362    pub fn packet_size(&self) -> usize {
363        self.header_size() + self.payload.len()
364    }
365
366    /// Build the IPv4 packet.
367    pub fn build(&self) -> Vec<u8> {
368        let _header_size = self.header_size();
369        let total_size = self.packet_size();
370
371        let mut buf = vec![0u8; total_size];
372        self.build_into(&mut buf)
373            .expect("buffer is correctly sized");
374        buf
375    }
376
377    /// Build the IPv4 packet into an existing buffer.
378    pub fn build_into(&self, buf: &mut [u8]) -> Result<usize, FieldError> {
379        let header_size = self.header_size();
380        let total_size = self.packet_size();
381
382        if buf.len() < total_size {
383            return Err(FieldError::BufferTooShort {
384                offset: 0,
385                need: total_size,
386                have: buf.len(),
387            });
388        }
389
390        // Calculate IHL
391        let ihl = if self.auto_ihl {
392            (header_size / 4) as u8
393        } else {
394            self.ihl.unwrap_or(5)
395        };
396
397        // Calculate total length
398        let total_len = if self.auto_length {
399            total_size as u16
400        } else {
401            self.total_len.unwrap_or(total_size as u16)
402        };
403
404        // Version + IHL
405        buf[offsets::VERSION_IHL] = ((self.version & 0x0F) << 4) | (ihl & 0x0F);
406
407        // TOS
408        buf[offsets::TOS] = self.tos;
409
410        // Total Length
411        buf[offsets::TOTAL_LEN] = (total_len >> 8) as u8;
412        buf[offsets::TOTAL_LEN + 1] = (total_len & 0xFF) as u8;
413
414        // ID
415        buf[offsets::ID] = (self.id >> 8) as u8;
416        buf[offsets::ID + 1] = (self.id & 0xFF) as u8;
417
418        // Flags + Fragment Offset
419        let flags_frag = (self.flags.to_byte() as u16) << 8 | self.frag_offset;
420        buf[offsets::FLAGS_FRAG] = (flags_frag >> 8) as u8;
421        buf[offsets::FLAGS_FRAG + 1] = (flags_frag & 0xFF) as u8;
422
423        // TTL
424        buf[offsets::TTL] = self.ttl;
425
426        // Protocol
427        buf[offsets::PROTOCOL] = self.protocol;
428
429        // Checksum (initially 0)
430        buf[offsets::CHECKSUM] = 0;
431        buf[offsets::CHECKSUM + 1] = 0;
432
433        // Source IP
434        let src_octets = self.src.octets();
435        buf[offsets::SRC..offsets::SRC + 4].copy_from_slice(&src_octets);
436
437        // Destination IP
438        let dst_octets = self.dst.octets();
439        buf[offsets::DST..offsets::DST + 4].copy_from_slice(&dst_octets);
440
441        // Options
442        if !self.options.is_empty() {
443            let opts_bytes = self.options.to_bytes();
444            let opts_end = offsets::OPTIONS + opts_bytes.len();
445            if opts_end <= header_size {
446                buf[offsets::OPTIONS..opts_end].copy_from_slice(&opts_bytes);
447            }
448        }
449
450        // Payload
451        if !self.payload.is_empty() {
452            buf[header_size..header_size + self.payload.len()].copy_from_slice(&self.payload);
453        }
454
455        // Checksum (computed last)
456        let checksum = if self.auto_checksum {
457            ipv4_checksum(&buf[..header_size])
458        } else {
459            self.checksum.unwrap_or(0)
460        };
461        buf[offsets::CHECKSUM] = (checksum >> 8) as u8;
462        buf[offsets::CHECKSUM + 1] = (checksum & 0xFF) as u8;
463
464        Ok(total_size)
465    }
466
467    /// Build only the header (no payload).
468    pub fn build_header(&self) -> Vec<u8> {
469        let header_size = self.header_size();
470        let mut buf = vec![0u8; header_size];
471
472        // Temporarily clear payload for header-only build
473        let payload = std::mem::take(&mut self.payload.clone());
474        let builder = Self {
475            payload: Vec::new(),
476            ..self.clone()
477        };
478        builder
479            .build_into(&mut buf)
480            .expect("buffer is correctly sized");
481
482        // Don't actually need to restore since we cloned
483        drop(payload);
484
485        buf
486    }
487}
488
489// ========== Convenience Constructors ==========
490
491impl Ipv4Builder {
492    /// Create an ICMP packet builder.
493    pub fn icmp() -> Self {
494        Self::new().protocol(protocol::ICMP)
495    }
496
497    /// Create a TCP packet builder.
498    pub fn tcp() -> Self {
499        Self::new().protocol(protocol::TCP)
500    }
501
502    /// Create a UDP packet builder.
503    pub fn udp() -> Self {
504        Self::new().protocol(protocol::UDP)
505    }
506
507    /// Create an IP-in-IP tunnel packet builder.
508    pub fn ipip() -> Self {
509        Self::new().protocol(protocol::IPV4)
510    }
511
512    /// Create a GRE tunnel packet builder.
513    pub fn gre() -> Self {
514        Self::new().protocol(protocol::GRE)
515    }
516
517    /// Create a packet destined for a specific address.
518    pub fn to(dst: Ipv4Addr) -> Self {
519        Self::new().dst(dst)
520    }
521
522    /// Create a packet from a specific source.
523    pub fn from(src: Ipv4Addr) -> Self {
524        Self::new().src(src)
525    }
526}
527
528// ========== Random Values ==========
529
530#[cfg(feature = "rand")]
531impl Ipv4Builder {
532    /// Set a random ID.
533    pub fn random_id(mut self) -> Self {
534        use rand::Rng;
535        self.id = rand::rng().random();
536        self
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    #[test]
545    fn test_basic_build() {
546        let pkt = Ipv4Builder::new()
547            .src(Ipv4Addr::new(192, 168, 1, 1))
548            .dst(Ipv4Addr::new(192, 168, 1, 2))
549            .ttl(64)
550            .protocol(protocol::TCP)
551            .build();
552
553        assert_eq!(pkt.len(), 20);
554
555        let layer = Ipv4Layer::at_offset(0);
556        assert_eq!(layer.version(&pkt).unwrap(), 4);
557        assert_eq!(layer.ihl(&pkt).unwrap(), 5);
558        assert_eq!(layer.ttl(&pkt).unwrap(), 64);
559        assert_eq!(layer.protocol(&pkt).unwrap(), protocol::TCP);
560        assert_eq!(layer.src(&pkt).unwrap(), Ipv4Addr::new(192, 168, 1, 1));
561        assert_eq!(layer.dst(&pkt).unwrap(), Ipv4Addr::new(192, 168, 1, 2));
562
563        // Verify checksum
564        assert!(layer.verify_checksum(&pkt).unwrap());
565    }
566
567    #[test]
568    fn test_with_payload() {
569        let payload = vec![1, 2, 3, 4, 5];
570        let pkt = Ipv4Builder::new()
571            .src(Ipv4Addr::new(10, 0, 0, 1))
572            .dst(Ipv4Addr::new(10, 0, 0, 2))
573            .protocol(protocol::UDP)
574            .payload(payload.clone())
575            .build();
576
577        assert_eq!(pkt.len(), 25); // 20 + 5
578
579        let layer = Ipv4Layer::at_offset(0);
580        assert_eq!(layer.total_len(&pkt).unwrap(), 25);
581        assert_eq!(layer.payload(&pkt).unwrap(), &payload[..]);
582    }
583
584    #[test]
585    fn test_with_options() {
586        let pkt = Ipv4Builder::new()
587            .src(Ipv4Addr::new(10, 0, 0, 1))
588            .dst(Ipv4Addr::new(10, 0, 0, 2))
589            .router_alert(0)
590            .build();
591
592        // Router Alert is 4 bytes, header should be 24 bytes
593        assert_eq!(pkt.len(), 24);
594
595        let layer = Ipv4Layer::at_offset(0);
596        assert_eq!(layer.ihl(&pkt).unwrap(), 6); // 24/4 = 6
597        assert!(layer.verify_checksum(&pkt).unwrap());
598    }
599
600    #[test]
601    fn test_flags() {
602        let pkt = Ipv4Builder::new()
603            .dst(Ipv4Addr::new(8, 8, 8, 8))
604            .dont_fragment()
605            .build();
606
607        let layer = Ipv4Layer::at_offset(0);
608        let flags = layer.flags(&pkt).unwrap();
609        assert!(flags.df);
610        assert!(!flags.mf);
611    }
612
613    #[test]
614    fn test_fragment() {
615        let pkt = Ipv4Builder::new()
616            .dst(Ipv4Addr::new(8, 8, 8, 8))
617            .more_fragments()
618            .frag_offset(100)
619            .build();
620
621        let layer = Ipv4Layer::at_offset(0);
622        let flags = layer.flags(&pkt).unwrap();
623        assert!(flags.mf);
624        assert_eq!(layer.frag_offset(&pkt).unwrap(), 100);
625    }
626
627    #[test]
628    fn test_dscp_ecn() {
629        let pkt = Ipv4Builder::new()
630            .dst(Ipv4Addr::new(8, 8, 8, 8))
631            .dscp(46) // EF
632            .ecn(2) // ECT(0)
633            .build();
634
635        let layer = Ipv4Layer::at_offset(0);
636        assert_eq!(layer.dscp(&pkt).unwrap(), 46);
637        assert_eq!(layer.ecn(&pkt).unwrap(), 2);
638    }
639
640    #[test]
641    fn test_from_bytes() {
642        let original = Ipv4Builder::new()
643            .src(Ipv4Addr::new(192, 168, 1, 100))
644            .dst(Ipv4Addr::new(192, 168, 1, 200))
645            .ttl(128)
646            .id(0xABCD)
647            .protocol(protocol::ICMP)
648            .payload(vec![8, 0, 0, 0, 0, 1, 0, 1]) // ICMP echo
649            .build();
650
651        let rebuilt = Ipv4Builder::from_bytes(&original)
652            .unwrap()
653            .auto_checksum(true)
654            .build();
655
656        // Should be identical
657        assert_eq!(original.len(), rebuilt.len());
658
659        let layer = Ipv4Layer::at_offset(0);
660        assert_eq!(layer.src(&original).unwrap(), layer.src(&rebuilt).unwrap());
661        assert_eq!(layer.dst(&original).unwrap(), layer.dst(&rebuilt).unwrap());
662        assert_eq!(layer.ttl(&original).unwrap(), layer.ttl(&rebuilt).unwrap());
663        assert_eq!(layer.id(&original).unwrap(), layer.id(&rebuilt).unwrap());
664    }
665
666    #[test]
667    fn test_convenience_constructors() {
668        let icmp = Ipv4Builder::icmp().build();
669        let layer = Ipv4Layer::at_offset(0);
670        assert_eq!(layer.protocol(&icmp).unwrap(), protocol::ICMP);
671
672        let tcp = Ipv4Builder::tcp().build();
673        assert_eq!(layer.protocol(&tcp).unwrap(), protocol::TCP);
674
675        let udp = Ipv4Builder::udp().build();
676        assert_eq!(layer.protocol(&udp).unwrap(), protocol::UDP);
677    }
678
679    #[test]
680    fn test_manual_fields() {
681        let pkt = Ipv4Builder::new()
682            .dst(Ipv4Addr::new(8, 8, 8, 8))
683            .total_len(100)
684            .checksum(0x1234)
685            .ihl(5)
686            .build();
687
688        let layer = Ipv4Layer::at_offset(0);
689        assert_eq!(layer.total_len(&pkt).unwrap(), 100);
690        assert_eq!(layer.checksum(&pkt).unwrap(), 0x1234);
691        assert_eq!(layer.ihl(&pkt).unwrap(), 5);
692    }
693
694    #[test]
695    fn test_source_route_option() {
696        let route = vec![
697            Ipv4Addr::new(10, 0, 0, 1),
698            Ipv4Addr::new(10, 0, 0, 2),
699            Ipv4Addr::new(10, 0, 0, 3),
700        ];
701
702        let pkt = Ipv4Builder::new()
703            .dst(Ipv4Addr::new(10, 0, 0, 4))
704            .lsrr(route.clone())
705            .build();
706
707        let layer = Ipv4Layer::at_offset(0);
708        let options = layer.options(&pkt).unwrap();
709
710        // Check that options are parsed correctly.
711        // We might get extra options (padding/NOP), so we just look for LSRR
712        let lsrr_option = options
713            .options
714            .iter()
715            .find(|opt| matches!(opt, Ipv4Option::Lsrr { .. }));
716
717        assert!(lsrr_option.is_some(), "Expected LSRR option");
718
719        if let Some(Ipv4Option::Lsrr {
720            route: parsed_route,
721            ..
722        }) = lsrr_option
723        {
724            assert_eq!(parsed_route, &route);
725        }
726    }
727}