Skip to main content

walrus_core/protocol/
codec.rs

1//! Length-prefixed framing codec for walrus wire protocol.
2//!
3//! Wire format: `[u32 BE length][protobuf payload]`. The length is the byte
4//! count of the payload only (not including the 4-byte header). Generic over
5//! `AsyncRead`/`AsyncWrite` — used by both UDS and TCP transports.
6
7use prost::Message;
8use std::io;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10
11/// Maximum frame size: 16 MiB.
12const MAX_FRAME_SIZE: u32 = 16 * 1024 * 1024;
13
14/// Errors that can occur during frame read/write.
15#[derive(Debug)]
16pub enum FrameError {
17    /// Underlying I/O error.
18    Io(io::Error),
19    /// Frame exceeds the maximum allowed size.
20    TooLarge { size: u32 },
21    /// Protobuf serialization/deserialization error.
22    Codec(String),
23    /// The connection was closed (EOF during read).
24    ConnectionClosed,
25}
26
27impl std::fmt::Display for FrameError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            Self::Io(e) => write!(f, "io error: {e}"),
31            Self::TooLarge { size } => {
32                write!(f, "frame too large: {size} bytes (max {MAX_FRAME_SIZE})")
33            }
34            Self::Codec(e) => write!(f, "codec error: {e}"),
35            Self::ConnectionClosed => write!(f, "connection closed"),
36        }
37    }
38}
39
40impl std::error::Error for FrameError {
41    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
42        match self {
43            Self::Io(e) => Some(e),
44            _ => None,
45        }
46    }
47}
48
49impl From<io::Error> for FrameError {
50    fn from(e: io::Error) -> Self {
51        Self::Io(e)
52    }
53}
54
55impl From<prost::DecodeError> for FrameError {
56    fn from(e: prost::DecodeError) -> Self {
57        Self::Codec(e.to_string())
58    }
59}
60
61/// Write a typed message as a length-prefixed protobuf frame.
62pub async fn write_message<W, T>(writer: &mut W, msg: &T) -> Result<(), FrameError>
63where
64    W: tokio::io::AsyncWrite + Unpin,
65    T: Message,
66{
67    let data = msg.encode_to_vec();
68    let len = data.len() as u32;
69    if len > MAX_FRAME_SIZE {
70        return Err(FrameError::TooLarge { size: len });
71    }
72    writer.write_all(&len.to_be_bytes()).await?;
73    writer.write_all(&data).await?;
74    writer.flush().await?;
75    Ok(())
76}
77
78/// Read a length-prefixed protobuf frame and deserialize into a typed message.
79pub async fn read_message<R, T>(reader: &mut R) -> Result<T, FrameError>
80where
81    R: tokio::io::AsyncRead + Unpin,
82    T: Message + Default,
83{
84    let mut len_buf = [0u8; 4];
85    match reader.read_exact(&mut len_buf).await {
86        Ok(_) => {}
87        Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
88            return Err(FrameError::ConnectionClosed);
89        }
90        Err(e) => return Err(FrameError::Io(e)),
91    }
92
93    let len = u32::from_be_bytes(len_buf);
94    if len > MAX_FRAME_SIZE {
95        return Err(FrameError::TooLarge { size: len });
96    }
97
98    let mut buf = vec![0u8; len as usize];
99    reader.read_exact(&mut buf).await?;
100    let msg = T::decode(&buf[..])?;
101    Ok(msg)
102}