theater_server/
fragmenting_codec.rs

1//! # Fragmenting Codec
2//!
3//! A codec that transparently handles fragmentation of large messages.
4//! Built on top of LengthDelimitedCodec to maintain compatibility with existing Theater infrastructure.
5
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use bytes::{Bytes, BytesMut};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::io;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::{Duration, Instant};
13use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
14use tracing::{debug, warn};
15
16/// Maximum size for a single fragment data (12MB)
17/// This leaves room for JSON serialization overhead while staying well under the 32MB frame limit
18const MAX_FRAGMENT_DATA_SIZE: usize = 12 * 1024 * 1024;
19
20/// How long to keep partial messages before timing out (30 seconds)
21const FRAGMENT_TIMEOUT: Duration = Duration::from_secs(30);
22
23/// A single fragment of a larger message
24#[derive(Debug, Clone, Serialize, Deserialize)]
25struct Fragment {
26    /// Unique identifier for the complete message
27    message_id: u64,
28    /// Index of this fragment (0-based)
29    fragment_index: u32,
30    /// Total number of fragments for this message
31    total_fragments: u32,
32    /// The actual data chunk (base64 encoded for efficient JSON serialization)
33    data: String,
34}
35
36/// Internal wrapper to distinguish between complete messages and fragments
37#[derive(Debug, Clone, Serialize, Deserialize)]
38enum FrameType {
39    /// A complete message that doesn't need fragmentation
40    Complete(Vec<u8>),
41    /// A fragment of a larger message
42    Fragment(Fragment),
43}
44
45/// Partial message being reassembled
46#[derive(Debug)]
47struct PartialMessage {
48    /// When this partial message was first created
49    created_at: Instant,
50    /// Total number of fragments expected
51    total_fragments: u32,
52    /// Fragments received so far, indexed by fragment_index
53    fragments: HashMap<u32, Vec<u8>>,
54}
55
56impl PartialMessage {
57    fn new(total_fragments: u32) -> Self {
58        Self {
59            created_at: Instant::now(),
60            total_fragments,
61            fragments: HashMap::new(),
62        }
63    }
64
65    fn add_fragment(&mut self, index: u32, data: Vec<u8>) {
66        self.fragments.insert(index, data);
67    }
68
69    fn is_complete(&self) -> bool {
70        self.fragments.len() == self.total_fragments as usize
71    }
72
73    fn is_expired(&self) -> bool {
74        self.created_at.elapsed() > FRAGMENT_TIMEOUT
75    }
76
77    fn reassemble(self) -> io::Result<Vec<u8>> {
78        if !self.is_complete() {
79            return Err(io::Error::new(
80                io::ErrorKind::InvalidData,
81                "Cannot reassemble incomplete message",
82            ));
83        }
84
85        let mut result = Vec::new();
86
87        // Reassemble fragments in order
88        for i in 0..self.total_fragments {
89            if let Some(fragment_data) = self.fragments.get(&i) {
90                result.extend_from_slice(fragment_data);
91            } else {
92                return Err(io::Error::new(
93                    io::ErrorKind::InvalidData,
94                    format!("Missing fragment {}", i),
95                ));
96            }
97        }
98
99        Ok(result)
100    }
101}
102
103/// A codec that transparently handles message fragmentation
104#[derive(Debug)]
105pub struct FragmentingCodec {
106    /// Underlying length-delimited codec
107    inner: LengthDelimitedCodec,
108    /// Counter for generating unique message IDs
109    next_message_id: AtomicU64,
110    /// Partial messages being reassembled (keyed by message_id)
111    partial_messages: HashMap<u64, PartialMessage>,
112    /// Last time we cleaned up expired partial messages
113    last_cleanup: Instant,
114}
115
116impl FragmentingCodec {
117    /// Create a new fragmenting codec with the same configuration as Theater's current setup
118    pub fn new() -> Self {
119        let mut inner = LengthDelimitedCodec::new();
120        inner.set_max_frame_length(32 * 1024 * 1024); // 32MB max frame
121
122        Self {
123            inner,
124            next_message_id: AtomicU64::new(1),
125            partial_messages: HashMap::new(),
126            last_cleanup: Instant::now(),
127        }
128    }
129
130    /// Generate the next unique message ID
131    fn next_message_id(&self) -> u64 {
132        self.next_message_id.fetch_add(1, Ordering::Relaxed)
133    }
134
135    /// Clean up expired partial messages
136    fn cleanup_expired(&mut self) {
137        // Only cleanup periodically to avoid overhead
138        if self.last_cleanup.elapsed() < Duration::from_secs(10) {
139            return;
140        }
141
142        let before_count = self.partial_messages.len();
143        self.partial_messages.retain(|message_id, partial| {
144            if partial.is_expired() {
145                warn!("Cleaning up expired partial message {}", message_id);
146                false
147            } else {
148                true
149            }
150        });
151
152        let cleaned = before_count - self.partial_messages.len();
153        if cleaned > 0 {
154            debug!("Cleaned up {} expired partial messages", cleaned);
155        }
156
157        self.last_cleanup = Instant::now();
158    }
159
160    /// Fragment a large message into chunks
161    fn fragment_message(&self, data: &[u8]) -> Vec<Fragment> {
162        let message_id = self.next_message_id();
163        let total_size = data.len();
164
165        // Use the defined chunk size constant
166        let chunk_size = MAX_FRAGMENT_DATA_SIZE;
167
168        // Calculate how many fragments we need
169        let total_fragments = (total_size + chunk_size - 1) / chunk_size;
170
171        debug!(
172            "Fragmenting message {} into {} fragments (total size: {} bytes, chunk size: {} bytes)",
173            message_id, total_fragments, total_size, chunk_size
174        );
175
176        let mut fragments = Vec::new();
177
178        for (i, chunk) in data.chunks(chunk_size).enumerate() {
179            let fragment = Fragment {
180                message_id,
181                fragment_index: i as u32,
182                total_fragments: total_fragments as u32,
183                data: BASE64.encode(chunk),
184            };
185
186            // Debug: check serialized size to ensure it's under the frame limit
187            if let Ok(serialized) = serde_json::to_vec(&FrameType::Fragment(fragment.clone())) {
188                debug!("Fragment {} serialized size: {} bytes", i, serialized.len());
189                if serialized.len() > 31 * 1024 * 1024 {
190                    // Close to 32MB limit
191                    warn!("Fragment {} serialized size ({} bytes) is dangerously close to frame limit", i, serialized.len());
192                }
193            }
194
195            fragments.push(fragment);
196        }
197
198        fragments
199    }
200}
201
202impl Default for FragmentingCodec {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208impl Encoder<Bytes> for FragmentingCodec {
209    type Error = io::Error;
210
211    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
212        let data = item.to_vec();
213
214        // Check if we need to fragment this message
215        if data.len() <= MAX_FRAGMENT_DATA_SIZE {
216            // Small message - send as complete
217            let frame = FrameType::Complete(data);
218            let serialized = serde_json::to_vec(&frame).map_err(|e| {
219                io::Error::new(
220                    io::ErrorKind::InvalidData,
221                    format!("Failed to serialize frame: {}", e),
222                )
223            })?;
224
225            self.inner.encode(Bytes::from(serialized), dst)
226        } else {
227            // Large message - fragment it
228            let fragments = self.fragment_message(&data);
229
230            // Encode each fragment into the destination buffer
231            for fragment in fragments {
232                let frame = FrameType::Fragment(fragment);
233                let serialized = serde_json::to_vec(&frame).map_err(|e| {
234                    io::Error::new(
235                        io::ErrorKind::InvalidData,
236                        format!("Failed to serialize fragment: {}", e),
237                    )
238                })?;
239
240                // Create a temporary buffer for this fragment
241                let mut fragment_buf = BytesMut::new();
242                self.inner
243                    .encode(Bytes::from(serialized), &mut fragment_buf)?;
244
245                // Append to the main destination buffer
246                dst.extend_from_slice(&fragment_buf);
247            }
248
249            Ok(())
250        }
251    }
252}
253
254impl Decoder for FragmentingCodec {
255    type Item = Bytes;
256    type Error = io::Error;
257
258    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
259        // Clean up expired messages periodically
260        self.cleanup_expired();
261
262        // Try to decode a frame from the underlying codec
263        if let Some(frame_bytes) = self.inner.decode(src)? {
264            // Deserialize the frame
265            let frame: FrameType = serde_json::from_slice(&frame_bytes).map_err(|e| {
266                io::Error::new(
267                    io::ErrorKind::InvalidData,
268                    format!("Failed to deserialize frame: {}", e),
269                )
270            })?;
271
272            match frame {
273                FrameType::Complete(data) => {
274                    // Complete message - return immediately
275                    Ok(Some(Bytes::from(data)))
276                }
277                FrameType::Fragment(fragment) => {
278                    // Fragment - add to partial message
279                    let message_id = fragment.message_id;
280                    let fragment_index = fragment.fragment_index;
281                    let total_fragments = fragment.total_fragments;
282
283                    debug!(
284                        "Received fragment {}/{} for message {}",
285                        fragment_index + 1,
286                        total_fragments,
287                        message_id
288                    );
289
290                    // Decode the base64 data
291                    let fragment_data = BASE64.decode(&fragment.data).map_err(|e| {
292                        io::Error::new(
293                            io::ErrorKind::InvalidData,
294                            format!("Failed to decode fragment data: {}", e),
295                        )
296                    })?;
297
298                    // Get or create partial message
299                    let partial = self
300                        .partial_messages
301                        .entry(message_id)
302                        .or_insert_with(|| PartialMessage::new(total_fragments));
303
304                    // Add this fragment
305                    partial.add_fragment(fragment_index, fragment_data);
306
307                    // Check if message is complete
308                    if partial.is_complete() {
309                        debug!("Message {} is complete, reassembling", message_id);
310
311                        // Remove from partial messages and reassemble
312                        let partial = self.partial_messages.remove(&message_id).unwrap();
313                        let complete_data = partial.reassemble()?;
314
315                        Ok(Some(Bytes::from(complete_data)))
316                    } else {
317                        // Still waiting for more fragments
318                        debug!(
319                            "Message {} still incomplete ({}/{} fragments)",
320                            message_id,
321                            partial.fragments.len(),
322                            total_fragments
323                        );
324                        Ok(None)
325                    }
326                }
327            }
328        } else {
329            // No complete frame available yet
330            Ok(None)
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use futures::{SinkExt, StreamExt};
339    use tokio::io::duplex;
340    use tokio_util::codec::{FramedRead, FramedWrite};
341
342    #[tokio::test]
343    async fn test_small_message_no_fragmentation() {
344        let (client, server) = duplex(1024);
345
346        let codec_write = FragmentingCodec::new();
347        let codec_read = FragmentingCodec::new();
348
349        let mut writer = FramedWrite::new(client, codec_write);
350        let mut reader = FramedRead::new(server, codec_read);
351
352        let test_data = b"Hello, World!";
353
354        // Send small message
355        writer.send(Bytes::from(&test_data[..])).await.unwrap();
356        drop(writer); // Close writer
357
358        // Receive should get the same data
359        let received = reader.next().await.unwrap().unwrap();
360        assert_eq!(received.as_ref(), test_data);
361    }
362
363    #[tokio::test]
364    async fn test_large_message_fragmentation() {
365        let (client, server) = duplex(64 * 1024 * 1024); // Large buffer
366
367        let codec_write = FragmentingCodec::new();
368        let codec_read = FragmentingCodec::new();
369
370        let mut writer = FramedWrite::new(client, codec_write);
371        let mut reader = FramedRead::new(server, codec_read);
372
373        // Create a message larger than MAX_FRAGMENT_DATA_SIZE
374        let test_data = vec![0xAB; MAX_FRAGMENT_DATA_SIZE + 1000];
375
376        // Send large message
377        match writer.send(Bytes::from(test_data.clone())).await {
378            Ok(_) => println!("Successfully sent large message"),
379            Err(e) => {
380                println!("Error sending: {:?}", e);
381                panic!("Failed to send: {}", e);
382            }
383        }
384        drop(writer); // Close writer
385
386        // Receive should get the same data
387        let received = reader.next().await.unwrap().unwrap();
388        assert_eq!(received.as_ref(), &test_data[..]);
389    }
390
391    #[test]
392    fn test_fragment_message() {
393        let codec = FragmentingCodec::new();
394        let data = vec![0x42; MAX_FRAGMENT_DATA_SIZE + 500];
395
396        let fragments = codec.fragment_message(&data);
397
398        assert_eq!(fragments.len(), 2);
399        assert_eq!(fragments[0].fragment_index, 0);
400        assert_eq!(fragments[1].fragment_index, 1);
401        assert_eq!(fragments[0].total_fragments, 2);
402        assert_eq!(fragments[1].total_fragments, 2);
403        assert_eq!(fragments[0].message_id, fragments[1].message_id);
404
405        // Check data integrity
406        let mut reassembled = Vec::new();
407        let decoded_0 = BASE64.decode(&fragments[0].data).unwrap();
408        let decoded_1 = BASE64.decode(&fragments[1].data).unwrap();
409        reassembled.extend_from_slice(&decoded_0);
410        reassembled.extend_from_slice(&decoded_1);
411        assert_eq!(reassembled, data);
412    }
413
414    #[test]
415    fn test_partial_message_assembly() {
416        let mut partial = PartialMessage::new(3);
417
418        assert!(!partial.is_complete());
419
420        partial.add_fragment(0, vec![1, 2, 3]);
421        partial.add_fragment(2, vec![7, 8, 9]);
422        assert!(!partial.is_complete());
423
424        partial.add_fragment(1, vec![4, 5, 6]);
425        assert!(partial.is_complete());
426
427        let reassembled = partial.reassemble().unwrap();
428        assert_eq!(reassembled, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
429    }
430}