1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
4use bytes::{Bytes, BytesMut};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::io;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
12use tracing::{debug, warn};
13
14const MAX_FRAGMENT_DATA_SIZE: usize = 8 * 1024 * 1024; const FRAGMENT_TIMEOUT: Duration = Duration::from_secs(30);
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23struct Fragment {
24 message_id: u64,
26 fragment_index: u32,
28 total_fragments: u32,
30 data: String,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36enum FrameType {
37 Complete(Vec<u8>),
39 Fragment(Fragment),
41}
42
43#[derive(Debug)]
45struct PartialMessage {
46 created_at: Instant,
48 total_fragments: u32,
50 fragments: HashMap<u32, Vec<u8>>,
52}
53
54impl PartialMessage {
55 fn new(total_fragments: u32) -> Self {
56 Self {
57 created_at: Instant::now(),
58 total_fragments,
59 fragments: HashMap::new(),
60 }
61 }
62
63 fn add_fragment(&mut self, index: u32, data: Vec<u8>) {
64 self.fragments.insert(index, data);
65 }
66
67 fn is_complete(&self) -> bool {
68 self.fragments.len() == self.total_fragments as usize
69 }
70
71 fn is_expired(&self) -> bool {
72 self.created_at.elapsed() > FRAGMENT_TIMEOUT
73 }
74
75 fn reassemble(self) -> io::Result<Vec<u8>> {
76 if !self.is_complete() {
77 return Err(io::Error::new(
78 io::ErrorKind::InvalidData,
79 "Cannot reassemble incomplete message",
80 ));
81 }
82
83 let mut result = Vec::new();
84
85 for i in 0..self.total_fragments {
87 if let Some(fragment_data) = self.fragments.get(&i) {
88 result.extend_from_slice(fragment_data);
89 } else {
90 return Err(io::Error::new(
91 io::ErrorKind::InvalidData,
92 format!("Missing fragment {}", i),
93 ));
94 }
95 }
96
97 Ok(result)
98 }
99}
100
101#[derive(Debug)]
103struct SharedState {
104 next_message_id: AtomicU64,
106 partial_messages: Mutex<HashMap<u64, PartialMessage>>,
108 last_cleanup: Mutex<Instant>,
110}
111
112impl SharedState {
113 fn new() -> Self {
114 Self {
115 next_message_id: AtomicU64::new(1),
116 partial_messages: Mutex::new(HashMap::new()),
117 last_cleanup: Mutex::new(Instant::now()),
118 }
119 }
120
121 fn next_message_id(&self) -> u64 {
122 self.next_message_id.fetch_add(1, Ordering::Relaxed)
123 }
124
125 fn cleanup_expired(&self) {
126 {
128 let last_cleanup = self.last_cleanup.lock().unwrap();
129 if last_cleanup.elapsed() < Duration::from_secs(10) {
130 return;
131 }
132 }
133
134 let mut partial_messages = self.partial_messages.lock().unwrap();
135 let before_count = partial_messages.len();
136
137 partial_messages.retain(|message_id, partial| {
138 if partial.is_expired() {
139 warn!("Cleaning up expired partial message {}", message_id);
140 false
141 } else {
142 true
143 }
144 });
145
146 let cleaned = before_count - partial_messages.len();
147 if cleaned > 0 {
148 debug!("Cleaned up {} expired partial messages", cleaned);
149 }
150
151 *self.last_cleanup.lock().unwrap() = Instant::now();
152 }
153}
154
155#[derive(Debug)]
157pub struct FragmentingCodec {
158 inner: LengthDelimitedCodec,
160 shared_state: Arc<SharedState>,
162}
163
164impl FragmentingCodec {
165 pub fn new() -> Self {
167 let mut inner = LengthDelimitedCodec::new();
168 inner.set_max_frame_length(32 * 1024 * 1024); Self {
171 inner,
172 shared_state: Arc::new(SharedState::new()),
173 }
174 }
175
176 fn fragment_message(&self, data: &[u8]) -> Vec<Fragment> {
178 let message_id = self.shared_state.next_message_id();
179 let total_size = data.len();
180
181 let chunk_size = MAX_FRAGMENT_DATA_SIZE;
183
184 let total_fragments = total_size.div_ceil(chunk_size);
186
187 debug!(
188 "Fragmenting message {} into {} fragments (total size: {} bytes, chunk size: {} bytes)",
189 message_id, total_fragments, total_size, chunk_size
190 );
191
192 let mut fragments = Vec::new();
193
194 for (i, chunk) in data.chunks(chunk_size).enumerate() {
195 let fragment = Fragment {
196 message_id,
197 fragment_index: i as u32,
198 total_fragments: total_fragments as u32,
199 data: BASE64.encode(chunk),
200 };
201
202 if let Ok(serialized) = serde_json::to_vec(&FrameType::Fragment(fragment.clone())) {
204 debug!("Fragment {} serialized size: {} bytes", i, serialized.len());
205 if serialized.len() > 31 * 1024 * 1024 {
206 warn!("Fragment {} serialized size ({} bytes) is dangerously close to frame limit", i, serialized.len());
208 }
209 }
210
211 fragments.push(fragment);
212 }
213
214 fragments
215 }
216}
217
218impl Default for FragmentingCodec {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224impl Clone for FragmentingCodec {
225 fn clone(&self) -> Self {
226 let mut inner = LengthDelimitedCodec::new();
227 inner.set_max_frame_length(32 * 1024 * 1024); Self {
230 inner,
231 shared_state: Arc::clone(&self.shared_state),
232 }
233 }
234}
235
236impl Encoder<Bytes> for FragmentingCodec {
237 type Error = io::Error;
238
239 fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
240 let data = item.to_vec();
241
242 if data.len() <= MAX_FRAGMENT_DATA_SIZE {
244 let frame = FrameType::Complete(data);
246 let serialized = serde_json::to_vec(&frame).map_err(|e| {
247 io::Error::new(
248 io::ErrorKind::InvalidData,
249 format!("Failed to serialize frame: {}", e),
250 )
251 })?;
252
253 self.inner.encode(Bytes::from(serialized), dst)
254 } else {
255 let fragments = self.fragment_message(&data);
257
258 for fragment in fragments {
260 let frame = FrameType::Fragment(fragment);
261 let serialized = serde_json::to_vec(&frame).map_err(|e| {
262 io::Error::new(
263 io::ErrorKind::InvalidData,
264 format!("Failed to serialize fragment: {}", e),
265 )
266 })?;
267
268 let mut fragment_buf = BytesMut::new();
270 self.inner
271 .encode(Bytes::from(serialized), &mut fragment_buf)?;
272
273 dst.extend_from_slice(&fragment_buf);
275 }
276
277 Ok(())
278 }
279 }
280}
281
282impl Decoder for FragmentingCodec {
283 type Item = Bytes;
284 type Error = io::Error;
285
286 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
287 self.shared_state.cleanup_expired();
289
290 if let Some(frame_bytes) = self.inner.decode(src)? {
292 let frame: FrameType = serde_json::from_slice(&frame_bytes).map_err(|e| {
294 io::Error::new(
295 io::ErrorKind::InvalidData,
296 format!("Failed to deserialize frame: {}", e),
297 )
298 })?;
299
300 match frame {
301 FrameType::Complete(data) => {
302 Ok(Some(Bytes::from(data)))
304 }
305 FrameType::Fragment(fragment) => {
306 let message_id = fragment.message_id;
308 let fragment_index = fragment.fragment_index;
309 let total_fragments = fragment.total_fragments;
310
311 debug!(
312 "Received fragment {}/{} for message {}",
313 fragment_index + 1,
314 total_fragments,
315 message_id
316 );
317
318 let fragment_data = BASE64.decode(&fragment.data).map_err(|e| {
320 io::Error::new(
321 io::ErrorKind::InvalidData,
322 format!("Failed to decode fragment data: {}", e),
323 )
324 })?;
325
326 let mut partial_messages = self.shared_state.partial_messages.lock().unwrap();
328 let partial = partial_messages
329 .entry(message_id)
330 .or_insert_with(|| PartialMessage::new(total_fragments));
331
332 partial.add_fragment(fragment_index, fragment_data);
334
335 if partial.is_complete() {
337 debug!("Message {} is complete, reassembling", message_id);
338
339 let partial = partial_messages.remove(&message_id).unwrap();
341 drop(partial_messages); let complete_data = partial.reassemble()?;
344 Ok(Some(Bytes::from(complete_data)))
345 } else {
346 debug!(
348 "Message {} still incomplete ({}/{} fragments)",
349 message_id,
350 partial.fragments.len(),
351 total_fragments
352 );
353 Ok(None)
354 }
355 }
356 }
357 } else {
358 Ok(None)
360 }
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use futures::{SinkExt, StreamExt};
368 use tokio::io::duplex;
369 use tokio_util::codec::{FramedRead, FramedWrite};
370
371 #[tokio::test]
372 async fn test_small_message_no_fragmentation() {
373 let (client, server) = duplex(1024);
374
375 let codec_write = FragmentingCodec::new();
376 let codec_read = FragmentingCodec::new();
377
378 let mut writer = FramedWrite::new(client, codec_write);
379 let mut reader = FramedRead::new(server, codec_read);
380
381 let test_data = b"Hello, World!";
382
383 writer.send(Bytes::from(&test_data[..])).await.unwrap();
385 drop(writer); let received = reader.next().await.unwrap().unwrap();
389 assert_eq!(received.as_ref(), test_data);
390 }
391
392 #[tokio::test]
393 async fn test_large_message_fragmentation() {
394 let (client, server) = duplex(64 * 1024 * 1024); let codec_write = FragmentingCodec::new();
397 let codec_read = FragmentingCodec::new();
398
399 let mut writer = FramedWrite::new(client, codec_write);
400 let mut reader = FramedRead::new(server, codec_read);
401
402 let test_data = vec![0xAB; MAX_FRAGMENT_DATA_SIZE + 1000];
404
405 match writer.send(Bytes::from(test_data.clone())).await {
407 Ok(_) => println!("Successfully sent large message"),
408 Err(e) => {
409 println!("Error sending: {:?}", e);
410 panic!("Failed to send: {}", e);
411 }
412 }
413 drop(writer); let received = reader.next().await.unwrap().unwrap();
417 assert_eq!(received.as_ref(), &test_data[..]);
418 }
419
420 #[test]
421 fn test_fragment_message() {
422 let codec = FragmentingCodec::new();
423 let data = vec![0x42; MAX_FRAGMENT_DATA_SIZE + 500];
424
425 let fragments = codec.fragment_message(&data);
426
427 assert_eq!(fragments.len(), 2);
428 assert_eq!(fragments[0].fragment_index, 0);
429 assert_eq!(fragments[1].fragment_index, 1);
430 assert_eq!(fragments[0].total_fragments, 2);
431 assert_eq!(fragments[1].total_fragments, 2);
432 assert_eq!(fragments[0].message_id, fragments[1].message_id);
433
434 let mut reassembled = Vec::new();
436 let decoded_0 = BASE64.decode(&fragments[0].data).unwrap();
437 let decoded_1 = BASE64.decode(&fragments[1].data).unwrap();
438 reassembled.extend_from_slice(&decoded_0);
439 reassembled.extend_from_slice(&decoded_1);
440 assert_eq!(reassembled, data);
441 }
442
443 #[test]
444 fn test_partial_message_assembly() {
445 let mut partial = PartialMessage::new(3);
446
447 assert!(!partial.is_complete());
448
449 partial.add_fragment(0, vec![1, 2, 3]);
450 partial.add_fragment(2, vec![7, 8, 9]);
451 assert!(!partial.is_complete());
452
453 partial.add_fragment(1, vec![4, 5, 6]);
454 assert!(partial.is_complete());
455
456 let reassembled = partial.reassemble().unwrap();
457 assert_eq!(reassembled, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
458 }
459}