Skip to main content

walrus_socket/
codec.rs

1//! Length-prefixed framing codec for Unix domain socket transport.
2//!
3//! Wire format: `[u32 BE length][JSON payload]`. The length is the byte count
4//! of the JSON payload only (not including the 4-byte header).
5
6use serde::{Serialize, de::DeserializeOwned};
7use std::io;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9
10/// Maximum frame size: 16 MiB.
11const MAX_FRAME_SIZE: u32 = 16 * 1024 * 1024;
12
13/// Errors that can occur during frame read/write.
14#[derive(Debug)]
15pub enum FrameError {
16    /// Underlying I/O error.
17    Io(io::Error),
18    /// Frame exceeds the maximum allowed size.
19    TooLarge { size: u32 },
20    /// JSON serialization/deserialization error.
21    Json(serde_json::Error),
22    /// The connection was closed (EOF during read).
23    ConnectionClosed,
24}
25
26impl std::fmt::Display for FrameError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            Self::Io(e) => write!(f, "io error: {e}"),
30            Self::TooLarge { size } => {
31                write!(f, "frame too large: {size} bytes (max {MAX_FRAME_SIZE})")
32            }
33            Self::Json(e) => write!(f, "json error: {e}"),
34            Self::ConnectionClosed => write!(f, "connection closed"),
35        }
36    }
37}
38
39impl std::error::Error for FrameError {
40    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
41        match self {
42            Self::Io(e) => Some(e),
43            Self::Json(e) => Some(e),
44            _ => None,
45        }
46    }
47}
48
49impl From<io::Error> for FrameError {
50    fn from(e: io::Error) -> Self {
51        Self::Io(e)
52    }
53}
54
55impl From<serde_json::Error> for FrameError {
56    fn from(e: serde_json::Error) -> Self {
57        Self::Json(e)
58    }
59}
60
61/// Write a typed message as a length-prefixed JSON frame.
62pub async fn write_message<W, T>(writer: &mut W, msg: &T) -> Result<(), FrameError>
63where
64    W: tokio::io::AsyncWrite + Unpin,
65    T: Serialize,
66{
67    let data = serde_json::to_vec(msg)?;
68    let len = data.len() as u32;
69    if len > MAX_FRAME_SIZE {
70        return Err(FrameError::TooLarge { size: len });
71    }
72    writer.write_all(&len.to_be_bytes()).await?;
73    writer.write_all(&data).await?;
74    writer.flush().await?;
75    Ok(())
76}
77
78/// Read a length-prefixed JSON frame and deserialize into a typed message.
79pub async fn read_message<R, T>(reader: &mut R) -> Result<T, FrameError>
80where
81    R: tokio::io::AsyncRead + Unpin,
82    T: DeserializeOwned,
83{
84    let mut len_buf = [0u8; 4];
85    match reader.read_exact(&mut len_buf).await {
86        Ok(_) => {}
87        Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
88            return Err(FrameError::ConnectionClosed);
89        }
90        Err(e) => return Err(FrameError::Io(e)),
91    }
92
93    let len = u32::from_be_bytes(len_buf);
94    if len > MAX_FRAME_SIZE {
95        return Err(FrameError::TooLarge { size: len });
96    }
97
98    let mut buf = vec![0u8; len as usize];
99    reader.read_exact(&mut buf).await?;
100    let msg = serde_json::from_slice(&buf)?;
101    Ok(msg)
102}