Skip to main content

stackforge_core/flow/
tcp_reassembly.rs

1use std::collections::BTreeMap;
2use std::path::Path;
3
4use super::config::FlowConfig;
5use super::error::FlowError;
6use super::spill::ReassemblyStorage;
7
8/// Result of processing a TCP segment through the reassembly engine.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum ReassemblyAction {
11    /// Segment was in-order and appended; value is bytes added to reassembled buffer.
12    DataReady(usize),
13    /// Segment was out-of-order and cached in the `BTreeMap`.
14    Buffered,
15    /// Segment was a total duplicate (already fully received).
16    Duplicate,
17    /// Segment had partial overlap; value is the trimmed bytes appended.
18    OverlapTrimmed(usize),
19    /// No payload in this segment.
20    Empty,
21}
22
23/// TCP stream reassembly engine using a `BTreeMap` for out-of-order segment management.
24///
25/// Mirrors Wireshark's reassemble.c logic: segments are keyed by absolute TCP
26/// sequence number. In-order segments are immediately appended to the contiguous
27/// reassembled buffer, while out-of-order segments are cached until gaps are filled.
28#[derive(Debug)]
29pub struct TcpReassembler {
30    /// Out-of-order segment cache: sequence number → payload.
31    segments: BTreeMap<u32, Vec<u8>>,
32    /// Next expected sequence number (advanced as data arrives in-order).
33    next_expected_seq: u32,
34    /// Contiguous reassembled byte stream (may be in memory or on disk).
35    reassembled: ReassemblyStorage,
36    /// Total bytes currently buffered in out-of-order cache.
37    total_buffered: usize,
38    /// Number of distinct out-of-order fragments.
39    fragment_count: usize,
40    /// Whether the reassembler has been initialized with an ISN.
41    initialized: bool,
42}
43
44impl TcpReassembler {
45    /// Create a new uninitialized reassembler.
46    #[must_use]
47    pub fn new() -> Self {
48        Self {
49            segments: BTreeMap::new(),
50            next_expected_seq: 0,
51            reassembled: ReassemblyStorage::new(),
52            total_buffered: 0,
53            fragment_count: 0,
54            initialized: false,
55        }
56    }
57
58    /// Initialize with the first observed sequence number (ISN + 1 for data after SYN).
59    pub fn initialize(&mut self, initial_seq: u32) {
60        self.next_expected_seq = initial_seq;
61        self.initialized = true;
62    }
63
64    /// Whether this reassembler has been initialized.
65    #[must_use]
66    pub fn is_initialized(&self) -> bool {
67        self.initialized
68    }
69
70    /// Get the contiguous reassembled data if it's in memory.
71    ///
72    /// Returns `None` if the data has been spilled to disk.
73    /// Use [`read_reassembled`] for guaranteed access regardless of storage location.
74    #[must_use]
75    pub fn reassembled_data(&self) -> &[u8] {
76        self.reassembled.as_slice().unwrap_or(&[])
77    }
78
79    /// Read the reassembled data regardless of storage location.
80    ///
81    /// Works for both in-memory and spilled-to-disk data.
82    pub fn read_reassembled(&self) -> std::io::Result<Vec<u8>> {
83        self.reassembled.read_all()
84    }
85
86    /// Drain and return the reassembled data, resetting the buffer.
87    pub fn drain_reassembled(&mut self) -> std::io::Result<Vec<u8>> {
88        self.reassembled.drain()
89    }
90
91    /// Total bytes in the out-of-order buffer.
92    #[must_use]
93    pub fn buffered_bytes(&self) -> usize {
94        self.total_buffered
95    }
96
97    /// Number of out-of-order fragments.
98    #[must_use]
99    pub fn fragment_count(&self) -> usize {
100        self.fragment_count
101    }
102
103    /// Total bytes currently held in memory (reassembled + OOO segments).
104    #[must_use]
105    pub fn in_memory_bytes(&self) -> usize {
106        self.reassembled.in_memory_bytes() + self.total_buffered
107    }
108
109    /// Spill reassembled data to a temporary file on disk.
110    ///
111    /// Returns the number of bytes freed from memory.
112    pub fn spill(&mut self, spill_dir: Option<&Path>) -> std::io::Result<usize> {
113        self.reassembled.spill_to_disk(spill_dir)
114    }
115
116    /// Whether the reassembled data has been spilled to disk.
117    #[must_use]
118    pub fn is_spilled(&self) -> bool {
119        self.reassembled.is_spilled()
120    }
121
122    /// Process an incoming TCP segment.
123    ///
124    /// Handles in-order, out-of-order, overlapping, and duplicate segments
125    /// according to the algorithm described in the architectural blueprint.
126    pub fn process_segment(
127        &mut self,
128        seq: u32,
129        payload: &[u8],
130        config: &FlowConfig,
131    ) -> Result<ReassemblyAction, FlowError> {
132        if payload.is_empty() {
133            return Ok(ReassemblyAction::Empty);
134        }
135
136        // Auto-initialize on first data segment if not yet initialized
137        if !self.initialized {
138            self.initialize(seq);
139        }
140
141        let seg_end = seq.wrapping_add(payload.len() as u32);
142
143        // Case 1: Total duplicate — segment is entirely before next_expected_seq
144        if self.seq_before_or_equal(seg_end, self.next_expected_seq) {
145            return Ok(ReassemblyAction::Duplicate);
146        }
147
148        // Case 2: Partial overlap — segment starts before next_expected_seq
149        // but extends beyond it
150        if self.seq_before(seq, self.next_expected_seq) {
151            let overlap = self.next_expected_seq.wrapping_sub(seq) as usize;
152            if overlap >= payload.len() {
153                return Ok(ReassemblyAction::Duplicate);
154            }
155            let trimmed = &payload[overlap..];
156            self.reassembled.extend_from_slice(trimmed);
157            self.next_expected_seq = self.next_expected_seq.wrapping_add(trimmed.len() as u32);
158            self.try_drain_buffered();
159            return Ok(ReassemblyAction::OverlapTrimmed(trimmed.len()));
160        }
161
162        // Case 3: In-order arrival — seq == next_expected_seq
163        if seq == self.next_expected_seq {
164            self.reassembled.extend_from_slice(payload);
165            self.next_expected_seq = self.next_expected_seq.wrapping_add(payload.len() as u32);
166            self.try_drain_buffered();
167            return Ok(ReassemblyAction::DataReady(payload.len()));
168        }
169
170        // Case 4: Out-of-order — seq > next_expected_seq (gap exists)
171        // Check limits before buffering
172        if self.fragment_count >= config.max_ooo_fragments {
173            return Err(FlowError::TooManyFragments {
174                count: self.fragment_count,
175                limit: config.max_ooo_fragments,
176            });
177        }
178        if self.total_buffered + payload.len() > config.max_reassembly_buffer {
179            return Err(FlowError::ReassemblyBufferFull {
180                limit: config.max_reassembly_buffer,
181            });
182        }
183
184        self.segments.insert(seq, payload.to_vec());
185        self.total_buffered += payload.len();
186        self.fragment_count += 1;
187        Ok(ReassemblyAction::Buffered)
188    }
189
190    /// Drain contiguous segments from the `BTreeMap` that can now be appended.
191    fn try_drain_buffered(&mut self) {
192        loop {
193            let key = {
194                let entry = self.segments.range(..=self.next_expected_seq).next_back();
195                match entry {
196                    Some((&k, _)) => k,
197                    None => break,
198                }
199            };
200
201            if let Some(data) = self.segments.remove(&key) {
202                let seg_end = key.wrapping_add(data.len() as u32);
203
204                self.total_buffered -= data.len();
205                self.fragment_count -= 1;
206
207                if self.seq_after(seg_end, self.next_expected_seq) {
208                    if self.seq_before(key, self.next_expected_seq) {
209                        let overlap = self.next_expected_seq.wrapping_sub(key) as usize;
210                        if overlap < data.len() {
211                            self.reassembled.extend_from_slice(&data[overlap..]);
212                            self.next_expected_seq = seg_end;
213                        }
214                    } else {
215                        self.reassembled.extend_from_slice(&data);
216                        self.next_expected_seq = seg_end;
217                    }
218                }
219            }
220        }
221    }
222
223    /// Check if `a` is strictly before `b` in the sequence space (handles wrapping).
224    fn seq_before(&self, a: u32, b: u32) -> bool {
225        (a.wrapping_sub(b) as i32) < 0
226    }
227
228    /// Check if `a` is before or equal to `b` in the sequence space.
229    fn seq_before_or_equal(&self, a: u32, b: u32) -> bool {
230        (a.wrapping_sub(b) as i32) <= 0
231    }
232
233    /// Check if `a` is strictly after `b` in the sequence space.
234    fn seq_after(&self, a: u32, b: u32) -> bool {
235        (a.wrapping_sub(b) as i32) > 0
236    }
237}
238
239impl Default for TcpReassembler {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    fn default_config() -> FlowConfig {
250        FlowConfig::default()
251    }
252
253    #[test]
254    fn test_in_order_reassembly() {
255        let config = default_config();
256        let mut r = TcpReassembler::new();
257        r.initialize(1000);
258
259        let action = r.process_segment(1000, b"hello", &config).unwrap();
260        assert_eq!(action, ReassemblyAction::DataReady(5));
261        assert_eq!(r.reassembled_data(), b"hello");
262        assert_eq!(r.next_expected_seq, 1005);
263
264        let action = r.process_segment(1005, b" world", &config).unwrap();
265        assert_eq!(action, ReassemblyAction::DataReady(6));
266        assert_eq!(r.reassembled_data(), b"hello world");
267    }
268
269    #[test]
270    fn test_out_of_order_then_fill_gap() {
271        let config = default_config();
272        let mut r = TcpReassembler::new();
273        r.initialize(1000);
274
275        let action = r.process_segment(1005, b" world", &config).unwrap();
276        assert_eq!(action, ReassemblyAction::Buffered);
277        assert_eq!(r.fragment_count(), 1);
278
279        let action = r.process_segment(1000, b"hello", &config).unwrap();
280        assert_eq!(action, ReassemblyAction::DataReady(5));
281        assert_eq!(r.reassembled_data(), b"hello world");
282        assert_eq!(r.fragment_count(), 0);
283    }
284
285    #[test]
286    fn test_total_duplicate() {
287        let config = default_config();
288        let mut r = TcpReassembler::new();
289        r.initialize(1000);
290
291        r.process_segment(1000, b"hello", &config).unwrap();
292        let action = r.process_segment(1000, b"hello", &config).unwrap();
293        assert_eq!(action, ReassemblyAction::Duplicate);
294        assert_eq!(r.reassembled_data(), b"hello");
295    }
296
297    #[test]
298    fn test_partial_overlap() {
299        let config = default_config();
300        let mut r = TcpReassembler::new();
301        r.initialize(1000);
302
303        r.process_segment(1000, b"hello", &config).unwrap();
304        let action = r.process_segment(1003, b"lo wo", &config).unwrap();
305        assert_eq!(action, ReassemblyAction::OverlapTrimmed(3));
306        assert_eq!(r.reassembled_data(), b"hello wo");
307    }
308
309    #[test]
310    fn test_empty_payload() {
311        let config = default_config();
312        let mut r = TcpReassembler::new();
313        r.initialize(1000);
314
315        let action = r.process_segment(1000, b"", &config).unwrap();
316        assert_eq!(action, ReassemblyAction::Empty);
317    }
318
319    #[test]
320    fn test_fragment_limit() {
321        let mut config = default_config();
322        config.max_ooo_fragments = 2;
323
324        let mut r = TcpReassembler::new();
325        r.initialize(1000);
326
327        r.process_segment(1010, b"a", &config).unwrap();
328        r.process_segment(1020, b"b", &config).unwrap();
329        let err = r.process_segment(1030, b"c", &config);
330        assert!(matches!(err, Err(FlowError::TooManyFragments { .. })));
331    }
332
333    #[test]
334    fn test_buffer_size_limit() {
335        let mut config = default_config();
336        config.max_reassembly_buffer = 10;
337
338        let mut r = TcpReassembler::new();
339        r.initialize(1000);
340
341        r.process_segment(1010, b"12345", &config).unwrap();
342        let err = r.process_segment(1020, b"123456", &config);
343        assert!(matches!(err, Err(FlowError::ReassemblyBufferFull { .. })));
344    }
345
346    #[test]
347    fn test_multiple_ooo_segments_drain() {
348        let config = default_config();
349        let mut r = TcpReassembler::new();
350        r.initialize(100);
351
352        r.process_segment(110, b"ccc", &config).unwrap();
353        r.process_segment(105, b"bbbbb", &config).unwrap();
354        assert_eq!(r.fragment_count(), 2);
355
356        r.process_segment(100, b"aaaaa", &config).unwrap();
357        assert_eq!(r.reassembled_data(), b"aaaaabbbbbccc");
358        assert_eq!(r.fragment_count(), 0);
359    }
360
361    #[test]
362    fn test_auto_initialize() {
363        let config = default_config();
364        let mut r = TcpReassembler::new();
365
366        let action = r.process_segment(5000, b"data", &config).unwrap();
367        assert_eq!(action, ReassemblyAction::DataReady(4));
368        assert!(r.is_initialized());
369        assert_eq!(r.reassembled_data(), b"data");
370    }
371
372    #[test]
373    fn test_drain_reassembled() {
374        let config = default_config();
375        let mut r = TcpReassembler::new();
376        r.initialize(0);
377
378        r.process_segment(0, b"hello", &config).unwrap();
379        let data = r.drain_reassembled().unwrap();
380        assert_eq!(data, b"hello");
381        assert!(r.reassembled_data().is_empty());
382    }
383
384    #[test]
385    fn test_spill_and_read() {
386        let config = default_config();
387        let mut r = TcpReassembler::new();
388        r.initialize(0);
389
390        r.process_segment(0, b"spill test data", &config).unwrap();
391        assert_eq!(r.in_memory_bytes(), 15);
392
393        let freed = r.spill(None).unwrap();
394        assert_eq!(freed, 15);
395        assert!(r.is_spilled());
396        assert_eq!(r.in_memory_bytes(), 0);
397
398        let data = r.read_reassembled().unwrap();
399        assert_eq!(data, b"spill test data");
400    }
401}