Skip to main content

stackforge_core/flow/
tcp_reassembly.rs

1use std::collections::BTreeMap;
2use std::path::Path;
3
4use super::config::FlowConfig;
5use super::error::FlowError;
6use super::spill::ReassemblyStorage;
7
8/// Result of processing a TCP segment through the reassembly engine.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum ReassemblyAction {
11    /// Segment was in-order and appended; value is bytes added to reassembled buffer.
12    DataReady(usize),
13    /// Segment was out-of-order and cached in the `BTreeMap`.
14    Buffered,
15    /// Segment was a total duplicate (already fully received).
16    Duplicate,
17    /// Segment had partial overlap; value is the trimmed bytes appended.
18    OverlapTrimmed(usize),
19    /// No payload in this segment.
20    Empty,
21}
22
23/// TCP stream reassembly engine using a `BTreeMap` for out-of-order segment management.
24///
25/// Mirrors Wireshark's reassemble.c logic: segments are keyed by absolute TCP
26/// sequence number. In-order segments are immediately appended to the contiguous
27/// reassembled buffer, while out-of-order segments are cached until gaps are filled.
28#[derive(Debug)]
29pub struct TcpReassembler {
30    /// Out-of-order segment cache: sequence number → payload.
31    segments: BTreeMap<u32, Vec<u8>>,
32    /// Next expected sequence number (advanced as data arrives in-order).
33    next_expected_seq: u32,
34    /// Contiguous reassembled byte stream (may be in memory or on disk).
35    reassembled: ReassemblyStorage,
36    /// Total bytes currently buffered in out-of-order cache.
37    total_buffered: usize,
38    /// Number of distinct out-of-order fragments.
39    fragment_count: usize,
40    /// Whether the reassembler has been initialized with an ISN.
41    initialized: bool,
42}
43
44impl TcpReassembler {
45    /// Create a new uninitialized reassembler.
46    #[must_use]
47    pub fn new() -> Self {
48        Self {
49            segments: BTreeMap::new(),
50            next_expected_seq: 0,
51            reassembled: ReassemblyStorage::new(),
52            total_buffered: 0,
53            fragment_count: 0,
54            initialized: false,
55        }
56    }
57
58    /// Initialize with the first observed sequence number (ISN + 1 for data after SYN).
59    pub fn initialize(&mut self, initial_seq: u32) {
60        self.next_expected_seq = initial_seq;
61        self.initialized = true;
62    }
63
64    /// Whether this reassembler has been initialized.
65    #[must_use]
66    pub fn is_initialized(&self) -> bool {
67        self.initialized
68    }
69
70    /// Get the contiguous reassembled data if it's in memory.
71    ///
72    /// Returns `None` if the data has been spilled to disk.
73    /// Use [`read_reassembled`] for guaranteed access regardless of storage location.
74    #[must_use]
75    pub fn reassembled_data(&self) -> &[u8] {
76        self.reassembled.as_slice().unwrap_or(&[])
77    }
78
79    /// Read the reassembled data regardless of storage location.
80    ///
81    /// Works for both in-memory and spilled-to-disk data.
82    pub fn read_reassembled(&self) -> std::io::Result<Vec<u8>> {
83        self.reassembled.read_all()
84    }
85
86    /// Drain and return the reassembled data, resetting the buffer.
87    pub fn drain_reassembled(&mut self) -> std::io::Result<Vec<u8>> {
88        self.reassembled.drain()
89    }
90
91    /// Total bytes in the out-of-order buffer.
92    #[must_use]
93    pub fn buffered_bytes(&self) -> usize {
94        self.total_buffered
95    }
96
97    /// Number of out-of-order fragments.
98    #[must_use]
99    pub fn fragment_count(&self) -> usize {
100        self.fragment_count
101    }
102
103    /// Total bytes currently held in memory (reassembled + OOO segments).
104    #[must_use]
105    pub fn in_memory_bytes(&self) -> usize {
106        self.reassembled.in_memory_bytes() + self.total_buffered
107    }
108
109    /// Spill reassembled data to a temporary file on disk.
110    ///
111    /// Returns the number of bytes freed from memory.
112    pub fn spill(&mut self, spill_dir: Option<&Path>) -> std::io::Result<usize> {
113        self.reassembled.spill_to_disk(spill_dir)
114    }
115
116    /// Total bytes in the contiguous reassembled stream.
117    #[must_use]
118    pub fn reassembled_len(&self) -> usize {
119        self.reassembled.len()
120    }
121
122    /// Truncate the reassembled stream to at most `max_len` bytes.
123    pub fn truncate_reassembled(&mut self, max_len: usize) {
124        self.reassembled.truncate(max_len);
125    }
126
127    /// Whether the reassembled data has been spilled to disk.
128    #[must_use]
129    pub fn is_spilled(&self) -> bool {
130        self.reassembled.is_spilled()
131    }
132
133    /// Process an incoming TCP segment.
134    ///
135    /// Handles in-order, out-of-order, overlapping, and duplicate segments
136    /// according to the algorithm described in the architectural blueprint.
137    pub fn process_segment(
138        &mut self,
139        seq: u32,
140        payload: &[u8],
141        config: &FlowConfig,
142    ) -> Result<ReassemblyAction, FlowError> {
143        if payload.is_empty() {
144            return Ok(ReassemblyAction::Empty);
145        }
146
147        // Auto-initialize on first data segment if not yet initialized
148        if !self.initialized {
149            self.initialize(seq);
150        }
151
152        let seg_end = seq.wrapping_add(payload.len() as u32);
153
154        // Case 1: Total duplicate — segment is entirely before next_expected_seq
155        if self.seq_before_or_equal(seg_end, self.next_expected_seq) {
156            return Ok(ReassemblyAction::Duplicate);
157        }
158
159        // Case 2: Partial overlap — segment starts before next_expected_seq
160        // but extends beyond it
161        if self.seq_before(seq, self.next_expected_seq) {
162            let overlap = self.next_expected_seq.wrapping_sub(seq) as usize;
163            if overlap >= payload.len() {
164                return Ok(ReassemblyAction::Duplicate);
165            }
166            let trimmed = &payload[overlap..];
167            self.reassembled.extend_from_slice(trimmed);
168            self.next_expected_seq = self.next_expected_seq.wrapping_add(trimmed.len() as u32);
169            self.try_drain_buffered();
170            return Ok(ReassemblyAction::OverlapTrimmed(trimmed.len()));
171        }
172
173        // Case 3: In-order arrival — seq == next_expected_seq
174        if seq == self.next_expected_seq {
175            self.reassembled.extend_from_slice(payload);
176            self.next_expected_seq = self.next_expected_seq.wrapping_add(payload.len() as u32);
177            self.try_drain_buffered();
178            return Ok(ReassemblyAction::DataReady(payload.len()));
179        }
180
181        // Case 4: Out-of-order — seq > next_expected_seq (gap exists)
182        // Check limits before buffering
183        if self.fragment_count >= config.max_ooo_fragments {
184            return Err(FlowError::TooManyFragments {
185                count: self.fragment_count,
186                limit: config.max_ooo_fragments,
187            });
188        }
189        if self.total_buffered + payload.len() > config.max_reassembly_buffer {
190            return Err(FlowError::ReassemblyBufferFull {
191                limit: config.max_reassembly_buffer,
192            });
193        }
194
195        self.segments.insert(seq, payload.to_vec());
196        self.total_buffered += payload.len();
197        self.fragment_count += 1;
198        Ok(ReassemblyAction::Buffered)
199    }
200
201    /// Drain contiguous segments from the `BTreeMap` that can now be appended.
202    fn try_drain_buffered(&mut self) {
203        loop {
204            let key = {
205                let entry = self.segments.range(..=self.next_expected_seq).next_back();
206                match entry {
207                    Some((&k, _)) => k,
208                    None => break,
209                }
210            };
211
212            if let Some(data) = self.segments.remove(&key) {
213                let seg_end = key.wrapping_add(data.len() as u32);
214
215                self.total_buffered -= data.len();
216                self.fragment_count -= 1;
217
218                if self.seq_after(seg_end, self.next_expected_seq) {
219                    if self.seq_before(key, self.next_expected_seq) {
220                        let overlap = self.next_expected_seq.wrapping_sub(key) as usize;
221                        if overlap < data.len() {
222                            self.reassembled.extend_from_slice(&data[overlap..]);
223                            self.next_expected_seq = seg_end;
224                        }
225                    } else {
226                        self.reassembled.extend_from_slice(&data);
227                        self.next_expected_seq = seg_end;
228                    }
229                }
230            }
231        }
232    }
233
234    /// Check if `a` is strictly before `b` in the sequence space (handles wrapping).
235    fn seq_before(&self, a: u32, b: u32) -> bool {
236        (a.wrapping_sub(b) as i32) < 0
237    }
238
239    /// Check if `a` is before or equal to `b` in the sequence space.
240    fn seq_before_or_equal(&self, a: u32, b: u32) -> bool {
241        (a.wrapping_sub(b) as i32) <= 0
242    }
243
244    /// Check if `a` is strictly after `b` in the sequence space.
245    fn seq_after(&self, a: u32, b: u32) -> bool {
246        (a.wrapping_sub(b) as i32) > 0
247    }
248}
249
250impl Default for TcpReassembler {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    fn default_config() -> FlowConfig {
261        FlowConfig::default()
262    }
263
264    #[test]
265    fn test_in_order_reassembly() {
266        let config = default_config();
267        let mut r = TcpReassembler::new();
268        r.initialize(1000);
269
270        let action = r.process_segment(1000, b"hello", &config).unwrap();
271        assert_eq!(action, ReassemblyAction::DataReady(5));
272        assert_eq!(r.reassembled_data(), b"hello");
273        assert_eq!(r.next_expected_seq, 1005);
274
275        let action = r.process_segment(1005, b" world", &config).unwrap();
276        assert_eq!(action, ReassemblyAction::DataReady(6));
277        assert_eq!(r.reassembled_data(), b"hello world");
278    }
279
280    #[test]
281    fn test_out_of_order_then_fill_gap() {
282        let config = default_config();
283        let mut r = TcpReassembler::new();
284        r.initialize(1000);
285
286        let action = r.process_segment(1005, b" world", &config).unwrap();
287        assert_eq!(action, ReassemblyAction::Buffered);
288        assert_eq!(r.fragment_count(), 1);
289
290        let action = r.process_segment(1000, b"hello", &config).unwrap();
291        assert_eq!(action, ReassemblyAction::DataReady(5));
292        assert_eq!(r.reassembled_data(), b"hello world");
293        assert_eq!(r.fragment_count(), 0);
294    }
295
296    #[test]
297    fn test_total_duplicate() {
298        let config = default_config();
299        let mut r = TcpReassembler::new();
300        r.initialize(1000);
301
302        r.process_segment(1000, b"hello", &config).unwrap();
303        let action = r.process_segment(1000, b"hello", &config).unwrap();
304        assert_eq!(action, ReassemblyAction::Duplicate);
305        assert_eq!(r.reassembled_data(), b"hello");
306    }
307
308    #[test]
309    fn test_partial_overlap() {
310        let config = default_config();
311        let mut r = TcpReassembler::new();
312        r.initialize(1000);
313
314        r.process_segment(1000, b"hello", &config).unwrap();
315        let action = r.process_segment(1003, b"lo wo", &config).unwrap();
316        assert_eq!(action, ReassemblyAction::OverlapTrimmed(3));
317        assert_eq!(r.reassembled_data(), b"hello wo");
318    }
319
320    #[test]
321    fn test_empty_payload() {
322        let config = default_config();
323        let mut r = TcpReassembler::new();
324        r.initialize(1000);
325
326        let action = r.process_segment(1000, b"", &config).unwrap();
327        assert_eq!(action, ReassemblyAction::Empty);
328    }
329
330    #[test]
331    fn test_fragment_limit() {
332        let mut config = default_config();
333        config.max_ooo_fragments = 2;
334
335        let mut r = TcpReassembler::new();
336        r.initialize(1000);
337
338        r.process_segment(1010, b"a", &config).unwrap();
339        r.process_segment(1020, b"b", &config).unwrap();
340        let err = r.process_segment(1030, b"c", &config);
341        assert!(matches!(err, Err(FlowError::TooManyFragments { .. })));
342    }
343
344    #[test]
345    fn test_buffer_size_limit() {
346        let mut config = default_config();
347        config.max_reassembly_buffer = 10;
348
349        let mut r = TcpReassembler::new();
350        r.initialize(1000);
351
352        r.process_segment(1010, b"12345", &config).unwrap();
353        let err = r.process_segment(1020, b"123456", &config);
354        assert!(matches!(err, Err(FlowError::ReassemblyBufferFull { .. })));
355    }
356
357    #[test]
358    fn test_multiple_ooo_segments_drain() {
359        let config = default_config();
360        let mut r = TcpReassembler::new();
361        r.initialize(100);
362
363        r.process_segment(110, b"ccc", &config).unwrap();
364        r.process_segment(105, b"bbbbb", &config).unwrap();
365        assert_eq!(r.fragment_count(), 2);
366
367        r.process_segment(100, b"aaaaa", &config).unwrap();
368        assert_eq!(r.reassembled_data(), b"aaaaabbbbbccc");
369        assert_eq!(r.fragment_count(), 0);
370    }
371
372    #[test]
373    fn test_auto_initialize() {
374        let config = default_config();
375        let mut r = TcpReassembler::new();
376
377        let action = r.process_segment(5000, b"data", &config).unwrap();
378        assert_eq!(action, ReassemblyAction::DataReady(4));
379        assert!(r.is_initialized());
380        assert_eq!(r.reassembled_data(), b"data");
381    }
382
383    #[test]
384    fn test_drain_reassembled() {
385        let config = default_config();
386        let mut r = TcpReassembler::new();
387        r.initialize(0);
388
389        r.process_segment(0, b"hello", &config).unwrap();
390        let data = r.drain_reassembled().unwrap();
391        assert_eq!(data, b"hello");
392        assert!(r.reassembled_data().is_empty());
393    }
394
395    #[test]
396    fn test_spill_and_read() {
397        let config = default_config();
398        let mut r = TcpReassembler::new();
399        r.initialize(0);
400
401        r.process_segment(0, b"spill test data", &config).unwrap();
402        assert_eq!(r.in_memory_bytes(), 15);
403
404        let freed = r.spill(None).unwrap();
405        assert_eq!(freed, 15);
406        assert!(r.is_spilled());
407        assert_eq!(r.in_memory_bytes(), 0);
408
409        let data = r.read_reassembled().unwrap();
410        assert_eq!(data, b"spill test data");
411    }
412}