1use std::net::Ipv4Addr;
7
8use super::checksum::ipv4_checksum;
9use super::header::{IPV4_MIN_HEADER_LEN, Ipv4Flags, Ipv4Layer, offsets};
10use super::options::{Ipv4Option, Ipv4Options, Ipv4OptionsBuilder};
11use super::protocol;
12use crate::layer::field::FieldError;
13
14#[derive(Debug, Clone)]
33pub struct Ipv4Builder {
34 version: u8,
36 ihl: Option<u8>,
37 tos: u8,
38 total_len: Option<u16>,
39 id: u16,
40 flags: Ipv4Flags,
41 frag_offset: u16,
42 ttl: u8,
43 protocol: u8,
44 checksum: Option<u16>,
45 src: Ipv4Addr,
46 dst: Ipv4Addr,
47
48 options: Ipv4Options,
50
51 payload: Vec<u8>,
53
54 auto_checksum: bool,
56 auto_length: bool,
57 auto_ihl: bool,
58}
59
60impl Default for Ipv4Builder {
61 fn default() -> Self {
62 Self {
63 version: 4,
64 ihl: None,
65 tos: 0,
66 total_len: None,
67 id: 1,
68 flags: Ipv4Flags::NONE,
69 frag_offset: 0,
70 ttl: 64,
71 protocol: 0,
72 checksum: None,
73 src: Ipv4Addr::LOCALHOST,
74 dst: Ipv4Addr::LOCALHOST,
75 options: Ipv4Options::new(),
76 payload: Vec::new(),
77 auto_checksum: true,
78 auto_length: true,
79 auto_ihl: true,
80 }
81 }
82}
83
84impl Ipv4Builder {
85 #[must_use]
87 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn from_bytes(data: &[u8]) -> Result<Self, FieldError> {
93 let layer = Ipv4Layer::at_offset_dynamic(data, 0)?;
94
95 let mut builder = Self::new();
96 builder.version = layer.version(data)?;
97 builder.ihl = Some(layer.ihl(data)?);
98 builder.tos = layer.tos(data)?;
99 builder.total_len = Some(layer.total_len(data)?);
100 builder.id = layer.id(data)?;
101 builder.flags = layer.flags(data)?;
102 builder.frag_offset = layer.frag_offset(data)?;
103 builder.ttl = layer.ttl(data)?;
104 builder.protocol = layer.protocol(data)?;
105 builder.checksum = Some(layer.checksum(data)?);
106 builder.src = layer.src(data)?;
107 builder.dst = layer.dst(data)?;
108
109 if layer.options_len(data) > 0 {
111 builder.options = layer.options(data)?;
112 }
113
114 let header_len = layer.calculate_header_len(data);
116 let total_len = layer.total_len(data)? as usize;
117 if total_len > header_len && data.len() >= total_len {
118 builder.payload = data[header_len..total_len].to_vec();
119 }
120
121 builder.auto_checksum = false;
123 builder.auto_length = false;
124 builder.auto_ihl = false;
125
126 Ok(builder)
127 }
128
129 #[must_use]
133 pub fn version(mut self, version: u8) -> Self {
134 self.version = version;
135 self
136 }
137
138 #[must_use]
141 pub fn ihl(mut self, ihl: u8) -> Self {
142 self.ihl = Some(ihl);
143 self.auto_ihl = false;
144 self
145 }
146
147 #[must_use]
149 pub fn tos(mut self, tos: u8) -> Self {
150 self.tos = tos;
151 self
152 }
153
154 #[must_use]
156 pub fn dscp(mut self, dscp: u8) -> Self {
157 self.tos = (self.tos & 0x03) | ((dscp & 0x3F) << 2);
158 self
159 }
160
161 #[must_use]
163 pub fn ecn(mut self, ecn: u8) -> Self {
164 self.tos = (self.tos & 0xFC) | (ecn & 0x03);
165 self
166 }
167
168 #[must_use]
171 pub fn total_len(mut self, len: u16) -> Self {
172 self.total_len = Some(len);
173 self.auto_length = false;
174 self
175 }
176
177 #[must_use]
179 pub fn len(self, len: u16) -> Self {
180 self.total_len(len)
181 }
182
183 #[must_use]
185 pub fn id(mut self, id: u16) -> Self {
186 self.id = id;
187 self
188 }
189
190 #[must_use]
192 pub fn flags(mut self, flags: Ipv4Flags) -> Self {
193 self.flags = flags;
194 self
195 }
196
197 #[must_use]
199 pub fn dont_fragment(mut self) -> Self {
200 self.flags.df = true;
201 self
202 }
203
204 #[must_use]
206 pub fn allow_fragment(mut self) -> Self {
207 self.flags.df = false;
208 self
209 }
210
211 #[must_use]
213 pub fn more_fragments(mut self) -> Self {
214 self.flags.mf = true;
215 self
216 }
217
218 #[must_use]
220 pub fn evil(mut self) -> Self {
221 self.flags.reserved = true;
222 self
223 }
224
225 #[must_use]
227 pub fn frag_offset(mut self, offset: u16) -> Self {
228 self.frag_offset = offset & 0x1FFF;
229 self
230 }
231
232 #[must_use]
234 pub fn frag_offset_bytes(mut self, offset: u32) -> Self {
235 self.frag_offset = ((offset / 8) & 0x1FFF) as u16;
236 self
237 }
238
239 #[must_use]
241 pub fn ttl(mut self, ttl: u8) -> Self {
242 self.ttl = ttl;
243 self
244 }
245
246 #[must_use]
248 pub fn protocol(mut self, protocol: u8) -> Self {
249 self.protocol = protocol;
250 self
251 }
252
253 #[must_use]
255 pub fn proto(self, protocol: u8) -> Self {
256 self.protocol(protocol)
257 }
258
259 #[must_use]
262 pub fn checksum(mut self, checksum: u16) -> Self {
263 self.checksum = Some(checksum);
264 self.auto_checksum = false;
265 self
266 }
267
268 #[must_use]
270 pub fn chksum(self, checksum: u16) -> Self {
271 self.checksum(checksum)
272 }
273
274 #[must_use]
276 pub fn src(mut self, src: Ipv4Addr) -> Self {
277 self.src = src;
278 self
279 }
280
281 #[must_use]
283 pub fn dst(mut self, dst: Ipv4Addr) -> Self {
284 self.dst = dst;
285 self
286 }
287
288 #[must_use]
292 pub fn options(mut self, options: Ipv4Options) -> Self {
293 self.options = options;
294 self
295 }
296
297 #[must_use]
299 pub fn option(mut self, option: Ipv4Option) -> Self {
300 self.options.push(option);
301 self
302 }
303
304 pub fn with_options<F>(mut self, f: F) -> Self
306 where
307 F: FnOnce(Ipv4OptionsBuilder) -> Ipv4OptionsBuilder,
308 {
309 self.options = f(Ipv4OptionsBuilder::new()).build();
310 self
311 }
312
313 #[must_use]
315 pub fn record_route(mut self, slots: usize) -> Self {
316 self.options.push(Ipv4Option::RecordRoute {
317 pointer: 4,
318 route: vec![Ipv4Addr::UNSPECIFIED; slots],
319 });
320 self
321 }
322
323 #[must_use]
325 pub fn lsrr(mut self, route: Vec<Ipv4Addr>) -> Self {
326 self.options.push(Ipv4Option::Lsrr { pointer: 4, route });
327 self
328 }
329
330 #[must_use]
332 pub fn ssrr(mut self, route: Vec<Ipv4Addr>) -> Self {
333 self.options.push(Ipv4Option::Ssrr { pointer: 4, route });
334 self
335 }
336
337 #[must_use]
339 pub fn router_alert(mut self, value: u16) -> Self {
340 self.options.push(Ipv4Option::RouterAlert { value });
341 self
342 }
343
344 pub fn payload(mut self, payload: impl Into<Vec<u8>>) -> Self {
348 self.payload = payload.into();
349 self
350 }
351
352 #[must_use]
354 pub fn append_payload(mut self, data: &[u8]) -> Self {
355 self.payload.extend_from_slice(data);
356 self
357 }
358
359 #[must_use]
363 pub fn auto_checksum(mut self, enabled: bool) -> Self {
364 self.auto_checksum = enabled;
365 self
366 }
367
368 #[must_use]
370 pub fn auto_length(mut self, enabled: bool) -> Self {
371 self.auto_length = enabled;
372 self
373 }
374
375 #[must_use]
377 pub fn auto_ihl(mut self, enabled: bool) -> Self {
378 self.auto_ihl = enabled;
379 self
380 }
381
382 #[must_use]
386 pub fn header_size(&self) -> usize {
387 if let Some(ihl) = self.ihl {
388 (ihl as usize) * 4
389 } else {
390 let opts_len = self.options.padded_len();
391 IPV4_MIN_HEADER_LEN + opts_len
392 }
393 }
394
395 #[must_use]
397 pub fn packet_size(&self) -> usize {
398 self.header_size() + self.payload.len()
399 }
400
401 #[must_use]
403 pub fn build(&self) -> Vec<u8> {
404 let _header_size = self.header_size();
405 let total_size = self.packet_size();
406
407 let mut buf = vec![0u8; total_size];
408 self.build_into(&mut buf)
409 .expect("buffer is correctly sized");
410 buf
411 }
412
413 pub fn build_into(&self, buf: &mut [u8]) -> Result<usize, FieldError> {
415 let header_size = self.header_size();
416 let total_size = self.packet_size();
417
418 if buf.len() < total_size {
419 return Err(FieldError::BufferTooShort {
420 offset: 0,
421 need: total_size,
422 have: buf.len(),
423 });
424 }
425
426 let ihl = if self.auto_ihl {
428 (header_size / 4) as u8
429 } else {
430 self.ihl.unwrap_or(5)
431 };
432
433 let total_len = if self.auto_length {
435 total_size as u16
436 } else {
437 self.total_len.unwrap_or(total_size as u16)
438 };
439
440 buf[offsets::VERSION_IHL] = ((self.version & 0x0F) << 4) | (ihl & 0x0F);
442
443 buf[offsets::TOS] = self.tos;
445
446 buf[offsets::TOTAL_LEN] = (total_len >> 8) as u8;
448 buf[offsets::TOTAL_LEN + 1] = (total_len & 0xFF) as u8;
449
450 buf[offsets::ID] = (self.id >> 8) as u8;
452 buf[offsets::ID + 1] = (self.id & 0xFF) as u8;
453
454 let flags_frag = u16::from(self.flags.to_byte()) << 8 | self.frag_offset;
456 buf[offsets::FLAGS_FRAG] = (flags_frag >> 8) as u8;
457 buf[offsets::FLAGS_FRAG + 1] = (flags_frag & 0xFF) as u8;
458
459 buf[offsets::TTL] = self.ttl;
461
462 buf[offsets::PROTOCOL] = self.protocol;
464
465 buf[offsets::CHECKSUM] = 0;
467 buf[offsets::CHECKSUM + 1] = 0;
468
469 let src_octets = self.src.octets();
471 buf[offsets::SRC..offsets::SRC + 4].copy_from_slice(&src_octets);
472
473 let dst_octets = self.dst.octets();
475 buf[offsets::DST..offsets::DST + 4].copy_from_slice(&dst_octets);
476
477 if !self.options.is_empty() {
479 let opts_bytes = self.options.to_bytes();
480 let opts_end = offsets::OPTIONS + opts_bytes.len();
481 if opts_end <= header_size {
482 buf[offsets::OPTIONS..opts_end].copy_from_slice(&opts_bytes);
483 }
484 }
485
486 if !self.payload.is_empty() {
488 buf[header_size..header_size + self.payload.len()].copy_from_slice(&self.payload);
489 }
490
491 let checksum = if self.auto_checksum {
493 ipv4_checksum(&buf[..header_size])
494 } else {
495 self.checksum.unwrap_or(0)
496 };
497 buf[offsets::CHECKSUM] = (checksum >> 8) as u8;
498 buf[offsets::CHECKSUM + 1] = (checksum & 0xFF) as u8;
499
500 Ok(total_size)
501 }
502
503 #[must_use]
505 pub fn build_header(&self) -> Vec<u8> {
506 let header_size = self.header_size();
507 let mut buf = vec![0u8; header_size];
508
509 let payload = std::mem::take(&mut self.payload.clone());
511 let builder = Self {
512 payload: Vec::new(),
513 ..self.clone()
514 };
515 builder
516 .build_into(&mut buf)
517 .expect("buffer is correctly sized");
518
519 drop(payload);
521
522 buf
523 }
524}
525
526impl Ipv4Builder {
529 #[must_use]
531 pub fn icmp() -> Self {
532 Self::new().protocol(protocol::ICMP)
533 }
534
535 #[must_use]
537 pub fn tcp() -> Self {
538 Self::new().protocol(protocol::TCP)
539 }
540
541 #[must_use]
543 pub fn udp() -> Self {
544 Self::new().protocol(protocol::UDP)
545 }
546
547 #[must_use]
549 pub fn ipip() -> Self {
550 Self::new().protocol(protocol::IPV4)
551 }
552
553 #[must_use]
555 pub fn gre() -> Self {
556 Self::new().protocol(protocol::GRE)
557 }
558
559 #[must_use]
561 pub fn to(dst: Ipv4Addr) -> Self {
562 Self::new().dst(dst)
563 }
564
565 #[must_use]
567 pub fn from(src: Ipv4Addr) -> Self {
568 Self::new().src(src)
569 }
570}
571
572#[cfg(feature = "rand")]
575impl Ipv4Builder {
576 #[must_use]
578 pub fn random_id(mut self) -> Self {
579 use rand::Rng;
580 self.id = rand::rng().random();
581 self
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_basic_build() {
591 let pkt = Ipv4Builder::new()
592 .src(Ipv4Addr::new(192, 168, 1, 1))
593 .dst(Ipv4Addr::new(192, 168, 1, 2))
594 .ttl(64)
595 .protocol(protocol::TCP)
596 .build();
597
598 assert_eq!(pkt.len(), 20);
599
600 let layer = Ipv4Layer::at_offset(0);
601 assert_eq!(layer.version(&pkt).unwrap(), 4);
602 assert_eq!(layer.ihl(&pkt).unwrap(), 5);
603 assert_eq!(layer.ttl(&pkt).unwrap(), 64);
604 assert_eq!(layer.protocol(&pkt).unwrap(), protocol::TCP);
605 assert_eq!(layer.src(&pkt).unwrap(), Ipv4Addr::new(192, 168, 1, 1));
606 assert_eq!(layer.dst(&pkt).unwrap(), Ipv4Addr::new(192, 168, 1, 2));
607
608 assert!(layer.verify_checksum(&pkt).unwrap());
610 }
611
612 #[test]
613 fn test_with_payload() {
614 let payload = vec![1, 2, 3, 4, 5];
615 let pkt = Ipv4Builder::new()
616 .src(Ipv4Addr::new(10, 0, 0, 1))
617 .dst(Ipv4Addr::new(10, 0, 0, 2))
618 .protocol(protocol::UDP)
619 .payload(payload.clone())
620 .build();
621
622 assert_eq!(pkt.len(), 25); let layer = Ipv4Layer::at_offset(0);
625 assert_eq!(layer.total_len(&pkt).unwrap(), 25);
626 assert_eq!(layer.payload(&pkt).unwrap(), &payload[..]);
627 }
628
629 #[test]
630 fn test_with_options() {
631 let pkt = Ipv4Builder::new()
632 .src(Ipv4Addr::new(10, 0, 0, 1))
633 .dst(Ipv4Addr::new(10, 0, 0, 2))
634 .router_alert(0)
635 .build();
636
637 assert_eq!(pkt.len(), 24);
639
640 let layer = Ipv4Layer::at_offset(0);
641 assert_eq!(layer.ihl(&pkt).unwrap(), 6); assert!(layer.verify_checksum(&pkt).unwrap());
643 }
644
645 #[test]
646 fn test_flags() {
647 let pkt = Ipv4Builder::new()
648 .dst(Ipv4Addr::new(8, 8, 8, 8))
649 .dont_fragment()
650 .build();
651
652 let layer = Ipv4Layer::at_offset(0);
653 let flags = layer.flags(&pkt).unwrap();
654 assert!(flags.df);
655 assert!(!flags.mf);
656 }
657
658 #[test]
659 fn test_fragment() {
660 let pkt = Ipv4Builder::new()
661 .dst(Ipv4Addr::new(8, 8, 8, 8))
662 .more_fragments()
663 .frag_offset(100)
664 .build();
665
666 let layer = Ipv4Layer::at_offset(0);
667 let flags = layer.flags(&pkt).unwrap();
668 assert!(flags.mf);
669 assert_eq!(layer.frag_offset(&pkt).unwrap(), 100);
670 }
671
672 #[test]
673 fn test_dscp_ecn() {
674 let pkt = Ipv4Builder::new()
675 .dst(Ipv4Addr::new(8, 8, 8, 8))
676 .dscp(46) .ecn(2) .build();
679
680 let layer = Ipv4Layer::at_offset(0);
681 assert_eq!(layer.dscp(&pkt).unwrap(), 46);
682 assert_eq!(layer.ecn(&pkt).unwrap(), 2);
683 }
684
685 #[test]
686 fn test_from_bytes() {
687 let original = Ipv4Builder::new()
688 .src(Ipv4Addr::new(192, 168, 1, 100))
689 .dst(Ipv4Addr::new(192, 168, 1, 200))
690 .ttl(128)
691 .id(0xABCD)
692 .protocol(protocol::ICMP)
693 .payload(vec![8, 0, 0, 0, 0, 1, 0, 1]) .build();
695
696 let rebuilt = Ipv4Builder::from_bytes(&original)
697 .unwrap()
698 .auto_checksum(true)
699 .build();
700
701 assert_eq!(original.len(), rebuilt.len());
703
704 let layer = Ipv4Layer::at_offset(0);
705 assert_eq!(layer.src(&original).unwrap(), layer.src(&rebuilt).unwrap());
706 assert_eq!(layer.dst(&original).unwrap(), layer.dst(&rebuilt).unwrap());
707 assert_eq!(layer.ttl(&original).unwrap(), layer.ttl(&rebuilt).unwrap());
708 assert_eq!(layer.id(&original).unwrap(), layer.id(&rebuilt).unwrap());
709 }
710
711 #[test]
712 fn test_convenience_constructors() {
713 let icmp = Ipv4Builder::icmp().build();
714 let layer = Ipv4Layer::at_offset(0);
715 assert_eq!(layer.protocol(&icmp).unwrap(), protocol::ICMP);
716
717 let tcp = Ipv4Builder::tcp().build();
718 assert_eq!(layer.protocol(&tcp).unwrap(), protocol::TCP);
719
720 let udp = Ipv4Builder::udp().build();
721 assert_eq!(layer.protocol(&udp).unwrap(), protocol::UDP);
722 }
723
724 #[test]
725 fn test_manual_fields() {
726 let pkt = Ipv4Builder::new()
727 .dst(Ipv4Addr::new(8, 8, 8, 8))
728 .total_len(100)
729 .checksum(0x1234)
730 .ihl(5)
731 .build();
732
733 let layer = Ipv4Layer::at_offset(0);
734 assert_eq!(layer.total_len(&pkt).unwrap(), 100);
735 assert_eq!(layer.checksum(&pkt).unwrap(), 0x1234);
736 assert_eq!(layer.ihl(&pkt).unwrap(), 5);
737 }
738
739 #[test]
740 fn test_source_route_option() {
741 let route = vec![
742 Ipv4Addr::new(10, 0, 0, 1),
743 Ipv4Addr::new(10, 0, 0, 2),
744 Ipv4Addr::new(10, 0, 0, 3),
745 ];
746
747 let pkt = Ipv4Builder::new()
748 .dst(Ipv4Addr::new(10, 0, 0, 4))
749 .lsrr(route.clone())
750 .build();
751
752 let layer = Ipv4Layer::at_offset(0);
753 let options = layer.options(&pkt).unwrap();
754
755 let lsrr_option = options
758 .options
759 .iter()
760 .find(|opt| matches!(opt, Ipv4Option::Lsrr { .. }));
761
762 assert!(lsrr_option.is_some(), "Expected LSRR option");
763
764 if let Some(Ipv4Option::Lsrr {
765 route: parsed_route,
766 ..
767 }) = lsrr_option
768 {
769 assert_eq!(parsed_route, &route);
770 }
771 }
772}