1use std::collections::BTreeMap;
2use std::path::Path;
3
4use super::config::FlowConfig;
5use super::error::FlowError;
6use super::spill::ReassemblyStorage;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum ReassemblyAction {
11 DataReady(usize),
13 Buffered,
15 Duplicate,
17 OverlapTrimmed(usize),
19 Empty,
21}
22
23#[derive(Debug)]
29pub struct TcpReassembler {
30 segments: BTreeMap<u32, Vec<u8>>,
32 next_expected_seq: u32,
34 reassembled: ReassemblyStorage,
36 total_buffered: usize,
38 fragment_count: usize,
40 initialized: bool,
42}
43
44impl TcpReassembler {
45 #[must_use]
47 pub fn new() -> Self {
48 Self {
49 segments: BTreeMap::new(),
50 next_expected_seq: 0,
51 reassembled: ReassemblyStorage::new(),
52 total_buffered: 0,
53 fragment_count: 0,
54 initialized: false,
55 }
56 }
57
58 pub fn initialize(&mut self, initial_seq: u32) {
60 self.next_expected_seq = initial_seq;
61 self.initialized = true;
62 }
63
64 #[must_use]
66 pub fn is_initialized(&self) -> bool {
67 self.initialized
68 }
69
70 #[must_use]
75 pub fn reassembled_data(&self) -> &[u8] {
76 self.reassembled.as_slice().unwrap_or(&[])
77 }
78
79 pub fn read_reassembled(&self) -> std::io::Result<Vec<u8>> {
83 self.reassembled.read_all()
84 }
85
86 pub fn drain_reassembled(&mut self) -> std::io::Result<Vec<u8>> {
88 self.reassembled.drain()
89 }
90
91 #[must_use]
93 pub fn buffered_bytes(&self) -> usize {
94 self.total_buffered
95 }
96
97 #[must_use]
99 pub fn fragment_count(&self) -> usize {
100 self.fragment_count
101 }
102
103 #[must_use]
105 pub fn in_memory_bytes(&self) -> usize {
106 self.reassembled.in_memory_bytes() + self.total_buffered
107 }
108
109 pub fn spill(&mut self, spill_dir: Option<&Path>) -> std::io::Result<usize> {
113 self.reassembled.spill_to_disk(spill_dir)
114 }
115
116 #[must_use]
118 pub fn reassembled_len(&self) -> usize {
119 self.reassembled.len()
120 }
121
122 pub fn truncate_reassembled(&mut self, max_len: usize) {
124 self.reassembled.truncate(max_len);
125 }
126
127 #[must_use]
129 pub fn is_spilled(&self) -> bool {
130 self.reassembled.is_spilled()
131 }
132
133 pub fn process_segment(
138 &mut self,
139 seq: u32,
140 payload: &[u8],
141 config: &FlowConfig,
142 ) -> Result<ReassemblyAction, FlowError> {
143 if payload.is_empty() {
144 return Ok(ReassemblyAction::Empty);
145 }
146
147 if !self.initialized {
149 self.initialize(seq);
150 }
151
152 let seg_end = seq.wrapping_add(payload.len() as u32);
153
154 if self.seq_before_or_equal(seg_end, self.next_expected_seq) {
156 return Ok(ReassemblyAction::Duplicate);
157 }
158
159 if self.seq_before(seq, self.next_expected_seq) {
162 let overlap = self.next_expected_seq.wrapping_sub(seq) as usize;
163 if overlap >= payload.len() {
164 return Ok(ReassemblyAction::Duplicate);
165 }
166 let trimmed = &payload[overlap..];
167 self.reassembled.extend_from_slice(trimmed);
168 self.next_expected_seq = self.next_expected_seq.wrapping_add(trimmed.len() as u32);
169 self.try_drain_buffered();
170 return Ok(ReassemblyAction::OverlapTrimmed(trimmed.len()));
171 }
172
173 if seq == self.next_expected_seq {
175 self.reassembled.extend_from_slice(payload);
176 self.next_expected_seq = self.next_expected_seq.wrapping_add(payload.len() as u32);
177 self.try_drain_buffered();
178 return Ok(ReassemblyAction::DataReady(payload.len()));
179 }
180
181 if self.fragment_count >= config.max_ooo_fragments {
184 return Err(FlowError::TooManyFragments {
185 count: self.fragment_count,
186 limit: config.max_ooo_fragments,
187 });
188 }
189 if self.total_buffered + payload.len() > config.max_reassembly_buffer {
190 return Err(FlowError::ReassemblyBufferFull {
191 limit: config.max_reassembly_buffer,
192 });
193 }
194
195 self.segments.insert(seq, payload.to_vec());
196 self.total_buffered += payload.len();
197 self.fragment_count += 1;
198 Ok(ReassemblyAction::Buffered)
199 }
200
201 fn try_drain_buffered(&mut self) {
203 loop {
204 let key = {
205 let entry = self.segments.range(..=self.next_expected_seq).next_back();
206 match entry {
207 Some((&k, _)) => k,
208 None => break,
209 }
210 };
211
212 if let Some(data) = self.segments.remove(&key) {
213 let seg_end = key.wrapping_add(data.len() as u32);
214
215 self.total_buffered -= data.len();
216 self.fragment_count -= 1;
217
218 if self.seq_after(seg_end, self.next_expected_seq) {
219 if self.seq_before(key, self.next_expected_seq) {
220 let overlap = self.next_expected_seq.wrapping_sub(key) as usize;
221 if overlap < data.len() {
222 self.reassembled.extend_from_slice(&data[overlap..]);
223 self.next_expected_seq = seg_end;
224 }
225 } else {
226 self.reassembled.extend_from_slice(&data);
227 self.next_expected_seq = seg_end;
228 }
229 }
230 }
231 }
232 }
233
234 fn seq_before(&self, a: u32, b: u32) -> bool {
236 (a.wrapping_sub(b) as i32) < 0
237 }
238
239 fn seq_before_or_equal(&self, a: u32, b: u32) -> bool {
241 (a.wrapping_sub(b) as i32) <= 0
242 }
243
244 fn seq_after(&self, a: u32, b: u32) -> bool {
246 (a.wrapping_sub(b) as i32) > 0
247 }
248}
249
250impl Default for TcpReassembler {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 fn default_config() -> FlowConfig {
261 FlowConfig::default()
262 }
263
264 #[test]
265 fn test_in_order_reassembly() {
266 let config = default_config();
267 let mut r = TcpReassembler::new();
268 r.initialize(1000);
269
270 let action = r.process_segment(1000, b"hello", &config).unwrap();
271 assert_eq!(action, ReassemblyAction::DataReady(5));
272 assert_eq!(r.reassembled_data(), b"hello");
273 assert_eq!(r.next_expected_seq, 1005);
274
275 let action = r.process_segment(1005, b" world", &config).unwrap();
276 assert_eq!(action, ReassemblyAction::DataReady(6));
277 assert_eq!(r.reassembled_data(), b"hello world");
278 }
279
280 #[test]
281 fn test_out_of_order_then_fill_gap() {
282 let config = default_config();
283 let mut r = TcpReassembler::new();
284 r.initialize(1000);
285
286 let action = r.process_segment(1005, b" world", &config).unwrap();
287 assert_eq!(action, ReassemblyAction::Buffered);
288 assert_eq!(r.fragment_count(), 1);
289
290 let action = r.process_segment(1000, b"hello", &config).unwrap();
291 assert_eq!(action, ReassemblyAction::DataReady(5));
292 assert_eq!(r.reassembled_data(), b"hello world");
293 assert_eq!(r.fragment_count(), 0);
294 }
295
296 #[test]
297 fn test_total_duplicate() {
298 let config = default_config();
299 let mut r = TcpReassembler::new();
300 r.initialize(1000);
301
302 r.process_segment(1000, b"hello", &config).unwrap();
303 let action = r.process_segment(1000, b"hello", &config).unwrap();
304 assert_eq!(action, ReassemblyAction::Duplicate);
305 assert_eq!(r.reassembled_data(), b"hello");
306 }
307
308 #[test]
309 fn test_partial_overlap() {
310 let config = default_config();
311 let mut r = TcpReassembler::new();
312 r.initialize(1000);
313
314 r.process_segment(1000, b"hello", &config).unwrap();
315 let action = r.process_segment(1003, b"lo wo", &config).unwrap();
316 assert_eq!(action, ReassemblyAction::OverlapTrimmed(3));
317 assert_eq!(r.reassembled_data(), b"hello wo");
318 }
319
320 #[test]
321 fn test_empty_payload() {
322 let config = default_config();
323 let mut r = TcpReassembler::new();
324 r.initialize(1000);
325
326 let action = r.process_segment(1000, b"", &config).unwrap();
327 assert_eq!(action, ReassemblyAction::Empty);
328 }
329
330 #[test]
331 fn test_fragment_limit() {
332 let mut config = default_config();
333 config.max_ooo_fragments = 2;
334
335 let mut r = TcpReassembler::new();
336 r.initialize(1000);
337
338 r.process_segment(1010, b"a", &config).unwrap();
339 r.process_segment(1020, b"b", &config).unwrap();
340 let err = r.process_segment(1030, b"c", &config);
341 assert!(matches!(err, Err(FlowError::TooManyFragments { .. })));
342 }
343
344 #[test]
345 fn test_buffer_size_limit() {
346 let mut config = default_config();
347 config.max_reassembly_buffer = 10;
348
349 let mut r = TcpReassembler::new();
350 r.initialize(1000);
351
352 r.process_segment(1010, b"12345", &config).unwrap();
353 let err = r.process_segment(1020, b"123456", &config);
354 assert!(matches!(err, Err(FlowError::ReassemblyBufferFull { .. })));
355 }
356
357 #[test]
358 fn test_multiple_ooo_segments_drain() {
359 let config = default_config();
360 let mut r = TcpReassembler::new();
361 r.initialize(100);
362
363 r.process_segment(110, b"ccc", &config).unwrap();
364 r.process_segment(105, b"bbbbb", &config).unwrap();
365 assert_eq!(r.fragment_count(), 2);
366
367 r.process_segment(100, b"aaaaa", &config).unwrap();
368 assert_eq!(r.reassembled_data(), b"aaaaabbbbbccc");
369 assert_eq!(r.fragment_count(), 0);
370 }
371
372 #[test]
373 fn test_auto_initialize() {
374 let config = default_config();
375 let mut r = TcpReassembler::new();
376
377 let action = r.process_segment(5000, b"data", &config).unwrap();
378 assert_eq!(action, ReassemblyAction::DataReady(4));
379 assert!(r.is_initialized());
380 assert_eq!(r.reassembled_data(), b"data");
381 }
382
383 #[test]
384 fn test_drain_reassembled() {
385 let config = default_config();
386 let mut r = TcpReassembler::new();
387 r.initialize(0);
388
389 r.process_segment(0, b"hello", &config).unwrap();
390 let data = r.drain_reassembled().unwrap();
391 assert_eq!(data, b"hello");
392 assert!(r.reassembled_data().is_empty());
393 }
394
395 #[test]
396 fn test_spill_and_read() {
397 let config = default_config();
398 let mut r = TcpReassembler::new();
399 r.initialize(0);
400
401 r.process_segment(0, b"spill test data", &config).unwrap();
402 assert_eq!(r.in_memory_bytes(), 15);
403
404 let freed = r.spill(None).unwrap();
405 assert_eq!(freed, 15);
406 assert!(r.is_spilled());
407 assert_eq!(r.in_memory_bytes(), 0);
408
409 let data = r.read_reassembled().unwrap();
410 assert_eq!(data, b"spill test data");
411 }
412}