1use serde::{Serialize, de::DeserializeOwned};
7use std::io;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9
10const MAX_FRAME_SIZE: u32 = 16 * 1024 * 1024;
12
13#[derive(Debug)]
15pub enum FrameError {
16 Io(io::Error),
18 TooLarge { size: u32 },
20 Json(serde_json::Error),
22 ConnectionClosed,
24}
25
26impl std::fmt::Display for FrameError {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 Self::Io(e) => write!(f, "io error: {e}"),
30 Self::TooLarge { size } => {
31 write!(f, "frame too large: {size} bytes (max {MAX_FRAME_SIZE})")
32 }
33 Self::Json(e) => write!(f, "json error: {e}"),
34 Self::ConnectionClosed => write!(f, "connection closed"),
35 }
36 }
37}
38
39impl std::error::Error for FrameError {
40 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
41 match self {
42 Self::Io(e) => Some(e),
43 Self::Json(e) => Some(e),
44 _ => None,
45 }
46 }
47}
48
49impl From<io::Error> for FrameError {
50 fn from(e: io::Error) -> Self {
51 Self::Io(e)
52 }
53}
54
55impl From<serde_json::Error> for FrameError {
56 fn from(e: serde_json::Error) -> Self {
57 Self::Json(e)
58 }
59}
60
61pub async fn write_message<W, T>(writer: &mut W, msg: &T) -> Result<(), FrameError>
63where
64 W: tokio::io::AsyncWrite + Unpin,
65 T: Serialize,
66{
67 let data = serde_json::to_vec(msg)?;
68 let len = data.len() as u32;
69 if len > MAX_FRAME_SIZE {
70 return Err(FrameError::TooLarge { size: len });
71 }
72 writer.write_all(&len.to_be_bytes()).await?;
73 writer.write_all(&data).await?;
74 writer.flush().await?;
75 Ok(())
76}
77
78pub async fn read_message<R, T>(reader: &mut R) -> Result<T, FrameError>
80where
81 R: tokio::io::AsyncRead + Unpin,
82 T: DeserializeOwned,
83{
84 let mut len_buf = [0u8; 4];
85 match reader.read_exact(&mut len_buf).await {
86 Ok(_) => {}
87 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
88 return Err(FrameError::ConnectionClosed);
89 }
90 Err(e) => return Err(FrameError::Io(e)),
91 }
92
93 let len = u32::from_be_bytes(len_buf);
94 if len > MAX_FRAME_SIZE {
95 return Err(FrameError::TooLarge { size: len });
96 }
97
98 let mut buf = vec![0u8; len as usize];
99 reader.read_exact(&mut buf).await?;
100 let msg = serde_json::from_slice(&buf)?;
101 Ok(msg)
102}