1use std::collections::BTreeMap;
2
3use super::config::FlowConfig;
4use super::error::FlowError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum ReassemblyAction {
9 DataReady(usize),
11 Buffered,
13 Duplicate,
15 OverlapTrimmed(usize),
17 Empty,
19}
20
21#[derive(Debug)]
27pub struct TcpReassembler {
28 segments: BTreeMap<u32, Vec<u8>>,
30 next_expected_seq: u32,
32 reassembled: Vec<u8>,
34 total_buffered: usize,
36 fragment_count: usize,
38 initialized: bool,
40}
41
42impl TcpReassembler {
43 pub fn new() -> Self {
45 Self {
46 segments: BTreeMap::new(),
47 next_expected_seq: 0,
48 reassembled: Vec::new(),
49 total_buffered: 0,
50 fragment_count: 0,
51 initialized: false,
52 }
53 }
54
55 pub fn initialize(&mut self, initial_seq: u32) {
57 self.next_expected_seq = initial_seq;
58 self.initialized = true;
59 }
60
61 pub fn is_initialized(&self) -> bool {
63 self.initialized
64 }
65
66 pub fn reassembled_data(&self) -> &[u8] {
68 &self.reassembled
69 }
70
71 pub fn drain_reassembled(&mut self) -> Vec<u8> {
73 std::mem::take(&mut self.reassembled)
74 }
75
76 pub fn buffered_bytes(&self) -> usize {
78 self.total_buffered
79 }
80
81 pub fn fragment_count(&self) -> usize {
83 self.fragment_count
84 }
85
86 pub fn process_segment(
91 &mut self,
92 seq: u32,
93 payload: &[u8],
94 config: &FlowConfig,
95 ) -> Result<ReassemblyAction, FlowError> {
96 if payload.is_empty() {
97 return Ok(ReassemblyAction::Empty);
98 }
99
100 if !self.initialized {
102 self.initialize(seq);
103 }
104
105 let seg_end = seq.wrapping_add(payload.len() as u32);
106
107 if self.seq_before_or_equal(seg_end, self.next_expected_seq) {
109 return Ok(ReassemblyAction::Duplicate);
110 }
111
112 if self.seq_before(seq, self.next_expected_seq) {
115 let overlap = self.next_expected_seq.wrapping_sub(seq) as usize;
116 if overlap >= payload.len() {
117 return Ok(ReassemblyAction::Duplicate);
118 }
119 let trimmed = &payload[overlap..];
120 self.reassembled.extend_from_slice(trimmed);
121 self.next_expected_seq = self.next_expected_seq.wrapping_add(trimmed.len() as u32);
122 self.try_drain_buffered();
123 return Ok(ReassemblyAction::OverlapTrimmed(trimmed.len()));
124 }
125
126 if seq == self.next_expected_seq {
128 self.reassembled.extend_from_slice(payload);
129 self.next_expected_seq = self.next_expected_seq.wrapping_add(payload.len() as u32);
130 self.try_drain_buffered();
131 return Ok(ReassemblyAction::DataReady(payload.len()));
132 }
133
134 if self.fragment_count >= config.max_ooo_fragments {
137 return Err(FlowError::TooManyFragments {
138 count: self.fragment_count,
139 limit: config.max_ooo_fragments,
140 });
141 }
142 if self.total_buffered + payload.len() > config.max_reassembly_buffer {
143 return Err(FlowError::ReassemblyBufferFull {
144 limit: config.max_reassembly_buffer,
145 });
146 }
147
148 self.segments.insert(seq, payload.to_vec());
149 self.total_buffered += payload.len();
150 self.fragment_count += 1;
151 Ok(ReassemblyAction::Buffered)
152 }
153
154 fn try_drain_buffered(&mut self) {
156 loop {
158 let key = {
160 let entry = self.segments.range(..=self.next_expected_seq).next_back();
161 match entry {
162 Some((&k, _)) => k,
163 None => break,
164 }
165 };
166
167 if let Some(data) = self.segments.remove(&key) {
169 let seg_end = key.wrapping_add(data.len() as u32);
170
171 self.total_buffered -= data.len();
172 self.fragment_count -= 1;
173
174 if self.seq_after(seg_end, self.next_expected_seq) {
176 if self.seq_before(key, self.next_expected_seq) {
177 let overlap = self.next_expected_seq.wrapping_sub(key) as usize;
179 if overlap < data.len() {
180 self.reassembled.extend_from_slice(&data[overlap..]);
181 self.next_expected_seq = seg_end;
182 }
183 } else {
184 self.reassembled.extend_from_slice(&data);
186 self.next_expected_seq = seg_end;
187 }
188 }
189 }
191 }
192 }
193
194 fn seq_before(&self, a: u32, b: u32) -> bool {
196 (a.wrapping_sub(b) as i32) < 0
197 }
198
199 fn seq_before_or_equal(&self, a: u32, b: u32) -> bool {
201 (a.wrapping_sub(b) as i32) <= 0
202 }
203
204 fn seq_after(&self, a: u32, b: u32) -> bool {
206 (a.wrapping_sub(b) as i32) > 0
207 }
208}
209
210impl Default for TcpReassembler {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 fn default_config() -> FlowConfig {
221 FlowConfig::default()
222 }
223
224 #[test]
225 fn test_in_order_reassembly() {
226 let config = default_config();
227 let mut r = TcpReassembler::new();
228 r.initialize(1000);
229
230 let action = r.process_segment(1000, b"hello", &config).unwrap();
231 assert_eq!(action, ReassemblyAction::DataReady(5));
232 assert_eq!(r.reassembled_data(), b"hello");
233 assert_eq!(r.next_expected_seq, 1005);
234
235 let action = r.process_segment(1005, b" world", &config).unwrap();
236 assert_eq!(action, ReassemblyAction::DataReady(6));
237 assert_eq!(r.reassembled_data(), b"hello world");
238 }
239
240 #[test]
241 fn test_out_of_order_then_fill_gap() {
242 let config = default_config();
243 let mut r = TcpReassembler::new();
244 r.initialize(1000);
245
246 let action = r.process_segment(1005, b" world", &config).unwrap();
248 assert_eq!(action, ReassemblyAction::Buffered);
249 assert_eq!(r.fragment_count(), 1);
250
251 let action = r.process_segment(1000, b"hello", &config).unwrap();
253 assert_eq!(action, ReassemblyAction::DataReady(5));
254 assert_eq!(r.reassembled_data(), b"hello world");
256 assert_eq!(r.fragment_count(), 0);
257 }
258
259 #[test]
260 fn test_total_duplicate() {
261 let config = default_config();
262 let mut r = TcpReassembler::new();
263 r.initialize(1000);
264
265 r.process_segment(1000, b"hello", &config).unwrap();
266 let action = r.process_segment(1000, b"hello", &config).unwrap();
267 assert_eq!(action, ReassemblyAction::Duplicate);
268 assert_eq!(r.reassembled_data(), b"hello");
269 }
270
271 #[test]
272 fn test_partial_overlap() {
273 let config = default_config();
274 let mut r = TcpReassembler::new();
275 r.initialize(1000);
276
277 r.process_segment(1000, b"hello", &config).unwrap();
278 let action = r.process_segment(1003, b"lo wo", &config).unwrap();
280 assert_eq!(action, ReassemblyAction::OverlapTrimmed(3));
281 assert_eq!(r.reassembled_data(), b"hello wo");
282 }
283
284 #[test]
285 fn test_empty_payload() {
286 let config = default_config();
287 let mut r = TcpReassembler::new();
288 r.initialize(1000);
289
290 let action = r.process_segment(1000, b"", &config).unwrap();
291 assert_eq!(action, ReassemblyAction::Empty);
292 }
293
294 #[test]
295 fn test_fragment_limit() {
296 let mut config = default_config();
297 config.max_ooo_fragments = 2;
298
299 let mut r = TcpReassembler::new();
300 r.initialize(1000);
301
302 r.process_segment(1010, b"a", &config).unwrap();
303 r.process_segment(1020, b"b", &config).unwrap();
304 let err = r.process_segment(1030, b"c", &config);
305 assert!(matches!(err, Err(FlowError::TooManyFragments { .. })));
306 }
307
308 #[test]
309 fn test_buffer_size_limit() {
310 let mut config = default_config();
311 config.max_reassembly_buffer = 10;
312
313 let mut r = TcpReassembler::new();
314 r.initialize(1000);
315
316 r.process_segment(1010, b"12345", &config).unwrap();
317 let err = r.process_segment(1020, b"123456", &config);
318 assert!(matches!(err, Err(FlowError::ReassemblyBufferFull { .. })));
319 }
320
321 #[test]
322 fn test_multiple_ooo_segments_drain() {
323 let config = default_config();
324 let mut r = TcpReassembler::new();
325 r.initialize(100);
326
327 r.process_segment(110, b"ccc", &config).unwrap();
329 r.process_segment(105, b"bbbbb", &config).unwrap();
330 assert_eq!(r.fragment_count(), 2);
331
332 r.process_segment(100, b"aaaaa", &config).unwrap();
334 assert_eq!(r.reassembled_data(), b"aaaaabbbbbccc");
335 assert_eq!(r.fragment_count(), 0);
336 }
337
338 #[test]
339 fn test_auto_initialize() {
340 let config = default_config();
341 let mut r = TcpReassembler::new();
342
343 let action = r.process_segment(5000, b"data", &config).unwrap();
345 assert_eq!(action, ReassemblyAction::DataReady(4));
346 assert!(r.is_initialized());
347 assert_eq!(r.reassembled_data(), b"data");
348 }
349
350 #[test]
351 fn test_drain_reassembled() {
352 let config = default_config();
353 let mut r = TcpReassembler::new();
354 r.initialize(0);
355
356 r.process_segment(0, b"hello", &config).unwrap();
357 let data = r.drain_reassembled();
358 assert_eq!(data, b"hello");
359 assert!(r.reassembled_data().is_empty());
360 }
361}