Skip to main content

stackforge_core/layer/ipv4/
builder.rs

1//! IPv4 packet builder.
2//!
3//! Provides a fluent API for constructing IPv4 packets with automatic
4//! field calculation (checksum, length, IHL).
5
6use std::net::Ipv4Addr;
7
8use super::checksum::ipv4_checksum;
9use super::header::{IPV4_MIN_HEADER_LEN, Ipv4Flags, Ipv4Layer, offsets};
10use super::options::{Ipv4Option, Ipv4Options, Ipv4OptionsBuilder};
11use super::protocol;
12use crate::layer::field::FieldError;
13
14/// Builder for IPv4 packets.
15///
16/// # Example
17///
18/// ```rust
19/// use stackforge_core::layer::ipv4::{Ipv4Builder, protocol};
20/// use std::net::Ipv4Addr;
21///
22/// let packet = Ipv4Builder::new()
23///     .src(Ipv4Addr::new(192, 168, 1, 1))
24///     .dst(Ipv4Addr::new(192, 168, 1, 2))
25///     .ttl(64)
26///     .protocol(protocol::TCP)
27///     .dont_fragment()
28///     .build();
29///
30/// assert_eq!(packet.len(), 20); // Minimum header, no payload
31/// ```
32#[derive(Debug, Clone)]
33pub struct Ipv4Builder {
34    // Header fields
35    version: u8,
36    ihl: Option<u8>,
37    tos: u8,
38    total_len: Option<u16>,
39    id: u16,
40    flags: Ipv4Flags,
41    frag_offset: u16,
42    ttl: u8,
43    protocol: u8,
44    checksum: Option<u16>,
45    src: Ipv4Addr,
46    dst: Ipv4Addr,
47
48    // Options
49    options: Ipv4Options,
50
51    // Payload
52    payload: Vec<u8>,
53
54    // Build options
55    auto_checksum: bool,
56    auto_length: bool,
57    auto_ihl: bool,
58}
59
60impl Default for Ipv4Builder {
61    fn default() -> Self {
62        Self {
63            version: 4,
64            ihl: None,
65            tos: 0,
66            total_len: None,
67            id: 1,
68            flags: Ipv4Flags::NONE,
69            frag_offset: 0,
70            ttl: 64,
71            protocol: 0,
72            checksum: None,
73            src: Ipv4Addr::LOCALHOST,
74            dst: Ipv4Addr::LOCALHOST,
75            options: Ipv4Options::new(),
76            payload: Vec::new(),
77            auto_checksum: true,
78            auto_length: true,
79            auto_ihl: true,
80        }
81    }
82}
83
84impl Ipv4Builder {
85    /// Create a new IPv4 builder with default values.
86    #[must_use]
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// Create a builder initialized from an existing packet.
92    pub fn from_bytes(data: &[u8]) -> Result<Self, FieldError> {
93        let layer = Ipv4Layer::at_offset_dynamic(data, 0)?;
94
95        let mut builder = Self::new();
96        builder.version = layer.version(data)?;
97        builder.ihl = Some(layer.ihl(data)?);
98        builder.tos = layer.tos(data)?;
99        builder.total_len = Some(layer.total_len(data)?);
100        builder.id = layer.id(data)?;
101        builder.flags = layer.flags(data)?;
102        builder.frag_offset = layer.frag_offset(data)?;
103        builder.ttl = layer.ttl(data)?;
104        builder.protocol = layer.protocol(data)?;
105        builder.checksum = Some(layer.checksum(data)?);
106        builder.src = layer.src(data)?;
107        builder.dst = layer.dst(data)?;
108
109        // Parse options if present
110        if layer.options_len(data) > 0 {
111            builder.options = layer.options(data)?;
112        }
113
114        // Copy payload
115        let header_len = layer.calculate_header_len(data);
116        let total_len = layer.total_len(data)? as usize;
117        if total_len > header_len && data.len() >= total_len {
118            builder.payload = data[header_len..total_len].to_vec();
119        }
120
121        // Disable auto-calculation since we're copying exact values
122        builder.auto_checksum = false;
123        builder.auto_length = false;
124        builder.auto_ihl = false;
125
126        Ok(builder)
127    }
128
129    // ========== Header Field Setters ==========
130
131    /// Set the IP version (should normally be 4).
132    #[must_use]
133    pub fn version(mut self, version: u8) -> Self {
134        self.version = version;
135        self
136    }
137
138    /// Set the Internet Header Length (in 32-bit words).
139    /// If not set, will be calculated automatically.
140    #[must_use]
141    pub fn ihl(mut self, ihl: u8) -> Self {
142        self.ihl = Some(ihl);
143        self.auto_ihl = false;
144        self
145    }
146
147    /// Set the Type of Service field.
148    #[must_use]
149    pub fn tos(mut self, tos: u8) -> Self {
150        self.tos = tos;
151        self
152    }
153
154    /// Set the DSCP (Differentiated Services Code Point).
155    #[must_use]
156    pub fn dscp(mut self, dscp: u8) -> Self {
157        self.tos = (self.tos & 0x03) | ((dscp & 0x3F) << 2);
158        self
159    }
160
161    /// Set the ECN (Explicit Congestion Notification).
162    #[must_use]
163    pub fn ecn(mut self, ecn: u8) -> Self {
164        self.tos = (self.tos & 0xFC) | (ecn & 0x03);
165        self
166    }
167
168    /// Set the total length field.
169    /// If not set, will be calculated automatically.
170    #[must_use]
171    pub fn total_len(mut self, len: u16) -> Self {
172        self.total_len = Some(len);
173        self.auto_length = false;
174        self
175    }
176
177    /// Alias for `total_len` (Scapy compatibility).
178    #[must_use]
179    pub fn len(self, len: u16) -> Self {
180        self.total_len(len)
181    }
182
183    /// Set the identification field.
184    #[must_use]
185    pub fn id(mut self, id: u16) -> Self {
186        self.id = id;
187        self
188    }
189
190    /// Set the flags field.
191    #[must_use]
192    pub fn flags(mut self, flags: Ipv4Flags) -> Self {
193        self.flags = flags;
194        self
195    }
196
197    /// Set the Don't Fragment flag.
198    #[must_use]
199    pub fn dont_fragment(mut self) -> Self {
200        self.flags.df = true;
201        self
202    }
203
204    /// Clear the Don't Fragment flag.
205    #[must_use]
206    pub fn allow_fragment(mut self) -> Self {
207        self.flags.df = false;
208        self
209    }
210
211    /// Set the More Fragments flag.
212    #[must_use]
213    pub fn more_fragments(mut self) -> Self {
214        self.flags.mf = true;
215        self
216    }
217
218    /// Set the reserved/evil bit.
219    #[must_use]
220    pub fn evil(mut self) -> Self {
221        self.flags.reserved = true;
222        self
223    }
224
225    /// Set the fragment offset (in 8-byte units).
226    #[must_use]
227    pub fn frag_offset(mut self, offset: u16) -> Self {
228        self.frag_offset = offset & 0x1FFF;
229        self
230    }
231
232    /// Set the fragment offset in bytes (will be divided by 8).
233    #[must_use]
234    pub fn frag_offset_bytes(mut self, offset: u32) -> Self {
235        self.frag_offset = ((offset / 8) & 0x1FFF) as u16;
236        self
237    }
238
239    /// Set the TTL (Time to Live).
240    #[must_use]
241    pub fn ttl(mut self, ttl: u8) -> Self {
242        self.ttl = ttl;
243        self
244    }
245
246    /// Set the protocol number.
247    #[must_use]
248    pub fn protocol(mut self, protocol: u8) -> Self {
249        self.protocol = protocol;
250        self
251    }
252
253    /// Alias for protocol (Scapy compatibility).
254    #[must_use]
255    pub fn proto(self, protocol: u8) -> Self {
256        self.protocol(protocol)
257    }
258
259    /// Set the checksum manually.
260    /// If not set, will be calculated automatically.
261    #[must_use]
262    pub fn checksum(mut self, checksum: u16) -> Self {
263        self.checksum = Some(checksum);
264        self.auto_checksum = false;
265        self
266    }
267
268    /// Alias for checksum (Scapy compatibility).
269    #[must_use]
270    pub fn chksum(self, checksum: u16) -> Self {
271        self.checksum(checksum)
272    }
273
274    /// Set the source IP address.
275    #[must_use]
276    pub fn src(mut self, src: Ipv4Addr) -> Self {
277        self.src = src;
278        self
279    }
280
281    /// Set the destination IP address.
282    #[must_use]
283    pub fn dst(mut self, dst: Ipv4Addr) -> Self {
284        self.dst = dst;
285        self
286    }
287
288    // ========== Options ==========
289
290    /// Set the options.
291    #[must_use]
292    pub fn options(mut self, options: Ipv4Options) -> Self {
293        self.options = options;
294        self
295    }
296
297    /// Add a single option.
298    #[must_use]
299    pub fn option(mut self, option: Ipv4Option) -> Self {
300        self.options.push(option);
301        self
302    }
303
304    /// Add options using a builder function.
305    pub fn with_options<F>(mut self, f: F) -> Self
306    where
307        F: FnOnce(Ipv4OptionsBuilder) -> Ipv4OptionsBuilder,
308    {
309        self.options = f(Ipv4OptionsBuilder::new()).build();
310        self
311    }
312
313    /// Add a Record Route option.
314    #[must_use]
315    pub fn record_route(mut self, slots: usize) -> Self {
316        self.options.push(Ipv4Option::RecordRoute {
317            pointer: 4,
318            route: vec![Ipv4Addr::UNSPECIFIED; slots],
319        });
320        self
321    }
322
323    /// Add a Loose Source Route option.
324    #[must_use]
325    pub fn lsrr(mut self, route: Vec<Ipv4Addr>) -> Self {
326        self.options.push(Ipv4Option::Lsrr { pointer: 4, route });
327        self
328    }
329
330    /// Add a Strict Source Route option.
331    #[must_use]
332    pub fn ssrr(mut self, route: Vec<Ipv4Addr>) -> Self {
333        self.options.push(Ipv4Option::Ssrr { pointer: 4, route });
334        self
335    }
336
337    /// Add a Router Alert option.
338    #[must_use]
339    pub fn router_alert(mut self, value: u16) -> Self {
340        self.options.push(Ipv4Option::RouterAlert { value });
341        self
342    }
343
344    // ========== Payload ==========
345
346    /// Set the payload data.
347    pub fn payload(mut self, payload: impl Into<Vec<u8>>) -> Self {
348        self.payload = payload.into();
349        self
350    }
351
352    /// Append data to the payload.
353    #[must_use]
354    pub fn append_payload(mut self, data: &[u8]) -> Self {
355        self.payload.extend_from_slice(data);
356        self
357    }
358
359    // ========== Build Options ==========
360
361    /// Enable or disable automatic checksum calculation.
362    #[must_use]
363    pub fn auto_checksum(mut self, enabled: bool) -> Self {
364        self.auto_checksum = enabled;
365        self
366    }
367
368    /// Enable or disable automatic length calculation.
369    #[must_use]
370    pub fn auto_length(mut self, enabled: bool) -> Self {
371        self.auto_length = enabled;
372        self
373    }
374
375    /// Enable or disable automatic IHL calculation.
376    #[must_use]
377    pub fn auto_ihl(mut self, enabled: bool) -> Self {
378        self.auto_ihl = enabled;
379        self
380    }
381
382    // ========== Build Methods ==========
383
384    /// Calculate the header size (including options).
385    #[must_use]
386    pub fn header_size(&self) -> usize {
387        if let Some(ihl) = self.ihl {
388            (ihl as usize) * 4
389        } else {
390            let opts_len = self.options.padded_len();
391            IPV4_MIN_HEADER_LEN + opts_len
392        }
393    }
394
395    /// Calculate the total packet size.
396    #[must_use]
397    pub fn packet_size(&self) -> usize {
398        self.header_size() + self.payload.len()
399    }
400
401    /// Build the IPv4 packet.
402    #[must_use]
403    pub fn build(&self) -> Vec<u8> {
404        let _header_size = self.header_size();
405        let total_size = self.packet_size();
406
407        let mut buf = vec![0u8; total_size];
408        self.build_into(&mut buf)
409            .expect("buffer is correctly sized");
410        buf
411    }
412
413    /// Build the IPv4 packet into an existing buffer.
414    pub fn build_into(&self, buf: &mut [u8]) -> Result<usize, FieldError> {
415        let header_size = self.header_size();
416        let total_size = self.packet_size();
417
418        if buf.len() < total_size {
419            return Err(FieldError::BufferTooShort {
420                offset: 0,
421                need: total_size,
422                have: buf.len(),
423            });
424        }
425
426        // Calculate IHL
427        let ihl = if self.auto_ihl {
428            (header_size / 4) as u8
429        } else {
430            self.ihl.unwrap_or(5)
431        };
432
433        // Calculate total length
434        let total_len = if self.auto_length {
435            total_size as u16
436        } else {
437            self.total_len.unwrap_or(total_size as u16)
438        };
439
440        // Version + IHL
441        buf[offsets::VERSION_IHL] = ((self.version & 0x0F) << 4) | (ihl & 0x0F);
442
443        // TOS
444        buf[offsets::TOS] = self.tos;
445
446        // Total Length
447        buf[offsets::TOTAL_LEN] = (total_len >> 8) as u8;
448        buf[offsets::TOTAL_LEN + 1] = (total_len & 0xFF) as u8;
449
450        // ID
451        buf[offsets::ID] = (self.id >> 8) as u8;
452        buf[offsets::ID + 1] = (self.id & 0xFF) as u8;
453
454        // Flags + Fragment Offset
455        let flags_frag = u16::from(self.flags.to_byte()) << 8 | self.frag_offset;
456        buf[offsets::FLAGS_FRAG] = (flags_frag >> 8) as u8;
457        buf[offsets::FLAGS_FRAG + 1] = (flags_frag & 0xFF) as u8;
458
459        // TTL
460        buf[offsets::TTL] = self.ttl;
461
462        // Protocol
463        buf[offsets::PROTOCOL] = self.protocol;
464
465        // Checksum (initially 0)
466        buf[offsets::CHECKSUM] = 0;
467        buf[offsets::CHECKSUM + 1] = 0;
468
469        // Source IP
470        let src_octets = self.src.octets();
471        buf[offsets::SRC..offsets::SRC + 4].copy_from_slice(&src_octets);
472
473        // Destination IP
474        let dst_octets = self.dst.octets();
475        buf[offsets::DST..offsets::DST + 4].copy_from_slice(&dst_octets);
476
477        // Options
478        if !self.options.is_empty() {
479            let opts_bytes = self.options.to_bytes();
480            let opts_end = offsets::OPTIONS + opts_bytes.len();
481            if opts_end <= header_size {
482                buf[offsets::OPTIONS..opts_end].copy_from_slice(&opts_bytes);
483            }
484        }
485
486        // Payload
487        if !self.payload.is_empty() {
488            buf[header_size..header_size + self.payload.len()].copy_from_slice(&self.payload);
489        }
490
491        // Checksum (computed last)
492        let checksum = if self.auto_checksum {
493            ipv4_checksum(&buf[..header_size])
494        } else {
495            self.checksum.unwrap_or(0)
496        };
497        buf[offsets::CHECKSUM] = (checksum >> 8) as u8;
498        buf[offsets::CHECKSUM + 1] = (checksum & 0xFF) as u8;
499
500        Ok(total_size)
501    }
502
503    /// Build only the header (no payload).
504    #[must_use]
505    pub fn build_header(&self) -> Vec<u8> {
506        let header_size = self.header_size();
507        let mut buf = vec![0u8; header_size];
508
509        // Temporarily clear payload for header-only build
510        let payload = std::mem::take(&mut self.payload.clone());
511        let builder = Self {
512            payload: Vec::new(),
513            ..self.clone()
514        };
515        builder
516            .build_into(&mut buf)
517            .expect("buffer is correctly sized");
518
519        // Don't actually need to restore since we cloned
520        drop(payload);
521
522        buf
523    }
524}
525
526// ========== Convenience Constructors ==========
527
528impl Ipv4Builder {
529    /// Create an ICMP packet builder.
530    #[must_use]
531    pub fn icmp() -> Self {
532        Self::new().protocol(protocol::ICMP)
533    }
534
535    /// Create a TCP packet builder.
536    #[must_use]
537    pub fn tcp() -> Self {
538        Self::new().protocol(protocol::TCP)
539    }
540
541    /// Create a UDP packet builder.
542    #[must_use]
543    pub fn udp() -> Self {
544        Self::new().protocol(protocol::UDP)
545    }
546
547    /// Create an IP-in-IP tunnel packet builder.
548    #[must_use]
549    pub fn ipip() -> Self {
550        Self::new().protocol(protocol::IPV4)
551    }
552
553    /// Create a GRE tunnel packet builder.
554    #[must_use]
555    pub fn gre() -> Self {
556        Self::new().protocol(protocol::GRE)
557    }
558
559    /// Create a packet destined for a specific address.
560    #[must_use]
561    pub fn to(dst: Ipv4Addr) -> Self {
562        Self::new().dst(dst)
563    }
564
565    /// Create a packet from a specific source.
566    #[must_use]
567    pub fn from(src: Ipv4Addr) -> Self {
568        Self::new().src(src)
569    }
570}
571
572// ========== Random Values ==========
573
574#[cfg(feature = "rand")]
575impl Ipv4Builder {
576    /// Set a random ID.
577    #[must_use]
578    pub fn random_id(mut self) -> Self {
579        use rand::Rng;
580        self.id = rand::rng().random();
581        self
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn test_basic_build() {
591        let pkt = Ipv4Builder::new()
592            .src(Ipv4Addr::new(192, 168, 1, 1))
593            .dst(Ipv4Addr::new(192, 168, 1, 2))
594            .ttl(64)
595            .protocol(protocol::TCP)
596            .build();
597
598        assert_eq!(pkt.len(), 20);
599
600        let layer = Ipv4Layer::at_offset(0);
601        assert_eq!(layer.version(&pkt).unwrap(), 4);
602        assert_eq!(layer.ihl(&pkt).unwrap(), 5);
603        assert_eq!(layer.ttl(&pkt).unwrap(), 64);
604        assert_eq!(layer.protocol(&pkt).unwrap(), protocol::TCP);
605        assert_eq!(layer.src(&pkt).unwrap(), Ipv4Addr::new(192, 168, 1, 1));
606        assert_eq!(layer.dst(&pkt).unwrap(), Ipv4Addr::new(192, 168, 1, 2));
607
608        // Verify checksum
609        assert!(layer.verify_checksum(&pkt).unwrap());
610    }
611
612    #[test]
613    fn test_with_payload() {
614        let payload = vec![1, 2, 3, 4, 5];
615        let pkt = Ipv4Builder::new()
616            .src(Ipv4Addr::new(10, 0, 0, 1))
617            .dst(Ipv4Addr::new(10, 0, 0, 2))
618            .protocol(protocol::UDP)
619            .payload(payload.clone())
620            .build();
621
622        assert_eq!(pkt.len(), 25); // 20 + 5
623
624        let layer = Ipv4Layer::at_offset(0);
625        assert_eq!(layer.total_len(&pkt).unwrap(), 25);
626        assert_eq!(layer.payload(&pkt).unwrap(), &payload[..]);
627    }
628
629    #[test]
630    fn test_with_options() {
631        let pkt = Ipv4Builder::new()
632            .src(Ipv4Addr::new(10, 0, 0, 1))
633            .dst(Ipv4Addr::new(10, 0, 0, 2))
634            .router_alert(0)
635            .build();
636
637        // Router Alert is 4 bytes, header should be 24 bytes
638        assert_eq!(pkt.len(), 24);
639
640        let layer = Ipv4Layer::at_offset(0);
641        assert_eq!(layer.ihl(&pkt).unwrap(), 6); // 24/4 = 6
642        assert!(layer.verify_checksum(&pkt).unwrap());
643    }
644
645    #[test]
646    fn test_flags() {
647        let pkt = Ipv4Builder::new()
648            .dst(Ipv4Addr::new(8, 8, 8, 8))
649            .dont_fragment()
650            .build();
651
652        let layer = Ipv4Layer::at_offset(0);
653        let flags = layer.flags(&pkt).unwrap();
654        assert!(flags.df);
655        assert!(!flags.mf);
656    }
657
658    #[test]
659    fn test_fragment() {
660        let pkt = Ipv4Builder::new()
661            .dst(Ipv4Addr::new(8, 8, 8, 8))
662            .more_fragments()
663            .frag_offset(100)
664            .build();
665
666        let layer = Ipv4Layer::at_offset(0);
667        let flags = layer.flags(&pkt).unwrap();
668        assert!(flags.mf);
669        assert_eq!(layer.frag_offset(&pkt).unwrap(), 100);
670    }
671
672    #[test]
673    fn test_dscp_ecn() {
674        let pkt = Ipv4Builder::new()
675            .dst(Ipv4Addr::new(8, 8, 8, 8))
676            .dscp(46) // EF
677            .ecn(2) // ECT(0)
678            .build();
679
680        let layer = Ipv4Layer::at_offset(0);
681        assert_eq!(layer.dscp(&pkt).unwrap(), 46);
682        assert_eq!(layer.ecn(&pkt).unwrap(), 2);
683    }
684
685    #[test]
686    fn test_from_bytes() {
687        let original = Ipv4Builder::new()
688            .src(Ipv4Addr::new(192, 168, 1, 100))
689            .dst(Ipv4Addr::new(192, 168, 1, 200))
690            .ttl(128)
691            .id(0xABCD)
692            .protocol(protocol::ICMP)
693            .payload(vec![8, 0, 0, 0, 0, 1, 0, 1]) // ICMP echo
694            .build();
695
696        let rebuilt = Ipv4Builder::from_bytes(&original)
697            .unwrap()
698            .auto_checksum(true)
699            .build();
700
701        // Should be identical
702        assert_eq!(original.len(), rebuilt.len());
703
704        let layer = Ipv4Layer::at_offset(0);
705        assert_eq!(layer.src(&original).unwrap(), layer.src(&rebuilt).unwrap());
706        assert_eq!(layer.dst(&original).unwrap(), layer.dst(&rebuilt).unwrap());
707        assert_eq!(layer.ttl(&original).unwrap(), layer.ttl(&rebuilt).unwrap());
708        assert_eq!(layer.id(&original).unwrap(), layer.id(&rebuilt).unwrap());
709    }
710
711    #[test]
712    fn test_convenience_constructors() {
713        let icmp = Ipv4Builder::icmp().build();
714        let layer = Ipv4Layer::at_offset(0);
715        assert_eq!(layer.protocol(&icmp).unwrap(), protocol::ICMP);
716
717        let tcp = Ipv4Builder::tcp().build();
718        assert_eq!(layer.protocol(&tcp).unwrap(), protocol::TCP);
719
720        let udp = Ipv4Builder::udp().build();
721        assert_eq!(layer.protocol(&udp).unwrap(), protocol::UDP);
722    }
723
724    #[test]
725    fn test_manual_fields() {
726        let pkt = Ipv4Builder::new()
727            .dst(Ipv4Addr::new(8, 8, 8, 8))
728            .total_len(100)
729            .checksum(0x1234)
730            .ihl(5)
731            .build();
732
733        let layer = Ipv4Layer::at_offset(0);
734        assert_eq!(layer.total_len(&pkt).unwrap(), 100);
735        assert_eq!(layer.checksum(&pkt).unwrap(), 0x1234);
736        assert_eq!(layer.ihl(&pkt).unwrap(), 5);
737    }
738
739    #[test]
740    fn test_source_route_option() {
741        let route = vec![
742            Ipv4Addr::new(10, 0, 0, 1),
743            Ipv4Addr::new(10, 0, 0, 2),
744            Ipv4Addr::new(10, 0, 0, 3),
745        ];
746
747        let pkt = Ipv4Builder::new()
748            .dst(Ipv4Addr::new(10, 0, 0, 4))
749            .lsrr(route.clone())
750            .build();
751
752        let layer = Ipv4Layer::at_offset(0);
753        let options = layer.options(&pkt).unwrap();
754
755        // Check that options are parsed correctly.
756        // We might get extra options (padding/NOP), so we just look for LSRR
757        let lsrr_option = options
758            .options
759            .iter()
760            .find(|opt| matches!(opt, Ipv4Option::Lsrr { .. }));
761
762        assert!(lsrr_option.is_some(), "Expected LSRR option");
763
764        if let Some(Ipv4Option::Lsrr {
765            route: parsed_route,
766            ..
767        }) = lsrr_option
768        {
769            assert_eq!(parsed_route, &route);
770        }
771    }
772}