runtara_protocol/
frame.rs

1// Copyright (C) 2025 SyncMyOrders Sp. z o.o.
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//! Wire format for QUIC stream framing.
4//!
5//! Each QUIC stream carries one RPC call with the following frame format:
6//! - 4 bytes: message length (big-endian)
7//! - 2 bytes: message type
8//! - N bytes: protobuf payload
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use prost::Message;
12use thiserror::Error;
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14
15/// Maximum frame size (64 MB)
16/// Increased to accommodate large compiled workflow binaries
17pub const MAX_FRAME_SIZE: usize = 64 * 1024 * 1024;
18
19/// Frame header size (4 bytes length + 2 bytes type)
20pub const HEADER_SIZE: usize = 6;
21
22/// Message types for the wire protocol
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24#[repr(u16)]
25pub enum MessageType {
26    /// Request message
27    Request = 1,
28    /// Response message
29    Response = 2,
30    /// Start of a streaming response
31    StreamStart = 3,
32    /// Data chunk in a streaming response
33    StreamData = 4,
34    /// End of a streaming response
35    StreamEnd = 5,
36    /// Error response
37    Error = 6,
38}
39
40impl TryFrom<u16> for MessageType {
41    type Error = FrameError;
42
43    fn try_from(value: u16) -> Result<Self, <Self as TryFrom<u16>>::Error> {
44        match value {
45            1 => Ok(MessageType::Request),
46            2 => Ok(MessageType::Response),
47            3 => Ok(MessageType::StreamStart),
48            4 => Ok(MessageType::StreamData),
49            5 => Ok(MessageType::StreamEnd),
50            6 => Ok(MessageType::Error),
51            _ => Err(FrameError::InvalidMessageType(value)),
52        }
53    }
54}
55
56/// Errors that can occur during frame encoding/decoding
57#[derive(Debug, Error)]
58pub enum FrameError {
59    #[error("frame too large: {0} bytes (max: {MAX_FRAME_SIZE})")]
60    FrameTooLarge(usize),
61
62    #[error("invalid message type: {0}")]
63    InvalidMessageType(u16),
64
65    #[error("IO error: {0}")]
66    Io(#[from] std::io::Error),
67
68    #[error("protobuf decode error: {0}")]
69    Decode(#[from] prost::DecodeError),
70
71    #[error("connection closed")]
72    ConnectionClosed,
73}
74
75/// A framed message with type and payload
76#[derive(Debug, Clone)]
77pub struct Frame {
78    pub message_type: MessageType,
79    pub payload: Bytes,
80}
81
82impl Frame {
83    /// Create a new request frame
84    pub fn request<M: Message>(msg: &M) -> Result<Self, FrameError> {
85        Self::new(MessageType::Request, msg)
86    }
87
88    /// Create a new response frame
89    pub fn response<M: Message>(msg: &M) -> Result<Self, FrameError> {
90        Self::new(MessageType::Response, msg)
91    }
92
93    /// Create a new error frame
94    pub fn error<M: Message>(msg: &M) -> Result<Self, FrameError> {
95        Self::new(MessageType::Error, msg)
96    }
97
98    /// Create a new stream data frame
99    pub fn stream_data<M: Message>(msg: &M) -> Result<Self, FrameError> {
100        Self::new(MessageType::StreamData, msg)
101    }
102
103    /// Create a new frame with the given type and message
104    pub fn new<M: Message>(message_type: MessageType, msg: &M) -> Result<Self, FrameError> {
105        let payload = msg.encode_to_vec();
106        if payload.len() > MAX_FRAME_SIZE {
107            return Err(FrameError::FrameTooLarge(payload.len()));
108        }
109        Ok(Self {
110            message_type,
111            payload: Bytes::from(payload),
112        })
113    }
114
115    /// Decode the payload as a protobuf message
116    pub fn decode<M: Message + Default>(&self) -> Result<M, FrameError> {
117        Ok(M::decode(self.payload.clone())?)
118    }
119
120    /// Encode the frame to bytes for wire transmission
121    pub fn encode(&self) -> Bytes {
122        let mut buf = BytesMut::with_capacity(HEADER_SIZE + self.payload.len());
123        buf.put_u32(self.payload.len() as u32);
124        buf.put_u16(self.message_type as u16);
125        buf.put(self.payload.clone());
126        buf.freeze()
127    }
128
129    /// Decode a frame from bytes
130    pub fn decode_from_bytes(mut bytes: Bytes) -> Result<Self, FrameError> {
131        if bytes.len() < HEADER_SIZE {
132            return Err(FrameError::Io(std::io::Error::new(
133                std::io::ErrorKind::UnexpectedEof,
134                "incomplete frame header",
135            )));
136        }
137
138        let length = bytes.get_u32() as usize;
139        let message_type = MessageType::try_from(bytes.get_u16())?;
140
141        if length > MAX_FRAME_SIZE {
142            return Err(FrameError::FrameTooLarge(length));
143        }
144
145        if bytes.len() < length {
146            return Err(FrameError::Io(std::io::Error::new(
147                std::io::ErrorKind::UnexpectedEof,
148                "incomplete frame payload",
149            )));
150        }
151
152        let payload = bytes.split_to(length);
153        Ok(Self {
154            message_type,
155            payload,
156        })
157    }
158}
159
160/// Write a frame to an async writer
161pub async fn write_frame<W: AsyncWrite + Unpin>(
162    writer: &mut W,
163    frame: &Frame,
164) -> Result<(), FrameError> {
165    let encoded = frame.encode();
166    writer.write_all(&encoded).await?;
167    Ok(())
168}
169
170/// Read a frame from an async reader
171pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Frame, FrameError> {
172    // Read header
173    let mut header = [0u8; HEADER_SIZE];
174    match reader.read_exact(&mut header).await {
175        Ok(_) => {}
176        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
177            return Err(FrameError::ConnectionClosed);
178        }
179        Err(e) => return Err(e.into()),
180    }
181
182    let length = u32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize;
183    let message_type = MessageType::try_from(u16::from_be_bytes([header[4], header[5]]))?;
184
185    if length > MAX_FRAME_SIZE {
186        return Err(FrameError::FrameTooLarge(length));
187    }
188
189    // Read payload
190    let mut payload = vec![0u8; length];
191    reader.read_exact(&mut payload).await?;
192
193    Ok(Frame {
194        message_type,
195        payload: Bytes::from(payload),
196    })
197}
198
199/// Framed codec for encoding/decoding frames on a stream
200pub struct FramedStream<S> {
201    stream: S,
202}
203
204impl<S> FramedStream<S> {
205    pub fn new(stream: S) -> Self {
206        Self { stream }
207    }
208
209    pub fn into_inner(self) -> S {
210        self.stream
211    }
212}
213
214impl<S: AsyncRead + Unpin> FramedStream<S> {
215    /// Read the next frame from the stream
216    pub async fn read_frame(&mut self) -> Result<Frame, FrameError> {
217        read_frame(&mut self.stream).await
218    }
219}
220
221impl<S: AsyncWrite + Unpin> FramedStream<S> {
222    /// Write a frame to the stream
223    pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), FrameError> {
224        write_frame(&mut self.stream, frame).await
225    }
226}
227
228impl<S: AsyncRead + AsyncWrite + Unpin> FramedStream<S> {
229    /// Send a request and wait for a response
230    pub async fn request<Req: Message, Resp: Message + Default>(
231        &mut self,
232        request: &Req,
233    ) -> Result<Resp, FrameError> {
234        let frame = Frame::request(request)?;
235        self.write_frame(&frame).await?;
236
237        let response_frame = self.read_frame().await?;
238        match response_frame.message_type {
239            MessageType::Response => response_frame.decode(),
240            MessageType::Error => {
241                // Try to decode as error message
242                Err(FrameError::Io(std::io::Error::other(
243                    "received error response",
244                )))
245            }
246            _ => Err(FrameError::Io(std::io::Error::new(
247                std::io::ErrorKind::InvalidData,
248                "unexpected message type",
249            ))),
250        }
251    }
252
253    /// Send a response
254    pub async fn respond<Resp: Message>(&mut self, response: &Resp) -> Result<(), FrameError> {
255        let frame = Frame::response(response)?;
256        self.write_frame(&frame).await
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_message_type_round_trip() {
266        for &mt in &[
267            MessageType::Request,
268            MessageType::Response,
269            MessageType::StreamStart,
270            MessageType::StreamData,
271            MessageType::StreamEnd,
272            MessageType::Error,
273        ] {
274            let value = mt as u16;
275            let decoded = MessageType::try_from(value).unwrap();
276            assert_eq!(mt, decoded);
277        }
278    }
279
280    #[test]
281    fn test_frame_encode_decode() {
282        use crate::management_proto::HealthCheckRequest;
283
284        let msg = HealthCheckRequest {};
285        let frame = Frame::request(&msg).unwrap();
286        let encoded = frame.encode();
287        let decoded = Frame::decode_from_bytes(encoded).unwrap();
288
289        assert_eq!(frame.message_type, decoded.message_type);
290        assert_eq!(frame.payload, decoded.payload);
291    }
292
293    // ========== Constants Tests ==========
294
295    #[test]
296    fn test_max_frame_size_constant() {
297        // MAX_FRAME_SIZE is 64 MB
298        assert_eq!(MAX_FRAME_SIZE, 64 * 1024 * 1024);
299    }
300
301    #[test]
302    fn test_header_size_constant() {
303        // HEADER_SIZE is 6 bytes: 4 bytes length + 2 bytes type
304        assert_eq!(HEADER_SIZE, 6);
305    }
306
307    // ========== MessageType Tests ==========
308
309    #[test]
310    fn test_message_type_values() {
311        assert_eq!(MessageType::Request as u16, 1);
312        assert_eq!(MessageType::Response as u16, 2);
313        assert_eq!(MessageType::StreamStart as u16, 3);
314        assert_eq!(MessageType::StreamData as u16, 4);
315        assert_eq!(MessageType::StreamEnd as u16, 5);
316        assert_eq!(MessageType::Error as u16, 6);
317    }
318
319    #[test]
320    fn test_message_type_conversions() {
321        assert_eq!(MessageType::try_from(1u16).unwrap(), MessageType::Request);
322        assert_eq!(MessageType::try_from(2u16).unwrap(), MessageType::Response);
323        assert_eq!(
324            MessageType::try_from(3u16).unwrap(),
325            MessageType::StreamStart
326        );
327        assert_eq!(
328            MessageType::try_from(4u16).unwrap(),
329            MessageType::StreamData
330        );
331        assert_eq!(MessageType::try_from(5u16).unwrap(), MessageType::StreamEnd);
332        assert_eq!(MessageType::try_from(6u16).unwrap(), MessageType::Error);
333    }
334
335    #[test]
336    fn test_message_type_invalid_conversion() {
337        assert!(MessageType::try_from(0u16).is_err());
338        assert!(MessageType::try_from(7u16).is_err());
339        assert!(MessageType::try_from(100u16).is_err());
340        assert!(MessageType::try_from(u16::MAX).is_err());
341    }
342
343    #[test]
344    fn test_message_type_debug() {
345        assert_eq!(format!("{:?}", MessageType::Request), "Request");
346        assert_eq!(format!("{:?}", MessageType::Response), "Response");
347        assert_eq!(format!("{:?}", MessageType::StreamStart), "StreamStart");
348        assert_eq!(format!("{:?}", MessageType::StreamData), "StreamData");
349        assert_eq!(format!("{:?}", MessageType::StreamEnd), "StreamEnd");
350        assert_eq!(format!("{:?}", MessageType::Error), "Error");
351    }
352
353    #[test]
354    fn test_message_type_clone_and_copy() {
355        let mt = MessageType::Request;
356        let cloned = mt.clone();
357        let copied: MessageType = mt;
358        assert_eq!(mt, cloned);
359        assert_eq!(mt, copied);
360    }
361
362    #[test]
363    fn test_message_type_equality() {
364        assert_eq!(MessageType::Request, MessageType::Request);
365        assert_ne!(MessageType::Request, MessageType::Response);
366    }
367
368    // ========== FrameError Tests ==========
369
370    #[test]
371    fn test_frame_error_display_frame_too_large() {
372        let err = FrameError::FrameTooLarge(100_000_000);
373        let msg = format!("{}", err);
374        assert!(msg.contains("frame too large"));
375        assert!(msg.contains("100000000"));
376        assert!(msg.contains(&MAX_FRAME_SIZE.to_string()));
377    }
378
379    #[test]
380    fn test_frame_error_display_invalid_message_type() {
381        let err = FrameError::InvalidMessageType(42);
382        let msg = format!("{}", err);
383        assert!(msg.contains("invalid message type"));
384        assert!(msg.contains("42"));
385    }
386
387    #[test]
388    fn test_frame_error_display_io() {
389        let io_err = std::io::Error::new(std::io::ErrorKind::Other, "test error");
390        let err = FrameError::Io(io_err);
391        let msg = format!("{}", err);
392        assert!(msg.contains("IO error"));
393    }
394
395    #[test]
396    fn test_frame_error_display_connection_closed() {
397        let err = FrameError::ConnectionClosed;
398        let msg = format!("{}", err);
399        assert!(msg.contains("connection closed"));
400    }
401
402    #[test]
403    fn test_frame_error_from_io_error() {
404        let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "pipe broken");
405        let frame_err: FrameError = io_err.into();
406        match frame_err {
407            FrameError::Io(_) => {}
408            _ => panic!("Expected FrameError::Io"),
409        }
410    }
411
412    // ========== Frame Creation Tests ==========
413
414    #[test]
415    fn test_frame_request_creation() {
416        use crate::management_proto::HealthCheckRequest;
417        let msg = HealthCheckRequest {};
418        let frame = Frame::request(&msg).unwrap();
419        assert_eq!(frame.message_type, MessageType::Request);
420    }
421
422    #[test]
423    fn test_frame_response_creation() {
424        use crate::management_proto::HealthCheckResponse;
425        let msg = HealthCheckResponse {
426            healthy: true,
427            version: "1.0.0".to_string(),
428            uptime_ms: 1000,
429            active_instances: 5,
430        };
431        let frame = Frame::response(&msg).unwrap();
432        assert_eq!(frame.message_type, MessageType::Response);
433    }
434
435    #[test]
436    fn test_frame_error_creation() {
437        use crate::management_proto::HealthCheckResponse;
438        let msg = HealthCheckResponse {
439            healthy: false,
440            version: "1.0.0".to_string(),
441            uptime_ms: 0,
442            active_instances: 0,
443        };
444        let frame = Frame::error(&msg).unwrap();
445        assert_eq!(frame.message_type, MessageType::Error);
446    }
447
448    #[test]
449    fn test_frame_stream_data_creation() {
450        use crate::management_proto::HealthCheckResponse;
451        let msg = HealthCheckResponse {
452            healthy: true,
453            version: "1.0.0".to_string(),
454            uptime_ms: 500,
455            active_instances: 2,
456        };
457        let frame = Frame::stream_data(&msg).unwrap();
458        assert_eq!(frame.message_type, MessageType::StreamData);
459    }
460
461    #[test]
462    fn test_frame_new_all_types() {
463        use crate::management_proto::HealthCheckRequest;
464        let msg = HealthCheckRequest {};
465
466        for &mt in &[
467            MessageType::Request,
468            MessageType::Response,
469            MessageType::StreamStart,
470            MessageType::StreamData,
471            MessageType::StreamEnd,
472            MessageType::Error,
473        ] {
474            let frame = Frame::new(mt, &msg).unwrap();
475            assert_eq!(frame.message_type, mt);
476        }
477    }
478
479    #[test]
480    fn test_frame_decode_payload() {
481        use crate::management_proto::HealthCheckResponse;
482        let original = HealthCheckResponse {
483            healthy: true,
484            version: "2.0.0".to_string(),
485            uptime_ms: 12345,
486            active_instances: 10,
487        };
488        let frame = Frame::response(&original).unwrap();
489        let decoded: HealthCheckResponse = frame.decode().unwrap();
490        assert!(decoded.healthy);
491        assert_eq!(decoded.version, "2.0.0");
492        assert_eq!(decoded.uptime_ms, 12345);
493        assert_eq!(decoded.active_instances, 10);
494    }
495
496    #[test]
497    fn test_frame_clone() {
498        use crate::management_proto::HealthCheckRequest;
499        let msg = HealthCheckRequest {};
500        let frame = Frame::request(&msg).unwrap();
501        let cloned = frame.clone();
502        assert_eq!(frame.message_type, cloned.message_type);
503        assert_eq!(frame.payload, cloned.payload);
504    }
505
506    #[test]
507    fn test_frame_debug() {
508        use crate::management_proto::HealthCheckRequest;
509        let msg = HealthCheckRequest {};
510        let frame = Frame::request(&msg).unwrap();
511        let debug_str = format!("{:?}", frame);
512        assert!(debug_str.contains("Frame"));
513        assert!(debug_str.contains("message_type"));
514        assert!(debug_str.contains("payload"));
515    }
516
517    // ========== Frame Encoding Tests ==========
518
519    #[test]
520    fn test_frame_encode_structure() {
521        use crate::management_proto::HealthCheckRequest;
522        let msg = HealthCheckRequest {};
523        let frame = Frame::request(&msg).unwrap();
524        let encoded = frame.encode();
525
526        // Check header: 4 bytes length + 2 bytes type
527        assert!(encoded.len() >= HEADER_SIZE);
528
529        // First 4 bytes should be the payload length (big-endian)
530        let length = u32::from_be_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]) as usize;
531        assert_eq!(length, frame.payload.len());
532
533        // Bytes 4-5 should be the message type
534        let msg_type = u16::from_be_bytes([encoded[4], encoded[5]]);
535        assert_eq!(msg_type, MessageType::Request as u16);
536
537        // Total length should be header + payload
538        assert_eq!(encoded.len(), HEADER_SIZE + frame.payload.len());
539    }
540
541    #[test]
542    fn test_frame_with_large_payload() {
543        use crate::instance_proto::CheckpointRequest;
544        // Create a checkpoint request with substantial data
545        let msg = CheckpointRequest {
546            instance_id: "test-instance".to_string(),
547            checkpoint_id: "checkpoint-123".to_string(),
548            state: vec![0u8; 1024 * 1024], // 1 MB of data
549        };
550        let frame = Frame::request(&msg).unwrap();
551        assert!(frame.payload.len() > 1024 * 1024);
552
553        // Should encode and decode correctly
554        let encoded = frame.encode();
555        let decoded = Frame::decode_from_bytes(encoded).unwrap();
556        assert_eq!(frame.payload, decoded.payload);
557    }
558
559    #[test]
560    fn test_frame_with_empty_payload() {
561        use crate::management_proto::HealthCheckRequest;
562        let msg = HealthCheckRequest {};
563        let frame = Frame::request(&msg).unwrap();
564        // HealthCheckRequest is empty, so payload should be minimal (just protobuf overhead)
565        assert!(frame.payload.len() <= 10);
566    }
567
568    // ========== decode_from_bytes Tests ==========
569
570    #[test]
571    fn test_decode_from_bytes_incomplete_header() {
572        let bytes = Bytes::from_static(&[0, 0, 0]); // Only 3 bytes, need 6
573        let result = Frame::decode_from_bytes(bytes);
574        assert!(result.is_err());
575        match result.unwrap_err() {
576            FrameError::Io(e) => {
577                assert!(e.to_string().contains("incomplete frame header"));
578            }
579            _ => panic!("Expected Io error with incomplete header message"),
580        }
581    }
582
583    #[test]
584    fn test_decode_from_bytes_incomplete_payload() {
585        // Header says 100 bytes payload, but we only have 10
586        let mut bytes = BytesMut::new();
587        bytes.put_u32(100); // length = 100
588        bytes.put_u16(1); // type = Request
589        bytes.put(&[0u8; 10][..]); // Only 10 bytes of payload
590
591        let result = Frame::decode_from_bytes(bytes.freeze());
592        assert!(result.is_err());
593        match result.unwrap_err() {
594            FrameError::Io(e) => {
595                assert!(e.to_string().contains("incomplete frame payload"));
596            }
597            _ => panic!("Expected Io error with incomplete payload message"),
598        }
599    }
600
601    #[test]
602    fn test_decode_from_bytes_invalid_message_type() {
603        let mut bytes = BytesMut::new();
604        bytes.put_u32(0); // length = 0
605        bytes.put_u16(99); // invalid type
606
607        let result = Frame::decode_from_bytes(bytes.freeze());
608        assert!(result.is_err());
609        match result.unwrap_err() {
610            FrameError::InvalidMessageType(99) => {}
611            _ => panic!("Expected InvalidMessageType error"),
612        }
613    }
614
615    #[test]
616    fn test_decode_from_bytes_frame_too_large() {
617        let mut bytes = BytesMut::new();
618        bytes.put_u32((MAX_FRAME_SIZE + 1) as u32); // Too large
619        bytes.put_u16(1); // type = Request
620
621        let result = Frame::decode_from_bytes(bytes.freeze());
622        assert!(result.is_err());
623        match result.unwrap_err() {
624            FrameError::FrameTooLarge(size) => {
625                assert_eq!(size, MAX_FRAME_SIZE + 1);
626            }
627            _ => panic!("Expected FrameTooLarge error"),
628        }
629    }
630
631    #[test]
632    fn test_decode_from_bytes_empty_payload() {
633        let mut bytes = BytesMut::new();
634        bytes.put_u32(0); // length = 0
635        bytes.put_u16(1); // type = Request
636
637        let result = Frame::decode_from_bytes(bytes.freeze());
638        assert!(result.is_ok());
639        let frame = result.unwrap();
640        assert_eq!(frame.message_type, MessageType::Request);
641        assert!(frame.payload.is_empty());
642    }
643
644    #[test]
645    fn test_decode_from_bytes_with_extra_data() {
646        // Create a valid frame followed by extra data
647        let mut bytes = BytesMut::new();
648        bytes.put_u32(5); // length = 5
649        bytes.put_u16(2); // type = Response
650        bytes.put(&[1, 2, 3, 4, 5][..]); // 5 bytes payload
651        bytes.put(&[99, 99, 99][..]); // Extra data (should be ignored)
652
653        let result = Frame::decode_from_bytes(bytes.freeze());
654        assert!(result.is_ok());
655        let frame = result.unwrap();
656        assert_eq!(frame.message_type, MessageType::Response);
657        assert_eq!(&frame.payload[..], &[1, 2, 3, 4, 5]);
658    }
659
660    // ========== Async read/write frame tests ==========
661
662    #[tokio::test]
663    async fn test_read_write_frame() {
664        use crate::management_proto::HealthCheckRequest;
665        use tokio::io::duplex;
666
667        let msg = HealthCheckRequest {};
668        let frame = Frame::request(&msg).unwrap();
669
670        // Create a duplex stream (in-memory bidirectional)
671        let (mut writer, mut reader) = duplex(1024);
672
673        // Write frame
674        write_frame(&mut writer, &frame).await.unwrap();
675
676        // Read frame back
677        let read_frame = read_frame(&mut reader).await.unwrap();
678        assert_eq!(frame.message_type, read_frame.message_type);
679        assert_eq!(frame.payload, read_frame.payload);
680    }
681
682    #[tokio::test]
683    async fn test_read_frame_connection_closed() {
684        use tokio::io::duplex;
685
686        let (_, mut reader) = duplex(1024);
687        // Writer is dropped, reader will get EOF
688
689        let result = read_frame(&mut reader).await;
690        assert!(result.is_err());
691        match result.unwrap_err() {
692            FrameError::ConnectionClosed => {}
693            e => panic!("Expected ConnectionClosed, got: {:?}", e),
694        }
695    }
696
697    #[tokio::test]
698    async fn test_write_read_multiple_frames() {
699        use crate::management_proto::{HealthCheckRequest, HealthCheckResponse};
700        use tokio::io::duplex;
701
702        let (mut writer, mut reader) = duplex(4096);
703
704        // Write multiple frames
705        let req = HealthCheckRequest {};
706        let resp = HealthCheckResponse {
707            healthy: true,
708            version: "1.0.0".to_string(),
709            uptime_ms: 100,
710            active_instances: 1,
711        };
712
713        let frame1 = Frame::request(&req).unwrap();
714        let frame2 = Frame::response(&resp).unwrap();
715
716        write_frame(&mut writer, &frame1).await.unwrap();
717        write_frame(&mut writer, &frame2).await.unwrap();
718        drop(writer); // Signal EOF
719
720        // Read back
721        let read1 = read_frame(&mut reader).await.unwrap();
722        let read2 = read_frame(&mut reader).await.unwrap();
723
724        assert_eq!(read1.message_type, MessageType::Request);
725        assert_eq!(read2.message_type, MessageType::Response);
726    }
727
728    // ========== FramedStream Tests ==========
729
730    #[test]
731    fn test_framed_stream_new() {
732        let stream = vec![0u8; 10];
733        let framed = FramedStream::new(stream);
734        let inner = framed.into_inner();
735        assert_eq!(inner.len(), 10);
736    }
737
738    #[test]
739    fn test_framed_stream_into_inner() {
740        let data = "test data".to_string();
741        let framed = FramedStream::new(data.clone());
742        let inner = framed.into_inner();
743        assert_eq!(inner, data);
744    }
745
746    #[tokio::test]
747    async fn test_framed_stream_read_write() {
748        use crate::management_proto::HealthCheckRequest;
749        use tokio::io::duplex;
750
751        let (writer, reader) = duplex(1024);
752        let mut writer_framed = FramedStream::new(writer);
753        let mut reader_framed = FramedStream::new(reader);
754
755        let msg = HealthCheckRequest {};
756        let frame = Frame::request(&msg).unwrap();
757
758        // Write and read through FramedStream
759        writer_framed.write_frame(&frame).await.unwrap();
760        drop(writer_framed); // Drop to signal EOF on the writing end
761
762        let read_frame = reader_framed.read_frame().await.unwrap();
763        assert_eq!(frame.message_type, read_frame.message_type);
764    }
765
766    // ========== Edge Cases and Boundary Tests ==========
767
768    #[test]
769    fn test_frame_at_max_size() {
770        use crate::instance_proto::CheckpointRequest;
771        // Create a frame just under the max size
772        let large_state = vec![0u8; MAX_FRAME_SIZE - 100]; // Leave room for other fields
773        let msg = CheckpointRequest {
774            instance_id: "i".to_string(),
775            checkpoint_id: "c".to_string(),
776            state: large_state,
777        };
778        let result = Frame::request(&msg);
779        // Should succeed if under limit
780        assert!(result.is_ok() || matches!(result, Err(FrameError::FrameTooLarge(_))));
781    }
782
783    #[test]
784    fn test_message_type_exhaustive_matching() {
785        // Ensure all message types can be matched
786        let types = vec![
787            MessageType::Request,
788            MessageType::Response,
789            MessageType::StreamStart,
790            MessageType::StreamData,
791            MessageType::StreamEnd,
792            MessageType::Error,
793        ];
794
795        for mt in types {
796            match mt {
797                MessageType::Request => assert_eq!(mt as u16, 1),
798                MessageType::Response => assert_eq!(mt as u16, 2),
799                MessageType::StreamStart => assert_eq!(mt as u16, 3),
800                MessageType::StreamData => assert_eq!(mt as u16, 4),
801                MessageType::StreamEnd => assert_eq!(mt as u16, 5),
802                MessageType::Error => assert_eq!(mt as u16, 6),
803            }
804        }
805    }
806
807    #[test]
808    fn test_frame_error_debug() {
809        let err = FrameError::FrameTooLarge(100);
810        let debug = format!("{:?}", err);
811        assert!(debug.contains("FrameTooLarge"));
812
813        let err = FrameError::InvalidMessageType(42);
814        let debug = format!("{:?}", err);
815        assert!(debug.contains("InvalidMessageType"));
816
817        let err = FrameError::ConnectionClosed;
818        let debug = format!("{:?}", err);
819        assert!(debug.contains("ConnectionClosed"));
820    }
821}