1use std::collections::BTreeMap;
2
3use super::config::FlowConfig;
4use super::error::FlowError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum ReassemblyAction {
9 DataReady(usize),
11 Buffered,
13 Duplicate,
15 OverlapTrimmed(usize),
17 Empty,
19}
20
21#[derive(Debug)]
27pub struct TcpReassembler {
28 segments: BTreeMap<u32, Vec<u8>>,
30 next_expected_seq: u32,
32 reassembled: Vec<u8>,
34 total_buffered: usize,
36 fragment_count: usize,
38 initialized: bool,
40}
41
42impl TcpReassembler {
43 #[must_use]
45 pub fn new() -> Self {
46 Self {
47 segments: BTreeMap::new(),
48 next_expected_seq: 0,
49 reassembled: Vec::new(),
50 total_buffered: 0,
51 fragment_count: 0,
52 initialized: false,
53 }
54 }
55
56 pub fn initialize(&mut self, initial_seq: u32) {
58 self.next_expected_seq = initial_seq;
59 self.initialized = true;
60 }
61
62 #[must_use]
64 pub fn is_initialized(&self) -> bool {
65 self.initialized
66 }
67
68 #[must_use]
70 pub fn reassembled_data(&self) -> &[u8] {
71 &self.reassembled
72 }
73
74 pub fn drain_reassembled(&mut self) -> Vec<u8> {
76 std::mem::take(&mut self.reassembled)
77 }
78
79 #[must_use]
81 pub fn buffered_bytes(&self) -> usize {
82 self.total_buffered
83 }
84
85 #[must_use]
87 pub fn fragment_count(&self) -> usize {
88 self.fragment_count
89 }
90
91 pub fn process_segment(
96 &mut self,
97 seq: u32,
98 payload: &[u8],
99 config: &FlowConfig,
100 ) -> Result<ReassemblyAction, FlowError> {
101 if payload.is_empty() {
102 return Ok(ReassemblyAction::Empty);
103 }
104
105 if !self.initialized {
107 self.initialize(seq);
108 }
109
110 let seg_end = seq.wrapping_add(payload.len() as u32);
111
112 if self.seq_before_or_equal(seg_end, self.next_expected_seq) {
114 return Ok(ReassemblyAction::Duplicate);
115 }
116
117 if self.seq_before(seq, self.next_expected_seq) {
120 let overlap = self.next_expected_seq.wrapping_sub(seq) as usize;
121 if overlap >= payload.len() {
122 return Ok(ReassemblyAction::Duplicate);
123 }
124 let trimmed = &payload[overlap..];
125 self.reassembled.extend_from_slice(trimmed);
126 self.next_expected_seq = self.next_expected_seq.wrapping_add(trimmed.len() as u32);
127 self.try_drain_buffered();
128 return Ok(ReassemblyAction::OverlapTrimmed(trimmed.len()));
129 }
130
131 if seq == self.next_expected_seq {
133 self.reassembled.extend_from_slice(payload);
134 self.next_expected_seq = self.next_expected_seq.wrapping_add(payload.len() as u32);
135 self.try_drain_buffered();
136 return Ok(ReassemblyAction::DataReady(payload.len()));
137 }
138
139 if self.fragment_count >= config.max_ooo_fragments {
142 return Err(FlowError::TooManyFragments {
143 count: self.fragment_count,
144 limit: config.max_ooo_fragments,
145 });
146 }
147 if self.total_buffered + payload.len() > config.max_reassembly_buffer {
148 return Err(FlowError::ReassemblyBufferFull {
149 limit: config.max_reassembly_buffer,
150 });
151 }
152
153 self.segments.insert(seq, payload.to_vec());
154 self.total_buffered += payload.len();
155 self.fragment_count += 1;
156 Ok(ReassemblyAction::Buffered)
157 }
158
159 fn try_drain_buffered(&mut self) {
161 loop {
163 let key = {
165 let entry = self.segments.range(..=self.next_expected_seq).next_back();
166 match entry {
167 Some((&k, _)) => k,
168 None => break,
169 }
170 };
171
172 if let Some(data) = self.segments.remove(&key) {
174 let seg_end = key.wrapping_add(data.len() as u32);
175
176 self.total_buffered -= data.len();
177 self.fragment_count -= 1;
178
179 if self.seq_after(seg_end, self.next_expected_seq) {
181 if self.seq_before(key, self.next_expected_seq) {
182 let overlap = self.next_expected_seq.wrapping_sub(key) as usize;
184 if overlap < data.len() {
185 self.reassembled.extend_from_slice(&data[overlap..]);
186 self.next_expected_seq = seg_end;
187 }
188 } else {
189 self.reassembled.extend_from_slice(&data);
191 self.next_expected_seq = seg_end;
192 }
193 }
194 }
196 }
197 }
198
199 fn seq_before(&self, a: u32, b: u32) -> bool {
201 (a.wrapping_sub(b) as i32) < 0
202 }
203
204 fn seq_before_or_equal(&self, a: u32, b: u32) -> bool {
206 (a.wrapping_sub(b) as i32) <= 0
207 }
208
209 fn seq_after(&self, a: u32, b: u32) -> bool {
211 (a.wrapping_sub(b) as i32) > 0
212 }
213}
214
215impl Default for TcpReassembler {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 fn default_config() -> FlowConfig {
226 FlowConfig::default()
227 }
228
229 #[test]
230 fn test_in_order_reassembly() {
231 let config = default_config();
232 let mut r = TcpReassembler::new();
233 r.initialize(1000);
234
235 let action = r.process_segment(1000, b"hello", &config).unwrap();
236 assert_eq!(action, ReassemblyAction::DataReady(5));
237 assert_eq!(r.reassembled_data(), b"hello");
238 assert_eq!(r.next_expected_seq, 1005);
239
240 let action = r.process_segment(1005, b" world", &config).unwrap();
241 assert_eq!(action, ReassemblyAction::DataReady(6));
242 assert_eq!(r.reassembled_data(), b"hello world");
243 }
244
245 #[test]
246 fn test_out_of_order_then_fill_gap() {
247 let config = default_config();
248 let mut r = TcpReassembler::new();
249 r.initialize(1000);
250
251 let action = r.process_segment(1005, b" world", &config).unwrap();
253 assert_eq!(action, ReassemblyAction::Buffered);
254 assert_eq!(r.fragment_count(), 1);
255
256 let action = r.process_segment(1000, b"hello", &config).unwrap();
258 assert_eq!(action, ReassemblyAction::DataReady(5));
259 assert_eq!(r.reassembled_data(), b"hello world");
261 assert_eq!(r.fragment_count(), 0);
262 }
263
264 #[test]
265 fn test_total_duplicate() {
266 let config = default_config();
267 let mut r = TcpReassembler::new();
268 r.initialize(1000);
269
270 r.process_segment(1000, b"hello", &config).unwrap();
271 let action = r.process_segment(1000, b"hello", &config).unwrap();
272 assert_eq!(action, ReassemblyAction::Duplicate);
273 assert_eq!(r.reassembled_data(), b"hello");
274 }
275
276 #[test]
277 fn test_partial_overlap() {
278 let config = default_config();
279 let mut r = TcpReassembler::new();
280 r.initialize(1000);
281
282 r.process_segment(1000, b"hello", &config).unwrap();
283 let action = r.process_segment(1003, b"lo wo", &config).unwrap();
285 assert_eq!(action, ReassemblyAction::OverlapTrimmed(3));
286 assert_eq!(r.reassembled_data(), b"hello wo");
287 }
288
289 #[test]
290 fn test_empty_payload() {
291 let config = default_config();
292 let mut r = TcpReassembler::new();
293 r.initialize(1000);
294
295 let action = r.process_segment(1000, b"", &config).unwrap();
296 assert_eq!(action, ReassemblyAction::Empty);
297 }
298
299 #[test]
300 fn test_fragment_limit() {
301 let mut config = default_config();
302 config.max_ooo_fragments = 2;
303
304 let mut r = TcpReassembler::new();
305 r.initialize(1000);
306
307 r.process_segment(1010, b"a", &config).unwrap();
308 r.process_segment(1020, b"b", &config).unwrap();
309 let err = r.process_segment(1030, b"c", &config);
310 assert!(matches!(err, Err(FlowError::TooManyFragments { .. })));
311 }
312
313 #[test]
314 fn test_buffer_size_limit() {
315 let mut config = default_config();
316 config.max_reassembly_buffer = 10;
317
318 let mut r = TcpReassembler::new();
319 r.initialize(1000);
320
321 r.process_segment(1010, b"12345", &config).unwrap();
322 let err = r.process_segment(1020, b"123456", &config);
323 assert!(matches!(err, Err(FlowError::ReassemblyBufferFull { .. })));
324 }
325
326 #[test]
327 fn test_multiple_ooo_segments_drain() {
328 let config = default_config();
329 let mut r = TcpReassembler::new();
330 r.initialize(100);
331
332 r.process_segment(110, b"ccc", &config).unwrap();
334 r.process_segment(105, b"bbbbb", &config).unwrap();
335 assert_eq!(r.fragment_count(), 2);
336
337 r.process_segment(100, b"aaaaa", &config).unwrap();
339 assert_eq!(r.reassembled_data(), b"aaaaabbbbbccc");
340 assert_eq!(r.fragment_count(), 0);
341 }
342
343 #[test]
344 fn test_auto_initialize() {
345 let config = default_config();
346 let mut r = TcpReassembler::new();
347
348 let action = r.process_segment(5000, b"data", &config).unwrap();
350 assert_eq!(action, ReassemblyAction::DataReady(4));
351 assert!(r.is_initialized());
352 assert_eq!(r.reassembled_data(), b"data");
353 }
354
355 #[test]
356 fn test_drain_reassembled() {
357 let config = default_config();
358 let mut r = TcpReassembler::new();
359 r.initialize(0);
360
361 r.process_segment(0, b"hello", &config).unwrap();
362 let data = r.drain_reassembled();
363 assert_eq!(data, b"hello");
364 assert!(r.reassembled_data().is_empty());
365 }
366}