1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use bytes::{Bytes, BytesMut};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::io;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::{Duration, Instant};
13use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
14use tracing::{debug, warn};
15
16const MAX_FRAGMENT_DATA_SIZE: usize = 12 * 1024 * 1024;
19
20const FRAGMENT_TIMEOUT: Duration = Duration::from_secs(30);
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25struct Fragment {
26 message_id: u64,
28 fragment_index: u32,
30 total_fragments: u32,
32 data: String,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38enum FrameType {
39 Complete(Vec<u8>),
41 Fragment(Fragment),
43}
44
45#[derive(Debug)]
47struct PartialMessage {
48 created_at: Instant,
50 total_fragments: u32,
52 fragments: HashMap<u32, Vec<u8>>,
54}
55
56impl PartialMessage {
57 fn new(total_fragments: u32) -> Self {
58 Self {
59 created_at: Instant::now(),
60 total_fragments,
61 fragments: HashMap::new(),
62 }
63 }
64
65 fn add_fragment(&mut self, index: u32, data: Vec<u8>) {
66 self.fragments.insert(index, data);
67 }
68
69 fn is_complete(&self) -> bool {
70 self.fragments.len() == self.total_fragments as usize
71 }
72
73 fn is_expired(&self) -> bool {
74 self.created_at.elapsed() > FRAGMENT_TIMEOUT
75 }
76
77 fn reassemble(self) -> io::Result<Vec<u8>> {
78 if !self.is_complete() {
79 return Err(io::Error::new(
80 io::ErrorKind::InvalidData,
81 "Cannot reassemble incomplete message",
82 ));
83 }
84
85 let mut result = Vec::new();
86
87 for i in 0..self.total_fragments {
89 if let Some(fragment_data) = self.fragments.get(&i) {
90 result.extend_from_slice(fragment_data);
91 } else {
92 return Err(io::Error::new(
93 io::ErrorKind::InvalidData,
94 format!("Missing fragment {}", i),
95 ));
96 }
97 }
98
99 Ok(result)
100 }
101}
102
103#[derive(Debug)]
105pub struct FragmentingCodec {
106 inner: LengthDelimitedCodec,
108 next_message_id: AtomicU64,
110 partial_messages: HashMap<u64, PartialMessage>,
112 last_cleanup: Instant,
114}
115
116impl FragmentingCodec {
117 pub fn new() -> Self {
119 let mut inner = LengthDelimitedCodec::new();
120 inner.set_max_frame_length(32 * 1024 * 1024); Self {
123 inner,
124 next_message_id: AtomicU64::new(1),
125 partial_messages: HashMap::new(),
126 last_cleanup: Instant::now(),
127 }
128 }
129
130 fn next_message_id(&self) -> u64 {
132 self.next_message_id.fetch_add(1, Ordering::Relaxed)
133 }
134
135 fn cleanup_expired(&mut self) {
137 if self.last_cleanup.elapsed() < Duration::from_secs(10) {
139 return;
140 }
141
142 let before_count = self.partial_messages.len();
143 self.partial_messages.retain(|message_id, partial| {
144 if partial.is_expired() {
145 warn!("Cleaning up expired partial message {}", message_id);
146 false
147 } else {
148 true
149 }
150 });
151
152 let cleaned = before_count - self.partial_messages.len();
153 if cleaned > 0 {
154 debug!("Cleaned up {} expired partial messages", cleaned);
155 }
156
157 self.last_cleanup = Instant::now();
158 }
159
160 fn fragment_message(&self, data: &[u8]) -> Vec<Fragment> {
162 let message_id = self.next_message_id();
163 let total_size = data.len();
164
165 let chunk_size = MAX_FRAGMENT_DATA_SIZE;
167
168 let total_fragments = (total_size + chunk_size - 1) / chunk_size;
170
171 debug!(
172 "Fragmenting message {} into {} fragments (total size: {} bytes, chunk size: {} bytes)",
173 message_id, total_fragments, total_size, chunk_size
174 );
175
176 let mut fragments = Vec::new();
177
178 for (i, chunk) in data.chunks(chunk_size).enumerate() {
179 let fragment = Fragment {
180 message_id,
181 fragment_index: i as u32,
182 total_fragments: total_fragments as u32,
183 data: BASE64.encode(chunk),
184 };
185
186 if let Ok(serialized) = serde_json::to_vec(&FrameType::Fragment(fragment.clone())) {
188 debug!("Fragment {} serialized size: {} bytes", i, serialized.len());
189 if serialized.len() > 31 * 1024 * 1024 {
190 warn!("Fragment {} serialized size ({} bytes) is dangerously close to frame limit", i, serialized.len());
192 }
193 }
194
195 fragments.push(fragment);
196 }
197
198 fragments
199 }
200}
201
202impl Default for FragmentingCodec {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208impl Encoder<Bytes> for FragmentingCodec {
209 type Error = io::Error;
210
211 fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
212 let data = item.to_vec();
213
214 if data.len() <= MAX_FRAGMENT_DATA_SIZE {
216 let frame = FrameType::Complete(data);
218 let serialized = serde_json::to_vec(&frame).map_err(|e| {
219 io::Error::new(
220 io::ErrorKind::InvalidData,
221 format!("Failed to serialize frame: {}", e),
222 )
223 })?;
224
225 self.inner.encode(Bytes::from(serialized), dst)
226 } else {
227 let fragments = self.fragment_message(&data);
229
230 for fragment in fragments {
232 let frame = FrameType::Fragment(fragment);
233 let serialized = serde_json::to_vec(&frame).map_err(|e| {
234 io::Error::new(
235 io::ErrorKind::InvalidData,
236 format!("Failed to serialize fragment: {}", e),
237 )
238 })?;
239
240 let mut fragment_buf = BytesMut::new();
242 self.inner
243 .encode(Bytes::from(serialized), &mut fragment_buf)?;
244
245 dst.extend_from_slice(&fragment_buf);
247 }
248
249 Ok(())
250 }
251 }
252}
253
254impl Decoder for FragmentingCodec {
255 type Item = Bytes;
256 type Error = io::Error;
257
258 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
259 self.cleanup_expired();
261
262 if let Some(frame_bytes) = self.inner.decode(src)? {
264 let frame: FrameType = serde_json::from_slice(&frame_bytes).map_err(|e| {
266 io::Error::new(
267 io::ErrorKind::InvalidData,
268 format!("Failed to deserialize frame: {}", e),
269 )
270 })?;
271
272 match frame {
273 FrameType::Complete(data) => {
274 Ok(Some(Bytes::from(data)))
276 }
277 FrameType::Fragment(fragment) => {
278 let message_id = fragment.message_id;
280 let fragment_index = fragment.fragment_index;
281 let total_fragments = fragment.total_fragments;
282
283 debug!(
284 "Received fragment {}/{} for message {}",
285 fragment_index + 1,
286 total_fragments,
287 message_id
288 );
289
290 let fragment_data = BASE64.decode(&fragment.data).map_err(|e| {
292 io::Error::new(
293 io::ErrorKind::InvalidData,
294 format!("Failed to decode fragment data: {}", e),
295 )
296 })?;
297
298 let partial = self
300 .partial_messages
301 .entry(message_id)
302 .or_insert_with(|| PartialMessage::new(total_fragments));
303
304 partial.add_fragment(fragment_index, fragment_data);
306
307 if partial.is_complete() {
309 debug!("Message {} is complete, reassembling", message_id);
310
311 let partial = self.partial_messages.remove(&message_id).unwrap();
313 let complete_data = partial.reassemble()?;
314
315 Ok(Some(Bytes::from(complete_data)))
316 } else {
317 debug!(
319 "Message {} still incomplete ({}/{} fragments)",
320 message_id,
321 partial.fragments.len(),
322 total_fragments
323 );
324 Ok(None)
325 }
326 }
327 }
328 } else {
329 Ok(None)
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use futures::{SinkExt, StreamExt};
339 use tokio::io::duplex;
340 use tokio_util::codec::{FramedRead, FramedWrite};
341
342 #[tokio::test]
343 async fn test_small_message_no_fragmentation() {
344 let (client, server) = duplex(1024);
345
346 let codec_write = FragmentingCodec::new();
347 let codec_read = FragmentingCodec::new();
348
349 let mut writer = FramedWrite::new(client, codec_write);
350 let mut reader = FramedRead::new(server, codec_read);
351
352 let test_data = b"Hello, World!";
353
354 writer.send(Bytes::from(&test_data[..])).await.unwrap();
356 drop(writer); let received = reader.next().await.unwrap().unwrap();
360 assert_eq!(received.as_ref(), test_data);
361 }
362
363 #[tokio::test]
364 async fn test_large_message_fragmentation() {
365 let (client, server) = duplex(64 * 1024 * 1024); let codec_write = FragmentingCodec::new();
368 let codec_read = FragmentingCodec::new();
369
370 let mut writer = FramedWrite::new(client, codec_write);
371 let mut reader = FramedRead::new(server, codec_read);
372
373 let test_data = vec![0xAB; MAX_FRAGMENT_DATA_SIZE + 1000];
375
376 match writer.send(Bytes::from(test_data.clone())).await {
378 Ok(_) => println!("Successfully sent large message"),
379 Err(e) => {
380 println!("Error sending: {:?}", e);
381 panic!("Failed to send: {}", e);
382 }
383 }
384 drop(writer); let received = reader.next().await.unwrap().unwrap();
388 assert_eq!(received.as_ref(), &test_data[..]);
389 }
390
391 #[test]
392 fn test_fragment_message() {
393 let codec = FragmentingCodec::new();
394 let data = vec![0x42; MAX_FRAGMENT_DATA_SIZE + 500];
395
396 let fragments = codec.fragment_message(&data);
397
398 assert_eq!(fragments.len(), 2);
399 assert_eq!(fragments[0].fragment_index, 0);
400 assert_eq!(fragments[1].fragment_index, 1);
401 assert_eq!(fragments[0].total_fragments, 2);
402 assert_eq!(fragments[1].total_fragments, 2);
403 assert_eq!(fragments[0].message_id, fragments[1].message_id);
404
405 let mut reassembled = Vec::new();
407 let decoded_0 = BASE64.decode(&fragments[0].data).unwrap();
408 let decoded_1 = BASE64.decode(&fragments[1].data).unwrap();
409 reassembled.extend_from_slice(&decoded_0);
410 reassembled.extend_from_slice(&decoded_1);
411 assert_eq!(reassembled, data);
412 }
413
414 #[test]
415 fn test_partial_message_assembly() {
416 let mut partial = PartialMessage::new(3);
417
418 assert!(!partial.is_complete());
419
420 partial.add_fragment(0, vec![1, 2, 3]);
421 partial.add_fragment(2, vec![7, 8, 9]);
422 assert!(!partial.is_complete());
423
424 partial.add_fragment(1, vec![4, 5, 6]);
425 assert!(partial.is_complete());
426
427 let reassembled = partial.reassemble().unwrap();
428 assert_eq!(reassembled, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
429 }
430}