stackforge_core/layer/ipv4/
fragmentation.rs1use std::net::Ipv4Addr;
7
8use crate::layer::field::FieldError;
9
10use super::checksum::ipv4_checksum;
11use super::header::{IPV4_MIN_HEADER_LEN, Ipv4Flags, Ipv4Layer, offsets};
12use super::options::Ipv4Options;
13
14pub const DEFAULT_MTU: usize = 1500;
16
17pub const MIN_FRAGMENT_PAYLOAD: usize = 8;
19
20pub const MAX_FRAGMENT_OFFSET: u16 = 0x1FFF;
22
23#[derive(Debug, Clone)]
25pub struct FragmentInfo {
26 pub offset: u32,
28 pub length: usize,
30 pub last: bool,
32 pub data: Vec<u8>,
34}
35
36impl FragmentInfo {
37 #[must_use]
39 pub fn end_offset(&self) -> u32 {
40 self.offset + self.length as u32
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct Fragment {
47 pub packet: Vec<u8>,
49 pub offset: u32,
51 pub last: bool,
53}
54
55#[derive(Debug, Clone)]
57pub struct Ipv4Fragmenter {
58 pub mtu: usize,
60 pub copy_options: bool,
62}
63
64impl Default for Ipv4Fragmenter {
65 fn default() -> Self {
66 Self {
67 mtu: DEFAULT_MTU,
68 copy_options: true,
69 }
70 }
71}
72
73impl Ipv4Fragmenter {
74 #[must_use]
76 pub fn new() -> Self {
77 Self::default()
78 }
79
80 #[must_use]
82 pub fn with_mtu(mtu: usize) -> Self {
83 Self {
84 mtu,
85 ..Self::default()
86 }
87 }
88
89 #[must_use]
91 pub fn mtu(mut self, mtu: usize) -> Self {
92 self.mtu = mtu;
93 self
94 }
95
96 #[must_use]
98 pub fn copy_options(mut self, copy: bool) -> Self {
99 self.copy_options = copy;
100 self
101 }
102
103 #[must_use]
105 pub fn needs_fragmentation(&self, packet: &[u8]) -> bool {
106 packet.len() > self.mtu
107 }
108
109 pub fn fragment(&self, packet: &[u8]) -> Result<Vec<Fragment>, FragmentError> {
114 let layer = Ipv4Layer::at_offset_dynamic(packet, 0)
115 .map_err(|e| FragmentError::ParseError(e.to_string()))?;
116
117 let flags = layer.flags(packet).unwrap_or(Ipv4Flags::NONE);
119 if flags.df && packet.len() > self.mtu {
120 return Err(FragmentError::DontFragmentSet {
121 packet_size: packet.len(),
122 mtu: self.mtu,
123 });
124 }
125
126 if packet.len() <= self.mtu {
128 return Ok(vec![Fragment {
129 packet: packet.to_vec(),
130 offset: 0,
131 last: true,
132 }]);
133 }
134
135 let header_len = layer.calculate_header_len(packet);
137 let total_len = layer.total_len(packet).unwrap_or(packet.len() as u16) as usize;
138 let _payload_start = header_len;
139 let payload_len = total_len.saturating_sub(header_len);
140
141 let options = if header_len > IPV4_MIN_HEADER_LEN {
143 layer.options(packet).ok()
144 } else {
145 None
146 };
147
148 let first_header_len = header_len; let other_header_len = if self.copy_options {
152 if let Some(ref opts) = options {
153 IPV4_MIN_HEADER_LEN + opts.copied_options().padded_len()
154 } else {
155 IPV4_MIN_HEADER_LEN
156 }
157 } else {
158 IPV4_MIN_HEADER_LEN
159 };
160
161 let first_payload_max = ((self.mtu - first_header_len) / 8) * 8;
162 let other_payload_max = ((self.mtu - other_header_len) / 8) * 8;
163
164 if first_payload_max < MIN_FRAGMENT_PAYLOAD || other_payload_max < MIN_FRAGMENT_PAYLOAD {
165 return Err(FragmentError::MtuTooSmall {
166 mtu: self.mtu,
167 min_required: other_header_len + MIN_FRAGMENT_PAYLOAD,
168 });
169 }
170
171 let mut fragments = Vec::new();
172 let mut offset: u32 = 0;
173 let mut remaining = payload_len;
174 let mut is_first = true;
175
176 let original_offset = u32::from(layer.frag_offset(packet).unwrap_or(0)) * 8;
178 let original_mf = flags.mf;
179
180 while remaining > 0 {
181 let _header_size = if is_first {
182 first_header_len
183 } else {
184 other_header_len
185 };
186 let max_payload = if is_first {
187 first_payload_max
188 } else {
189 other_payload_max
190 };
191
192 let frag_payload_len = remaining.min(max_payload);
193 let is_last = frag_payload_len == remaining && !original_mf;
194
195 let actual_payload_len = if is_last {
197 frag_payload_len
198 } else {
199 (frag_payload_len / 8) * 8
200 };
201
202 if actual_payload_len == 0 {
203 break;
204 }
205
206 let frag_packet = self.build_fragment(
208 packet,
209 &layer,
210 &options,
211 offset,
212 actual_payload_len,
213 !is_last,
214 is_first,
215 original_offset,
216 )?;
217
218 fragments.push(Fragment {
219 packet: frag_packet,
220 offset: original_offset + offset,
221 last: is_last,
222 });
223
224 offset += actual_payload_len as u32;
225 remaining -= actual_payload_len;
226 is_first = false;
227 }
228
229 Ok(fragments)
230 }
231
232 fn build_fragment(
234 &self,
235 original: &[u8],
236 layer: &Ipv4Layer,
237 options: &Option<Ipv4Options>,
238 offset: u32,
239 payload_len: usize,
240 more_fragments: bool,
241 is_first: bool,
242 original_offset: u32,
243 ) -> Result<Vec<u8>, FragmentError> {
244 let original_header_len = layer.calculate_header_len(original);
245
246 let frag_options = if is_first {
248 options.clone()
249 } else if self.copy_options {
250 options
251 .as_ref()
252 .map(super::options::Ipv4Options::copied_options)
253 } else {
254 None
255 };
256
257 let frag_header_len = if let Some(ref opts) = frag_options {
258 IPV4_MIN_HEADER_LEN + opts.padded_len()
259 } else {
260 IPV4_MIN_HEADER_LEN
261 };
262
263 let total_len = frag_header_len + payload_len;
264 let mut buf = vec![0u8; total_len];
265
266 buf[..IPV4_MIN_HEADER_LEN].copy_from_slice(&original[..IPV4_MIN_HEADER_LEN]);
268
269 let ihl = (frag_header_len / 4) as u8;
271 buf[offsets::VERSION_IHL] = (buf[offsets::VERSION_IHL] & 0xF0) | (ihl & 0x0F);
272
273 buf[offsets::TOTAL_LEN] = (total_len >> 8) as u8;
275 buf[offsets::TOTAL_LEN + 1] = (total_len & 0xFF) as u8;
276
277 let frag_offset_units = ((original_offset + offset) / 8) as u16;
279 let mut flags_byte = if more_fragments { 0x20 } else { 0x00 }; let orig_flags = layer.flags(original).unwrap_or(Ipv4Flags::NONE);
284 if orig_flags.reserved {
285 flags_byte |= 0x80;
286 }
287
288 let flags_frag = ((flags_byte as u16) << 8) | frag_offset_units;
289 buf[offsets::FLAGS_FRAG] = (flags_frag >> 8) as u8;
290 buf[offsets::FLAGS_FRAG + 1] = (flags_frag & 0xFF) as u8;
291
292 if let Some(ref opts) = frag_options {
294 let opts_bytes = opts.to_bytes();
295 buf[offsets::OPTIONS..offsets::OPTIONS + opts_bytes.len()].copy_from_slice(&opts_bytes);
296 }
297
298 let payload_start = original_header_len + offset as usize;
300 let payload_end = payload_start + payload_len;
301 if payload_end <= original.len() {
302 buf[frag_header_len..].copy_from_slice(&original[payload_start..payload_end]);
303 }
304
305 buf[offsets::CHECKSUM] = 0;
307 buf[offsets::CHECKSUM + 1] = 0;
308 let checksum = ipv4_checksum(&buf[..frag_header_len]);
309 buf[offsets::CHECKSUM] = (checksum >> 8) as u8;
310 buf[offsets::CHECKSUM + 1] = (checksum & 0xFF) as u8;
311
312 Ok(buf)
313 }
314}
315
316#[derive(Debug, Clone, PartialEq, Eq)]
318pub enum FragmentError {
319 DontFragmentSet { packet_size: usize, mtu: usize },
321 MtuTooSmall { mtu: usize, min_required: usize },
323 ParseError(String),
325}
326
327impl std::fmt::Display for FragmentError {
328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329 match self {
330 Self::DontFragmentSet { packet_size, mtu } => {
331 write!(
332 f,
333 "packet size {packet_size} exceeds MTU {mtu} but DF flag is set"
334 )
335 },
336 Self::MtuTooSmall { mtu, min_required } => {
337 write!(
338 f,
339 "MTU {mtu} is too small, minimum required is {min_required}"
340 )
341 },
342 Self::ParseError(msg) => write!(f, "parse error: {msg}"),
343 }
344 }
345}
346
347impl std::error::Error for FragmentError {}
348
349#[derive(Debug, Clone, PartialEq, Eq, Hash)]
351pub struct FragmentKey {
352 pub src: Ipv4Addr,
353 pub dst: Ipv4Addr,
354 pub id: u16,
355 pub protocol: u8,
356}
357
358impl FragmentKey {
359 pub fn from_packet(packet: &[u8]) -> Result<Self, FieldError> {
361 let layer = Ipv4Layer::at_offset_dynamic(packet, 0)?;
362 Ok(Self {
363 src: layer.src(packet)?,
364 dst: layer.dst(packet)?,
365 id: layer.id(packet)?,
366 protocol: layer.protocol(packet)?,
367 })
368 }
369}
370
371#[derive(Debug, Clone)]
373pub struct FragmentGroup {
374 pub key: FragmentKey,
376 pub fragments: Vec<FragmentInfo>,
378 pub total_length: Option<u32>,
380 pub first_header: Option<Vec<u8>>,
382 pub first_received: std::time::Instant,
384}
385
386impl FragmentGroup {
387 #[must_use]
389 pub fn new(key: FragmentKey) -> Self {
390 Self {
391 key,
392 fragments: Vec::new(),
393 total_length: None,
394 first_header: None,
395 first_received: std::time::Instant::now(),
396 }
397 }
398
399 pub fn add_fragment(&mut self, packet: &[u8]) -> Result<(), FieldError> {
401 let layer = Ipv4Layer::at_offset_dynamic(packet, 0)?;
402 let header_len = layer.calculate_header_len(packet);
403 let total_len = layer.total_len(packet)? as usize;
404 let flags = layer.flags(packet)?;
405 let offset = u32::from(layer.frag_offset(packet)?) * 8;
406 let payload_len = total_len.saturating_sub(header_len);
407
408 if offset == 0 {
410 self.first_header = Some(packet[..header_len].to_vec());
411 }
412
413 if !flags.mf {
415 self.total_length = Some(offset + payload_len as u32);
416 }
417
418 self.fragments.push(FragmentInfo {
420 offset,
421 length: payload_len,
422 last: !flags.mf,
423 data: packet.to_vec(),
424 });
425
426 Ok(())
427 }
428
429 #[must_use]
431 pub fn is_complete(&self) -> bool {
432 let total = match self.total_length {
433 Some(t) => t,
434 None => return false,
435 };
436
437 let mut sorted: Vec<_> = self.fragments.iter().collect();
439 sorted.sort_by_key(|f| f.offset);
440
441 let mut expected_offset = 0u32;
443 for frag in sorted {
444 if frag.offset != expected_offset {
445 return false;
446 }
447 expected_offset = frag.end_offset();
448 }
449
450 expected_offset >= total
451 }
452
453 pub fn reassemble(&self) -> Result<Vec<u8>, ReassemblyError> {
455 if !self.is_complete() {
456 return Err(ReassemblyError::Incomplete);
457 }
458
459 let total_length = self.total_length.ok_or(ReassemblyError::Incomplete)?;
460 let first_header = self
461 .first_header
462 .as_ref()
463 .ok_or(ReassemblyError::MissingFirstFragment)?;
464
465 let header_len = first_header.len();
466 let mut result = vec![0u8; header_len + total_length as usize];
467
468 result[..header_len].copy_from_slice(first_header);
470
471 let mut sorted: Vec<_> = self.fragments.iter().collect();
473 sorted.sort_by_key(|f| f.offset);
474
475 for frag in sorted {
476 let layer = Ipv4Layer::at_offset_dynamic(&frag.data, 0)
477 .map_err(|e| ReassemblyError::ParseError(e.to_string()))?;
478 let frag_header_len = layer.calculate_header_len(&frag.data);
479
480 let src_start = frag_header_len;
481 let src_end = src_start + frag.length;
482 let dst_start = header_len + frag.offset as usize;
483 let dst_end = dst_start + frag.length;
484
485 if src_end <= frag.data.len() && dst_end <= result.len() {
486 result[dst_start..dst_end].copy_from_slice(&frag.data[src_start..src_end]);
487 }
488 }
489
490 let new_total_len = (header_len + total_length as usize) as u16;
492 result[offsets::TOTAL_LEN] = (new_total_len >> 8) as u8;
493 result[offsets::TOTAL_LEN + 1] = (new_total_len & 0xFF) as u8;
494
495 result[offsets::FLAGS_FRAG] &= 0xC0; result[offsets::FLAGS_FRAG + 1] = 0;
498
499 result[offsets::CHECKSUM] = 0;
501 result[offsets::CHECKSUM + 1] = 0;
502 let checksum = ipv4_checksum(&result[..header_len]);
503 result[offsets::CHECKSUM] = (checksum >> 8) as u8;
504 result[offsets::CHECKSUM + 1] = (checksum & 0xFF) as u8;
505
506 Ok(result)
507 }
508}
509
510#[derive(Debug, Clone, PartialEq, Eq)]
512pub enum ReassemblyError {
513 Incomplete,
515 MissingFirstFragment,
517 Overlap,
519 ParseError(String),
521 Timeout,
523}
524
525impl std::fmt::Display for ReassemblyError {
526 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527 match self {
528 Self::Incomplete => write!(f, "not all fragments received"),
529 Self::MissingFirstFragment => write!(f, "first fragment not received"),
530 Self::Overlap => write!(f, "fragment overlap detected"),
531 Self::ParseError(msg) => write!(f, "parse error: {msg}"),
532 Self::Timeout => write!(f, "timeout waiting for fragments"),
533 }
534 }
535}
536
537impl std::error::Error for ReassemblyError {}
538
539pub fn reassemble_fragments(fragments: &[Vec<u8>]) -> Result<Vec<u8>, ReassemblyError> {
544 if fragments.is_empty() {
545 return Err(ReassemblyError::Incomplete);
546 }
547
548 let key = FragmentKey::from_packet(&fragments[0])
550 .map_err(|e| ReassemblyError::ParseError(e.to_string()))?;
551
552 let mut group = FragmentGroup::new(key);
553
554 for frag in fragments {
555 group
556 .add_fragment(frag)
557 .map_err(|e| ReassemblyError::ParseError(e.to_string()))?;
558 }
559
560 group.reassemble()
561}
562
563pub fn fragment_packet(packet: &[u8], mtu: usize) -> Result<Vec<Vec<u8>>, FragmentError> {
567 let fragmenter = Ipv4Fragmenter::with_mtu(mtu);
568 let fragments = fragmenter.fragment(packet)?;
569 Ok(fragments.into_iter().map(|f| f.packet).collect())
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use crate::Ipv4Builder;
576
577 fn build_large_packet(payload_size: usize) -> Vec<u8> {
578 Ipv4Builder::new()
579 .src(Ipv4Addr::new(192, 168, 1, 1))
580 .dst(Ipv4Addr::new(192, 168, 1, 2))
581 .id(0x1234)
582 .protocol(17) .payload(vec![0xAA; payload_size])
584 .build()
585 }
586
587 #[test]
588 fn test_no_fragmentation_needed() {
589 let packet = build_large_packet(100);
590 let fragmenter = Ipv4Fragmenter::with_mtu(1500);
591
592 assert!(!fragmenter.needs_fragmentation(&packet));
593
594 let frags = fragmenter.fragment(&packet).unwrap();
595 assert_eq!(frags.len(), 1);
596 assert!(frags[0].last);
597 assert_eq!(frags[0].offset, 0);
598 }
599
600 #[test]
601 fn test_basic_fragmentation() {
602 let packet = build_large_packet(3000);
603 let fragmenter = Ipv4Fragmenter::with_mtu(1500);
604
605 assert!(fragmenter.needs_fragmentation(&packet));
606
607 let frags = fragmenter.fragment(&packet).unwrap();
608 assert!(frags.len() >= 2);
609
610 assert_eq!(frags[0].offset, 0);
612 assert!(!frags[0].last);
613
614 assert!(frags.last().unwrap().last);
616
617 for frag in &frags {
619 assert!(frag.packet.len() <= 1500);
620 }
621 }
622
623 #[test]
624 fn test_dont_fragment_flag() {
625 let packet = Ipv4Builder::new()
626 .src(Ipv4Addr::new(192, 168, 1, 1))
627 .dst(Ipv4Addr::new(192, 168, 1, 2))
628 .dont_fragment()
629 .payload(vec![0; 2000])
630 .build();
631
632 let fragmenter = Ipv4Fragmenter::with_mtu(1500);
633 let result = fragmenter.fragment(&packet);
634
635 assert!(matches!(result, Err(FragmentError::DontFragmentSet { .. })));
636 }
637
638 #[test]
639 fn test_reassembly() {
640 let original = build_large_packet(3000);
641 let fragmenter = Ipv4Fragmenter::with_mtu(1000);
642
643 let frags = fragmenter.fragment(&original).unwrap();
644 let frag_packets: Vec<Vec<u8>> = frags.into_iter().map(|f| f.packet).collect();
645
646 let reassembled = reassemble_fragments(&frag_packets).unwrap();
647
648 let orig_layer = Ipv4Layer::at_offset(0);
650 let reasm_layer = Ipv4Layer::at_offset(0);
651
652 let orig_payload = orig_layer.payload(&original).unwrap();
653 let reasm_payload = reasm_layer.payload(&reassembled).unwrap();
654
655 assert_eq!(orig_payload, reasm_payload);
656 }
657
658 #[test]
659 fn test_fragment_key() {
660 let packet = build_large_packet(100);
661 let key = FragmentKey::from_packet(&packet).unwrap();
662
663 assert_eq!(key.src, Ipv4Addr::new(192, 168, 1, 1));
664 assert_eq!(key.dst, Ipv4Addr::new(192, 168, 1, 2));
665 assert_eq!(key.id, 0x1234);
666 assert_eq!(key.protocol, 17);
667 }
668
669 #[test]
670 fn test_fragment_group_complete() {
671 let packet = build_large_packet(2000);
672 let fragmenter = Ipv4Fragmenter::with_mtu(1000);
673
674 let frags = fragmenter.fragment(&packet).unwrap();
675 let key = FragmentKey::from_packet(&frags[0].packet).unwrap();
676
677 let mut group = FragmentGroup::new(key);
678
679 for frag in frags.iter().rev() {
681 group.add_fragment(&frag.packet).unwrap();
682 }
683
684 assert!(group.is_complete());
685 }
686
687 #[test]
688 fn test_fragment_group_incomplete() {
689 let packet = build_large_packet(2000);
690 let fragmenter = Ipv4Fragmenter::with_mtu(1000);
691
692 let frags = fragmenter.fragment(&packet).unwrap();
693 let key = FragmentKey::from_packet(&frags[0].packet).unwrap();
694
695 let mut group = FragmentGroup::new(key);
696
697 group.add_fragment(&frags[0].packet).unwrap();
699
700 assert!(!group.is_complete());
701 }
702
703 #[test]
704 fn test_small_mtu() {
705 let packet = build_large_packet(1000);
706 let fragmenter = Ipv4Fragmenter::with_mtu(100);
707
708 let frags = fragmenter.fragment(&packet).unwrap();
709
710 assert!(frags.len() > 10);
712
713 for frag in &frags {
715 assert!(frag.packet.len() <= 100);
716 }
717 }
718
719 #[test]
720 fn test_fragment_offset_alignment() {
721 let packet = build_large_packet(1000);
722 let fragmenter = Ipv4Fragmenter::with_mtu(500);
723
724 let frags = fragmenter.fragment(&packet).unwrap();
725
726 for frag in &frags[..frags.len() - 1] {
728 let layer = Ipv4Layer::at_offset(0);
729 let header_len = layer.calculate_header_len(&frag.packet);
730 let payload_len = frag.packet.len() - header_len;
731 assert_eq!(
732 payload_len % 8,
733 0,
734 "payload len {} not multiple of 8",
735 payload_len
736 );
737 }
738 }
739
740 #[test]
741 fn test_mtu_too_small() {
742 let packet = build_large_packet(100);
743 let fragmenter = Ipv4Fragmenter::with_mtu(20); let result = fragmenter.fragment(&packet);
746 assert!(matches!(result, Err(FragmentError::MtuTooSmall { .. })));
747 }
748}