1use serde::{Serialize, Deserialize};
29use std::collections::HashMap;
30use tracing::{debug, warn};
31
32#[derive(Serialize, Deserialize, Clone, Debug)]
34pub struct Fragment {
35 pub fragment_id: u64,
37 pub fragment_index: u16,
39 pub total_fragments: u16,
41 pub data: Vec<u8>,
43}
44
45impl Fragment {
46 pub fn is_last(&self) -> bool {
48 self.fragment_index == self.total_fragments - 1
49 }
50
51 pub fn is_single(&self) -> bool {
53 self.total_fragments == 1
54 }
55}
56
57pub struct Fragmenter;
59
60impl Fragmenter {
61 pub fn split(data: &[u8], fragment_size: usize, fragment_id: u64) -> Vec<Fragment> {
73 assert!(fragment_size > 0, "fragment_size must be > 0");
74
75 let chunks: Vec<&[u8]> = data.chunks(fragment_size).collect();
76 let total = chunks.len() as u16;
77
78 debug!(
79 fragment_id,
80 total_size = data.len(),
81 fragment_size,
82 total_fragments = total,
83 "Splitting payload into fragments"
84 );
85
86 chunks
87 .into_iter()
88 .enumerate()
89 .map(|(i, chunk)| Fragment {
90 fragment_id,
91 fragment_index: i as u16,
92 total_fragments: total,
93 data: chunk.to_vec(),
94 })
95 .collect()
96 }
97
98 pub fn needs_split(data: &[u8], fragment_size: usize) -> bool {
100 data.len() > fragment_size
101 }
102}
103
104pub struct Reassembler {
111 pending: HashMap<u64, ReassemblyBuffer>,
113 max_pending: usize,
115}
116
117struct ReassemblyBuffer {
118 fragments: HashMap<u16, Vec<u8>>,
119 total: u16,
120}
121
122impl Reassembler {
123 pub fn new() -> Self {
125 Reassembler {
126 pending: HashMap::new(),
127 max_pending: 256,
128 }
129 }
130
131 pub fn with_max_pending(max_pending: usize) -> Self {
133 Reassembler {
134 pending: HashMap::new(),
135 max_pending,
136 }
137 }
138
139 pub fn add(&mut self, fragment: Fragment) -> Option<Vec<u8>> {
144 if !self.pending.contains_key(&fragment.fragment_id)
146 && self.pending.len() >= self.max_pending
147 {
148 warn!(
149 fragment_id = fragment.fragment_id,
150 max_pending = self.max_pending,
151 "Reassembler at capacity, dropping fragment"
152 );
153 return None;
154 }
155
156 let id = fragment.fragment_id;
157 let total = fragment.total_fragments;
158 let index = fragment.fragment_index;
159
160 if fragment.is_single() {
162 debug!(fragment_id = id, "Single fragment, no reassembly needed");
163 return Some(fragment.data);
164 }
165
166 let buffer = self.pending.entry(id).or_insert_with(|| {
167 debug!(fragment_id = id, total_fragments = total, "New reassembly buffer");
168 ReassemblyBuffer {
169 fragments: HashMap::new(),
170 total,
171 }
172 });
173
174 if buffer.fragments.contains_key(&index) {
176 warn!(fragment_id = id, index, "Duplicate fragment ignored");
177 return None;
178 }
179
180 buffer.fragments.insert(index, fragment.data);
181 debug!(
182 fragment_id = id,
183 received = buffer.fragments.len(),
184 total = buffer.total,
185 "Fragment received"
186 );
187
188 if buffer.fragments.len() == buffer.total as usize {
190 let buf = self.pending.remove(&id).unwrap();
191 debug!(fragment_id = id, "Reassembly complete");
192 return Some(Self::assemble(buf));
193 }
194
195 None
196 }
197
198 pub fn pending_count(&self) -> usize {
200 self.pending.len()
201 }
202
203 pub fn cleanup(&mut self) {
205 let count = self.pending.len();
206 self.pending.clear();
207 if count > 0 {
208 warn!(dropped = count, "Reassembler cleanup: dropped incomplete messages");
209 }
210 }
211
212 fn assemble(buf: ReassemblyBuffer) -> Vec<u8> {
213 let mut indices: Vec<u16> = buf.fragments.keys().copied().collect();
214 indices.sort_unstable();
215 indices
216 .into_iter()
217 .flat_map(|i| buf.fragments[&i].clone())
218 .collect()
219 }
220}
221
222impl Default for Reassembler {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_no_fragmentation_needed() {
234 let data = vec![1u8; 100];
235 assert!(!Fragmenter::needs_split(&data, 1200));
236 }
237
238 #[test]
239 fn test_fragmentation_needed() {
240 let data = vec![1u8; 5000];
241 assert!(Fragmenter::needs_split(&data, 1200));
242 }
243
244 #[test]
245 fn test_single_fragment() {
246 let data = vec![42u8; 500];
247 let frags = Fragmenter::split(&data, 1200, 1);
248 assert_eq!(frags.len(), 1);
249 assert!(frags[0].is_single());
250 assert_eq!(frags[0].data, data);
251 }
252
253 #[test]
254 fn test_split_exact() {
255 let data = vec![0u8; 2400];
256 let frags = Fragmenter::split(&data, 1200, 2);
257 assert_eq!(frags.len(), 2);
258 assert_eq!(frags[0].fragment_index, 0);
259 assert_eq!(frags[1].fragment_index, 1);
260 assert!(frags[1].is_last());
261 }
262
263 #[test]
264 fn test_split_remainder() {
265 let data = vec![0u8; 2500];
266 let frags = Fragmenter::split(&data, 1200, 3);
267 assert_eq!(frags.len(), 3);
268 assert_eq!(frags[2].data.len(), 100);
269 }
270
271 #[test]
272 fn test_reassemble_in_order() {
273 let data: Vec<u8> = (0..255).collect();
274 let frags = Fragmenter::split(&data, 50, 10);
275 let mut r = Reassembler::new();
276 let mut result = None;
277 for f in frags {
278 result = r.add(f);
279 }
280 assert_eq!(result.unwrap(), data);
281 }
282
283 #[test]
284 fn test_reassemble_out_of_order() {
285 let data: Vec<u8> = (0..255).collect();
286 let mut frags = Fragmenter::split(&data, 50, 11);
287 frags.reverse(); let mut r = Reassembler::new();
289 let mut result = None;
290 for f in frags {
291 result = r.add(f);
292 }
293 assert_eq!(result.unwrap(), data);
294 }
295
296 #[test]
297 fn test_reassemble_single() {
298 let data = vec![1u8; 100];
299 let frags = Fragmenter::split(&data, 1200, 5);
300 assert_eq!(frags.len(), 1);
301 let mut r = Reassembler::new();
302 let result = r.add(frags.into_iter().next().unwrap());
303 assert_eq!(result.unwrap(), data);
304 }
305
306 #[test]
307 fn test_duplicate_fragment_ignored() {
308 let data: Vec<u8> = (0..200).collect();
309 let frags = Fragmenter::split(&data, 50, 20);
310 let mut r = Reassembler::new();
311 let dup = frags[0].clone();
312 r.add(frags[0].clone());
313 r.add(dup); assert_eq!(r.pending_count(), 1);
315 }
316
317 #[test]
318 fn test_multiple_messages() {
319 let data1: Vec<u8> = vec![1u8; 3000];
320 let data2: Vec<u8> = vec![2u8; 2500];
321 let frags1 = Fragmenter::split(&data1, 1200, 100);
322 let frags2 = Fragmenter::split(&data2, 1200, 101);
323
324 let mut r = Reassembler::new();
325 let mut result1 = None;
326 let mut result2 = None;
327
328 let max = frags1.len().max(frags2.len());
330 for i in 0..max {
331 if i < frags1.len() {
332 result1 = r.add(frags1[i].clone());
333 }
334 if i < frags2.len() {
335 result2 = r.add(frags2[i].clone());
336 }
337 }
338
339 assert_eq!(result1.unwrap(), data1);
340 assert_eq!(result2.unwrap(), data2);
341 }
342
343 #[test]
344 fn test_cleanup() {
345 let data: Vec<u8> = vec![0u8; 3000];
346 let frags = Fragmenter::split(&data, 1200, 99);
347 let mut r = Reassembler::new();
348 r.add(frags[0].clone()); assert_eq!(r.pending_count(), 1);
350 r.cleanup();
351 assert_eq!(r.pending_count(), 0);
352 }
353
354 #[test]
355 fn test_large_payload() {
356 let data: Vec<u8> = (0..=255).cycle().take(65000).collect();
357 let frags = Fragmenter::split(&data, 1200, 42);
358 assert!(frags.len() > 1);
359 let mut r = Reassembler::new();
360 let mut result = None;
361 for f in frags {
362 result = r.add(f);
363 }
364 assert_eq!(result.unwrap(), data);
365 }
366}