1use std::collections::BTreeMap;
2use std::path::Path;
3
4use super::config::FlowConfig;
5use super::error::FlowError;
6use super::spill::ReassemblyStorage;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum ReassemblyAction {
11 DataReady(usize),
13 Buffered,
15 Duplicate,
17 OverlapTrimmed(usize),
19 Empty,
21}
22
23#[derive(Debug)]
29pub struct TcpReassembler {
30 segments: BTreeMap<u32, Vec<u8>>,
32 next_expected_seq: u32,
34 reassembled: ReassemblyStorage,
36 total_buffered: usize,
38 fragment_count: usize,
40 initialized: bool,
42}
43
44impl TcpReassembler {
45 #[must_use]
47 pub fn new() -> Self {
48 Self {
49 segments: BTreeMap::new(),
50 next_expected_seq: 0,
51 reassembled: ReassemblyStorage::new(),
52 total_buffered: 0,
53 fragment_count: 0,
54 initialized: false,
55 }
56 }
57
58 pub fn initialize(&mut self, initial_seq: u32) {
60 self.next_expected_seq = initial_seq;
61 self.initialized = true;
62 }
63
64 #[must_use]
66 pub fn is_initialized(&self) -> bool {
67 self.initialized
68 }
69
70 #[must_use]
75 pub fn reassembled_data(&self) -> &[u8] {
76 self.reassembled.as_slice().unwrap_or(&[])
77 }
78
79 pub fn read_reassembled(&self) -> std::io::Result<Vec<u8>> {
83 self.reassembled.read_all()
84 }
85
86 pub fn drain_reassembled(&mut self) -> std::io::Result<Vec<u8>> {
88 self.reassembled.drain()
89 }
90
91 #[must_use]
93 pub fn buffered_bytes(&self) -> usize {
94 self.total_buffered
95 }
96
97 #[must_use]
99 pub fn fragment_count(&self) -> usize {
100 self.fragment_count
101 }
102
103 #[must_use]
105 pub fn in_memory_bytes(&self) -> usize {
106 self.reassembled.in_memory_bytes() + self.total_buffered
107 }
108
109 pub fn spill(&mut self, spill_dir: Option<&Path>) -> std::io::Result<usize> {
113 self.reassembled.spill_to_disk(spill_dir)
114 }
115
116 #[must_use]
118 pub fn is_spilled(&self) -> bool {
119 self.reassembled.is_spilled()
120 }
121
122 pub fn process_segment(
127 &mut self,
128 seq: u32,
129 payload: &[u8],
130 config: &FlowConfig,
131 ) -> Result<ReassemblyAction, FlowError> {
132 if payload.is_empty() {
133 return Ok(ReassemblyAction::Empty);
134 }
135
136 if !self.initialized {
138 self.initialize(seq);
139 }
140
141 let seg_end = seq.wrapping_add(payload.len() as u32);
142
143 if self.seq_before_or_equal(seg_end, self.next_expected_seq) {
145 return Ok(ReassemblyAction::Duplicate);
146 }
147
148 if self.seq_before(seq, self.next_expected_seq) {
151 let overlap = self.next_expected_seq.wrapping_sub(seq) as usize;
152 if overlap >= payload.len() {
153 return Ok(ReassemblyAction::Duplicate);
154 }
155 let trimmed = &payload[overlap..];
156 self.reassembled.extend_from_slice(trimmed);
157 self.next_expected_seq = self.next_expected_seq.wrapping_add(trimmed.len() as u32);
158 self.try_drain_buffered();
159 return Ok(ReassemblyAction::OverlapTrimmed(trimmed.len()));
160 }
161
162 if seq == self.next_expected_seq {
164 self.reassembled.extend_from_slice(payload);
165 self.next_expected_seq = self.next_expected_seq.wrapping_add(payload.len() as u32);
166 self.try_drain_buffered();
167 return Ok(ReassemblyAction::DataReady(payload.len()));
168 }
169
170 if self.fragment_count >= config.max_ooo_fragments {
173 return Err(FlowError::TooManyFragments {
174 count: self.fragment_count,
175 limit: config.max_ooo_fragments,
176 });
177 }
178 if self.total_buffered + payload.len() > config.max_reassembly_buffer {
179 return Err(FlowError::ReassemblyBufferFull {
180 limit: config.max_reassembly_buffer,
181 });
182 }
183
184 self.segments.insert(seq, payload.to_vec());
185 self.total_buffered += payload.len();
186 self.fragment_count += 1;
187 Ok(ReassemblyAction::Buffered)
188 }
189
190 fn try_drain_buffered(&mut self) {
192 loop {
193 let key = {
194 let entry = self.segments.range(..=self.next_expected_seq).next_back();
195 match entry {
196 Some((&k, _)) => k,
197 None => break,
198 }
199 };
200
201 if let Some(data) = self.segments.remove(&key) {
202 let seg_end = key.wrapping_add(data.len() as u32);
203
204 self.total_buffered -= data.len();
205 self.fragment_count -= 1;
206
207 if self.seq_after(seg_end, self.next_expected_seq) {
208 if self.seq_before(key, self.next_expected_seq) {
209 let overlap = self.next_expected_seq.wrapping_sub(key) as usize;
210 if overlap < data.len() {
211 self.reassembled.extend_from_slice(&data[overlap..]);
212 self.next_expected_seq = seg_end;
213 }
214 } else {
215 self.reassembled.extend_from_slice(&data);
216 self.next_expected_seq = seg_end;
217 }
218 }
219 }
220 }
221 }
222
223 fn seq_before(&self, a: u32, b: u32) -> bool {
225 (a.wrapping_sub(b) as i32) < 0
226 }
227
228 fn seq_before_or_equal(&self, a: u32, b: u32) -> bool {
230 (a.wrapping_sub(b) as i32) <= 0
231 }
232
233 fn seq_after(&self, a: u32, b: u32) -> bool {
235 (a.wrapping_sub(b) as i32) > 0
236 }
237}
238
239impl Default for TcpReassembler {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 fn default_config() -> FlowConfig {
250 FlowConfig::default()
251 }
252
253 #[test]
254 fn test_in_order_reassembly() {
255 let config = default_config();
256 let mut r = TcpReassembler::new();
257 r.initialize(1000);
258
259 let action = r.process_segment(1000, b"hello", &config).unwrap();
260 assert_eq!(action, ReassemblyAction::DataReady(5));
261 assert_eq!(r.reassembled_data(), b"hello");
262 assert_eq!(r.next_expected_seq, 1005);
263
264 let action = r.process_segment(1005, b" world", &config).unwrap();
265 assert_eq!(action, ReassemblyAction::DataReady(6));
266 assert_eq!(r.reassembled_data(), b"hello world");
267 }
268
269 #[test]
270 fn test_out_of_order_then_fill_gap() {
271 let config = default_config();
272 let mut r = TcpReassembler::new();
273 r.initialize(1000);
274
275 let action = r.process_segment(1005, b" world", &config).unwrap();
276 assert_eq!(action, ReassemblyAction::Buffered);
277 assert_eq!(r.fragment_count(), 1);
278
279 let action = r.process_segment(1000, b"hello", &config).unwrap();
280 assert_eq!(action, ReassemblyAction::DataReady(5));
281 assert_eq!(r.reassembled_data(), b"hello world");
282 assert_eq!(r.fragment_count(), 0);
283 }
284
285 #[test]
286 fn test_total_duplicate() {
287 let config = default_config();
288 let mut r = TcpReassembler::new();
289 r.initialize(1000);
290
291 r.process_segment(1000, b"hello", &config).unwrap();
292 let action = r.process_segment(1000, b"hello", &config).unwrap();
293 assert_eq!(action, ReassemblyAction::Duplicate);
294 assert_eq!(r.reassembled_data(), b"hello");
295 }
296
297 #[test]
298 fn test_partial_overlap() {
299 let config = default_config();
300 let mut r = TcpReassembler::new();
301 r.initialize(1000);
302
303 r.process_segment(1000, b"hello", &config).unwrap();
304 let action = r.process_segment(1003, b"lo wo", &config).unwrap();
305 assert_eq!(action, ReassemblyAction::OverlapTrimmed(3));
306 assert_eq!(r.reassembled_data(), b"hello wo");
307 }
308
309 #[test]
310 fn test_empty_payload() {
311 let config = default_config();
312 let mut r = TcpReassembler::new();
313 r.initialize(1000);
314
315 let action = r.process_segment(1000, b"", &config).unwrap();
316 assert_eq!(action, ReassemblyAction::Empty);
317 }
318
319 #[test]
320 fn test_fragment_limit() {
321 let mut config = default_config();
322 config.max_ooo_fragments = 2;
323
324 let mut r = TcpReassembler::new();
325 r.initialize(1000);
326
327 r.process_segment(1010, b"a", &config).unwrap();
328 r.process_segment(1020, b"b", &config).unwrap();
329 let err = r.process_segment(1030, b"c", &config);
330 assert!(matches!(err, Err(FlowError::TooManyFragments { .. })));
331 }
332
333 #[test]
334 fn test_buffer_size_limit() {
335 let mut config = default_config();
336 config.max_reassembly_buffer = 10;
337
338 let mut r = TcpReassembler::new();
339 r.initialize(1000);
340
341 r.process_segment(1010, b"12345", &config).unwrap();
342 let err = r.process_segment(1020, b"123456", &config);
343 assert!(matches!(err, Err(FlowError::ReassemblyBufferFull { .. })));
344 }
345
346 #[test]
347 fn test_multiple_ooo_segments_drain() {
348 let config = default_config();
349 let mut r = TcpReassembler::new();
350 r.initialize(100);
351
352 r.process_segment(110, b"ccc", &config).unwrap();
353 r.process_segment(105, b"bbbbb", &config).unwrap();
354 assert_eq!(r.fragment_count(), 2);
355
356 r.process_segment(100, b"aaaaa", &config).unwrap();
357 assert_eq!(r.reassembled_data(), b"aaaaabbbbbccc");
358 assert_eq!(r.fragment_count(), 0);
359 }
360
361 #[test]
362 fn test_auto_initialize() {
363 let config = default_config();
364 let mut r = TcpReassembler::new();
365
366 let action = r.process_segment(5000, b"data", &config).unwrap();
367 assert_eq!(action, ReassemblyAction::DataReady(4));
368 assert!(r.is_initialized());
369 assert_eq!(r.reassembled_data(), b"data");
370 }
371
372 #[test]
373 fn test_drain_reassembled() {
374 let config = default_config();
375 let mut r = TcpReassembler::new();
376 r.initialize(0);
377
378 r.process_segment(0, b"hello", &config).unwrap();
379 let data = r.drain_reassembled().unwrap();
380 assert_eq!(data, b"hello");
381 assert!(r.reassembled_data().is_empty());
382 }
383
384 #[test]
385 fn test_spill_and_read() {
386 let config = default_config();
387 let mut r = TcpReassembler::new();
388 r.initialize(0);
389
390 r.process_segment(0, b"spill test data", &config).unwrap();
391 assert_eq!(r.in_memory_bytes(), 15);
392
393 let freed = r.spill(None).unwrap();
394 assert_eq!(freed, 15);
395 assert!(r.is_spilled());
396 assert_eq!(r.in_memory_bytes(), 0);
397
398 let data = r.read_reassembled().unwrap();
399 assert_eq!(data, b"spill test data");
400 }
401}