1use std::net::Ipv4Addr;
7
8use crate::layer::field::FieldError;
9
10use super::checksum::ipv4_checksum;
11use super::header::{IPV4_MIN_HEADER_LEN, Ipv4Flags, Ipv4Layer, offsets};
12use super::options::Ipv4Options;
13
14pub const DEFAULT_MTU: usize = 1500;
16
17pub const MIN_FRAGMENT_PAYLOAD: usize = 8;
19
20pub const MAX_FRAGMENT_OFFSET: u16 = 0x1FFF;
22
23#[derive(Debug, Clone)]
25pub struct FragmentInfo {
26 pub offset: u32,
28 pub length: usize,
30 pub last: bool,
32 pub data: Vec<u8>,
34}
35
36impl FragmentInfo {
37 pub fn end_offset(&self) -> u32 {
39 self.offset + self.length as u32
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct Fragment {
46 pub packet: Vec<u8>,
48 pub offset: u32,
50 pub last: bool,
52}
53
54#[derive(Debug, Clone)]
56pub struct Ipv4Fragmenter {
57 pub mtu: usize,
59 pub copy_options: bool,
61}
62
63impl Default for Ipv4Fragmenter {
64 fn default() -> Self {
65 Self {
66 mtu: DEFAULT_MTU,
67 copy_options: true,
68 }
69 }
70}
71
72impl Ipv4Fragmenter {
73 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn with_mtu(mtu: usize) -> Self {
80 Self {
81 mtu,
82 ..Self::default()
83 }
84 }
85
86 pub fn mtu(mut self, mtu: usize) -> Self {
88 self.mtu = mtu;
89 self
90 }
91
92 pub fn copy_options(mut self, copy: bool) -> Self {
94 self.copy_options = copy;
95 self
96 }
97
98 pub fn needs_fragmentation(&self, packet: &[u8]) -> bool {
100 packet.len() > self.mtu
101 }
102
103 pub fn fragment(&self, packet: &[u8]) -> Result<Vec<Fragment>, FragmentError> {
108 let layer = Ipv4Layer::at_offset_dynamic(packet, 0)
109 .map_err(|e| FragmentError::ParseError(e.to_string()))?;
110
111 let flags = layer.flags(packet).unwrap_or(Ipv4Flags::NONE);
113 if flags.df && packet.len() > self.mtu {
114 return Err(FragmentError::DontFragmentSet {
115 packet_size: packet.len(),
116 mtu: self.mtu,
117 });
118 }
119
120 if packet.len() <= self.mtu {
122 return Ok(vec![Fragment {
123 packet: packet.to_vec(),
124 offset: 0,
125 last: true,
126 }]);
127 }
128
129 let header_len = layer.calculate_header_len(packet);
131 let total_len = layer.total_len(packet).unwrap_or(packet.len() as u16) as usize;
132 let _payload_start = header_len;
133 let payload_len = total_len.saturating_sub(header_len);
134
135 let options = if header_len > IPV4_MIN_HEADER_LEN {
137 layer.options(packet).ok()
138 } else {
139 None
140 };
141
142 let first_header_len = header_len; let other_header_len = if self.copy_options {
146 if let Some(ref opts) = options {
147 IPV4_MIN_HEADER_LEN + opts.copied_options().padded_len()
148 } else {
149 IPV4_MIN_HEADER_LEN
150 }
151 } else {
152 IPV4_MIN_HEADER_LEN
153 };
154
155 let first_payload_max = ((self.mtu - first_header_len) / 8) * 8;
156 let other_payload_max = ((self.mtu - other_header_len) / 8) * 8;
157
158 if first_payload_max < MIN_FRAGMENT_PAYLOAD || other_payload_max < MIN_FRAGMENT_PAYLOAD {
159 return Err(FragmentError::MtuTooSmall {
160 mtu: self.mtu,
161 min_required: other_header_len + MIN_FRAGMENT_PAYLOAD,
162 });
163 }
164
165 let mut fragments = Vec::new();
166 let mut offset: u32 = 0;
167 let mut remaining = payload_len;
168 let mut is_first = true;
169
170 let original_offset = layer.frag_offset(packet).unwrap_or(0) as u32 * 8;
172 let original_mf = flags.mf;
173
174 while remaining > 0 {
175 let _header_size = if is_first {
176 first_header_len
177 } else {
178 other_header_len
179 };
180 let max_payload = if is_first {
181 first_payload_max
182 } else {
183 other_payload_max
184 };
185
186 let frag_payload_len = remaining.min(max_payload);
187 let is_last = frag_payload_len == remaining && !original_mf;
188
189 let actual_payload_len = if !is_last {
191 (frag_payload_len / 8) * 8
192 } else {
193 frag_payload_len
194 };
195
196 if actual_payload_len == 0 {
197 break;
198 }
199
200 let frag_packet = self.build_fragment(
202 packet,
203 &layer,
204 &options,
205 offset,
206 actual_payload_len,
207 !is_last,
208 is_first,
209 original_offset,
210 )?;
211
212 fragments.push(Fragment {
213 packet: frag_packet,
214 offset: original_offset + offset,
215 last: is_last,
216 });
217
218 offset += actual_payload_len as u32;
219 remaining -= actual_payload_len;
220 is_first = false;
221 }
222
223 Ok(fragments)
224 }
225
226 fn build_fragment(
228 &self,
229 original: &[u8],
230 layer: &Ipv4Layer,
231 options: &Option<Ipv4Options>,
232 offset: u32,
233 payload_len: usize,
234 more_fragments: bool,
235 is_first: bool,
236 original_offset: u32,
237 ) -> Result<Vec<u8>, FragmentError> {
238 let original_header_len = layer.calculate_header_len(original);
239
240 let frag_options = if is_first {
242 options.clone()
243 } else if self.copy_options {
244 options.as_ref().map(|o| o.copied_options())
245 } else {
246 None
247 };
248
249 let frag_header_len = if let Some(ref opts) = frag_options {
250 IPV4_MIN_HEADER_LEN + opts.padded_len()
251 } else {
252 IPV4_MIN_HEADER_LEN
253 };
254
255 let total_len = frag_header_len + payload_len;
256 let mut buf = vec![0u8; total_len];
257
258 buf[..IPV4_MIN_HEADER_LEN].copy_from_slice(&original[..IPV4_MIN_HEADER_LEN]);
260
261 let ihl = (frag_header_len / 4) as u8;
263 buf[offsets::VERSION_IHL] = (buf[offsets::VERSION_IHL] & 0xF0) | (ihl & 0x0F);
264
265 buf[offsets::TOTAL_LEN] = (total_len >> 8) as u8;
267 buf[offsets::TOTAL_LEN + 1] = (total_len & 0xFF) as u8;
268
269 let frag_offset_units = ((original_offset + offset) / 8) as u16;
271 let mut flags_byte = if more_fragments { 0x20 } else { 0x00 }; let orig_flags = layer.flags(original).unwrap_or(Ipv4Flags::NONE);
276 if orig_flags.reserved {
277 flags_byte |= 0x80;
278 }
279
280 let flags_frag = ((flags_byte as u16) << 8) | frag_offset_units;
281 buf[offsets::FLAGS_FRAG] = (flags_frag >> 8) as u8;
282 buf[offsets::FLAGS_FRAG + 1] = (flags_frag & 0xFF) as u8;
283
284 if let Some(ref opts) = frag_options {
286 let opts_bytes = opts.to_bytes();
287 buf[offsets::OPTIONS..offsets::OPTIONS + opts_bytes.len()].copy_from_slice(&opts_bytes);
288 }
289
290 let payload_start = original_header_len + offset as usize;
292 let payload_end = payload_start + payload_len;
293 if payload_end <= original.len() {
294 buf[frag_header_len..].copy_from_slice(&original[payload_start..payload_end]);
295 }
296
297 buf[offsets::CHECKSUM] = 0;
299 buf[offsets::CHECKSUM + 1] = 0;
300 let checksum = ipv4_checksum(&buf[..frag_header_len]);
301 buf[offsets::CHECKSUM] = (checksum >> 8) as u8;
302 buf[offsets::CHECKSUM + 1] = (checksum & 0xFF) as u8;
303
304 Ok(buf)
305 }
306}
307
308#[derive(Debug, Clone, PartialEq, Eq)]
310pub enum FragmentError {
311 DontFragmentSet { packet_size: usize, mtu: usize },
313 MtuTooSmall { mtu: usize, min_required: usize },
315 ParseError(String),
317}
318
319impl std::fmt::Display for FragmentError {
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 match self {
322 Self::DontFragmentSet { packet_size, mtu } => {
323 write!(
324 f,
325 "packet size {} exceeds MTU {} but DF flag is set",
326 packet_size, mtu
327 )
328 },
329 Self::MtuTooSmall { mtu, min_required } => {
330 write!(
331 f,
332 "MTU {} is too small, minimum required is {}",
333 mtu, min_required
334 )
335 },
336 Self::ParseError(msg) => write!(f, "parse error: {}", msg),
337 }
338 }
339}
340
341impl std::error::Error for FragmentError {}
342
343#[derive(Debug, Clone, PartialEq, Eq, Hash)]
345pub struct FragmentKey {
346 pub src: Ipv4Addr,
347 pub dst: Ipv4Addr,
348 pub id: u16,
349 pub protocol: u8,
350}
351
352impl FragmentKey {
353 pub fn from_packet(packet: &[u8]) -> Result<Self, FieldError> {
355 let layer = Ipv4Layer::at_offset_dynamic(packet, 0)?;
356 Ok(Self {
357 src: layer.src(packet)?,
358 dst: layer.dst(packet)?,
359 id: layer.id(packet)?,
360 protocol: layer.protocol(packet)?,
361 })
362 }
363}
364
365#[derive(Debug, Clone)]
367pub struct FragmentGroup {
368 pub key: FragmentKey,
370 pub fragments: Vec<FragmentInfo>,
372 pub total_length: Option<u32>,
374 pub first_header: Option<Vec<u8>>,
376 pub first_received: std::time::Instant,
378}
379
380impl FragmentGroup {
381 pub fn new(key: FragmentKey) -> Self {
383 Self {
384 key,
385 fragments: Vec::new(),
386 total_length: None,
387 first_header: None,
388 first_received: std::time::Instant::now(),
389 }
390 }
391
392 pub fn add_fragment(&mut self, packet: &[u8]) -> Result<(), FieldError> {
394 let layer = Ipv4Layer::at_offset_dynamic(packet, 0)?;
395 let header_len = layer.calculate_header_len(packet);
396 let total_len = layer.total_len(packet)? as usize;
397 let flags = layer.flags(packet)?;
398 let offset = layer.frag_offset(packet)? as u32 * 8;
399 let payload_len = total_len.saturating_sub(header_len);
400
401 if offset == 0 {
403 self.first_header = Some(packet[..header_len].to_vec());
404 }
405
406 if !flags.mf {
408 self.total_length = Some(offset + payload_len as u32);
409 }
410
411 self.fragments.push(FragmentInfo {
413 offset,
414 length: payload_len,
415 last: !flags.mf,
416 data: packet.to_vec(),
417 });
418
419 Ok(())
420 }
421
422 pub fn is_complete(&self) -> bool {
424 let total = match self.total_length {
425 Some(t) => t,
426 None => return false,
427 };
428
429 let mut sorted: Vec<_> = self.fragments.iter().collect();
431 sorted.sort_by_key(|f| f.offset);
432
433 let mut expected_offset = 0u32;
435 for frag in sorted {
436 if frag.offset != expected_offset {
437 return false;
438 }
439 expected_offset = frag.end_offset();
440 }
441
442 expected_offset >= total
443 }
444
445 pub fn reassemble(&self) -> Result<Vec<u8>, ReassemblyError> {
447 if !self.is_complete() {
448 return Err(ReassemblyError::Incomplete);
449 }
450
451 let total_length = self.total_length.ok_or(ReassemblyError::Incomplete)?;
452 let first_header = self
453 .first_header
454 .as_ref()
455 .ok_or(ReassemblyError::MissingFirstFragment)?;
456
457 let header_len = first_header.len();
458 let mut result = vec![0u8; header_len + total_length as usize];
459
460 result[..header_len].copy_from_slice(first_header);
462
463 let mut sorted: Vec<_> = self.fragments.iter().collect();
465 sorted.sort_by_key(|f| f.offset);
466
467 for frag in sorted {
468 let layer = Ipv4Layer::at_offset_dynamic(&frag.data, 0)
469 .map_err(|e| ReassemblyError::ParseError(e.to_string()))?;
470 let frag_header_len = layer.calculate_header_len(&frag.data);
471
472 let src_start = frag_header_len;
473 let src_end = src_start + frag.length;
474 let dst_start = header_len + frag.offset as usize;
475 let dst_end = dst_start + frag.length;
476
477 if src_end <= frag.data.len() && dst_end <= result.len() {
478 result[dst_start..dst_end].copy_from_slice(&frag.data[src_start..src_end]);
479 }
480 }
481
482 let new_total_len = (header_len + total_length as usize) as u16;
484 result[offsets::TOTAL_LEN] = (new_total_len >> 8) as u8;
485 result[offsets::TOTAL_LEN + 1] = (new_total_len & 0xFF) as u8;
486
487 result[offsets::FLAGS_FRAG] &= 0xC0; result[offsets::FLAGS_FRAG + 1] = 0;
490
491 result[offsets::CHECKSUM] = 0;
493 result[offsets::CHECKSUM + 1] = 0;
494 let checksum = ipv4_checksum(&result[..header_len]);
495 result[offsets::CHECKSUM] = (checksum >> 8) as u8;
496 result[offsets::CHECKSUM + 1] = (checksum & 0xFF) as u8;
497
498 Ok(result)
499 }
500}
501
502#[derive(Debug, Clone, PartialEq, Eq)]
504pub enum ReassemblyError {
505 Incomplete,
507 MissingFirstFragment,
509 Overlap,
511 ParseError(String),
513 Timeout,
515}
516
517impl std::fmt::Display for ReassemblyError {
518 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 match self {
520 Self::Incomplete => write!(f, "not all fragments received"),
521 Self::MissingFirstFragment => write!(f, "first fragment not received"),
522 Self::Overlap => write!(f, "fragment overlap detected"),
523 Self::ParseError(msg) => write!(f, "parse error: {}", msg),
524 Self::Timeout => write!(f, "timeout waiting for fragments"),
525 }
526 }
527}
528
529impl std::error::Error for ReassemblyError {}
530
531pub fn reassemble_fragments(fragments: &[Vec<u8>]) -> Result<Vec<u8>, ReassemblyError> {
536 if fragments.is_empty() {
537 return Err(ReassemblyError::Incomplete);
538 }
539
540 let key = FragmentKey::from_packet(&fragments[0])
542 .map_err(|e| ReassemblyError::ParseError(e.to_string()))?;
543
544 let mut group = FragmentGroup::new(key);
545
546 for frag in fragments {
547 group
548 .add_fragment(frag)
549 .map_err(|e| ReassemblyError::ParseError(e.to_string()))?;
550 }
551
552 group.reassemble()
553}
554
555pub fn fragment_packet(packet: &[u8], mtu: usize) -> Result<Vec<Vec<u8>>, FragmentError> {
559 let fragmenter = Ipv4Fragmenter::with_mtu(mtu);
560 let fragments = fragmenter.fragment(packet)?;
561 Ok(fragments.into_iter().map(|f| f.packet).collect())
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use crate::Ipv4Builder;
568
569 fn build_large_packet(payload_size: usize) -> Vec<u8> {
570 Ipv4Builder::new()
571 .src(Ipv4Addr::new(192, 168, 1, 1))
572 .dst(Ipv4Addr::new(192, 168, 1, 2))
573 .id(0x1234)
574 .protocol(17) .payload(vec![0xAA; payload_size])
576 .build()
577 }
578
579 #[test]
580 fn test_no_fragmentation_needed() {
581 let packet = build_large_packet(100);
582 let fragmenter = Ipv4Fragmenter::with_mtu(1500);
583
584 assert!(!fragmenter.needs_fragmentation(&packet));
585
586 let frags = fragmenter.fragment(&packet).unwrap();
587 assert_eq!(frags.len(), 1);
588 assert!(frags[0].last);
589 assert_eq!(frags[0].offset, 0);
590 }
591
592 #[test]
593 fn test_basic_fragmentation() {
594 let packet = build_large_packet(3000);
595 let fragmenter = Ipv4Fragmenter::with_mtu(1500);
596
597 assert!(fragmenter.needs_fragmentation(&packet));
598
599 let frags = fragmenter.fragment(&packet).unwrap();
600 assert!(frags.len() >= 2);
601
602 assert_eq!(frags[0].offset, 0);
604 assert!(!frags[0].last);
605
606 assert!(frags.last().unwrap().last);
608
609 for frag in &frags {
611 assert!(frag.packet.len() <= 1500);
612 }
613 }
614
615 #[test]
616 fn test_dont_fragment_flag() {
617 let packet = Ipv4Builder::new()
618 .src(Ipv4Addr::new(192, 168, 1, 1))
619 .dst(Ipv4Addr::new(192, 168, 1, 2))
620 .dont_fragment()
621 .payload(vec![0; 2000])
622 .build();
623
624 let fragmenter = Ipv4Fragmenter::with_mtu(1500);
625 let result = fragmenter.fragment(&packet);
626
627 assert!(matches!(result, Err(FragmentError::DontFragmentSet { .. })));
628 }
629
630 #[test]
631 fn test_reassembly() {
632 let original = build_large_packet(3000);
633 let fragmenter = Ipv4Fragmenter::with_mtu(1000);
634
635 let frags = fragmenter.fragment(&original).unwrap();
636 let frag_packets: Vec<Vec<u8>> = frags.into_iter().map(|f| f.packet).collect();
637
638 let reassembled = reassemble_fragments(&frag_packets).unwrap();
639
640 let orig_layer = Ipv4Layer::at_offset(0);
642 let reasm_layer = Ipv4Layer::at_offset(0);
643
644 let orig_payload = orig_layer.payload(&original).unwrap();
645 let reasm_payload = reasm_layer.payload(&reassembled).unwrap();
646
647 assert_eq!(orig_payload, reasm_payload);
648 }
649
650 #[test]
651 fn test_fragment_key() {
652 let packet = build_large_packet(100);
653 let key = FragmentKey::from_packet(&packet).unwrap();
654
655 assert_eq!(key.src, Ipv4Addr::new(192, 168, 1, 1));
656 assert_eq!(key.dst, Ipv4Addr::new(192, 168, 1, 2));
657 assert_eq!(key.id, 0x1234);
658 assert_eq!(key.protocol, 17);
659 }
660
661 #[test]
662 fn test_fragment_group_complete() {
663 let packet = build_large_packet(2000);
664 let fragmenter = Ipv4Fragmenter::with_mtu(1000);
665
666 let frags = fragmenter.fragment(&packet).unwrap();
667 let key = FragmentKey::from_packet(&frags[0].packet).unwrap();
668
669 let mut group = FragmentGroup::new(key);
670
671 for frag in frags.iter().rev() {
673 group.add_fragment(&frag.packet).unwrap();
674 }
675
676 assert!(group.is_complete());
677 }
678
679 #[test]
680 fn test_fragment_group_incomplete() {
681 let packet = build_large_packet(2000);
682 let fragmenter = Ipv4Fragmenter::with_mtu(1000);
683
684 let frags = fragmenter.fragment(&packet).unwrap();
685 let key = FragmentKey::from_packet(&frags[0].packet).unwrap();
686
687 let mut group = FragmentGroup::new(key);
688
689 group.add_fragment(&frags[0].packet).unwrap();
691
692 assert!(!group.is_complete());
693 }
694
695 #[test]
696 fn test_small_mtu() {
697 let packet = build_large_packet(1000);
698 let fragmenter = Ipv4Fragmenter::with_mtu(100);
699
700 let frags = fragmenter.fragment(&packet).unwrap();
701
702 assert!(frags.len() > 10);
704
705 for frag in &frags {
707 assert!(frag.packet.len() <= 100);
708 }
709 }
710
711 #[test]
712 fn test_fragment_offset_alignment() {
713 let packet = build_large_packet(1000);
714 let fragmenter = Ipv4Fragmenter::with_mtu(500);
715
716 let frags = fragmenter.fragment(&packet).unwrap();
717
718 for frag in &frags[..frags.len() - 1] {
720 let layer = Ipv4Layer::at_offset(0);
721 let header_len = layer.calculate_header_len(&frag.packet);
722 let payload_len = frag.packet.len() - header_len;
723 assert_eq!(
724 payload_len % 8,
725 0,
726 "payload len {} not multiple of 8",
727 payload_len
728 );
729 }
730 }
731
732 #[test]
733 fn test_mtu_too_small() {
734 let packet = build_large_packet(100);
735 let fragmenter = Ipv4Fragmenter::with_mtu(20); let result = fragmenter.fragment(&packet);
738 assert!(matches!(result, Err(FragmentError::MtuTooSmall { .. })));
739 }
740}