walrus_core/protocol/
codec.rs1use prost::Message;
8use std::io;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10
11const MAX_FRAME_SIZE: u32 = 16 * 1024 * 1024;
13
14#[derive(Debug)]
16pub enum FrameError {
17 Io(io::Error),
19 TooLarge { size: u32 },
21 Codec(String),
23 ConnectionClosed,
25}
26
27impl std::fmt::Display for FrameError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 Self::Io(e) => write!(f, "io error: {e}"),
31 Self::TooLarge { size } => {
32 write!(f, "frame too large: {size} bytes (max {MAX_FRAME_SIZE})")
33 }
34 Self::Codec(e) => write!(f, "codec error: {e}"),
35 Self::ConnectionClosed => write!(f, "connection closed"),
36 }
37 }
38}
39
40impl std::error::Error for FrameError {
41 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
42 match self {
43 Self::Io(e) => Some(e),
44 _ => None,
45 }
46 }
47}
48
49impl From<io::Error> for FrameError {
50 fn from(e: io::Error) -> Self {
51 Self::Io(e)
52 }
53}
54
55impl From<prost::DecodeError> for FrameError {
56 fn from(e: prost::DecodeError) -> Self {
57 Self::Codec(e.to_string())
58 }
59}
60
61pub async fn write_message<W, T>(writer: &mut W, msg: &T) -> Result<(), FrameError>
63where
64 W: tokio::io::AsyncWrite + Unpin,
65 T: Message,
66{
67 let data = msg.encode_to_vec();
68 let len = data.len() as u32;
69 if len > MAX_FRAME_SIZE {
70 return Err(FrameError::TooLarge { size: len });
71 }
72 writer.write_all(&len.to_be_bytes()).await?;
73 writer.write_all(&data).await?;
74 writer.flush().await?;
75 Ok(())
76}
77
78pub async fn read_message<R, T>(reader: &mut R) -> Result<T, FrameError>
80where
81 R: tokio::io::AsyncRead + Unpin,
82 T: Message + Default,
83{
84 let mut len_buf = [0u8; 4];
85 match reader.read_exact(&mut len_buf).await {
86 Ok(_) => {}
87 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
88 return Err(FrameError::ConnectionClosed);
89 }
90 Err(e) => return Err(FrameError::Io(e)),
91 }
92
93 let len = u32::from_be_bytes(len_buf);
94 if len > MAX_FRAME_SIZE {
95 return Err(FrameError::TooLarge { size: len });
96 }
97
98 let mut buf = vec![0u8; len as usize];
99 reader.read_exact(&mut buf).await?;
100 let msg = T::decode(&buf[..])?;
101 Ok(msg)
102}