steamroom_cli/daemon/
framing.rs1use rkyv::rancor;
13use rkyv::util::AlignedVec;
14use tokio::io::AsyncReadExt;
15use tokio::io::AsyncWriteExt;
16
17use crate::daemon::proto::Frame;
18use crate::daemon::proto::PROTO_VERSION;
19use crate::errors::CliError;
20
21pub const MAX_FRAME_BYTES: u32 = 16 * 1024 * 1024;
22
23async fn read_exact_or_closed<R>(r: &mut R, buf: &mut [u8]) -> Result<(), CliError>
28where
29 R: AsyncReadExt + Unpin,
30{
31 match r.read_exact(buf).await {
32 Ok(_) => Ok(()),
33 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Err(CliError::SocketClosed),
34 Err(e) => Err(CliError::Io(e)),
35 }
36}
37
38pub async fn write_frame<W>(w: &mut W, frame: &Frame) -> Result<(), CliError>
39where
40 W: AsyncWriteExt + Unpin,
41{
42 let bytes = rkyv::to_bytes::<rancor::Error>(frame)
43 .map_err(|e| CliError::MalformedFrame(e.to_string()))?;
44 let len_usize = bytes.len();
45 if len_usize > MAX_FRAME_BYTES as usize {
46 return Err(CliError::FrameTooLarge {
47 len_bytes: len_usize as u64,
48 limit_bytes: MAX_FRAME_BYTES as u64,
49 });
50 }
51 let len: u32 = len_usize as u32;
53 w.write_all(&PROTO_VERSION.to_le_bytes())
54 .await
55 .map_err(CliError::Io)?;
56 w.write_all(&len.to_le_bytes())
57 .await
58 .map_err(CliError::Io)?;
59 w.write_all(&bytes).await.map_err(CliError::Io)?;
60 w.flush().await.map_err(CliError::Io)?;
61 Ok(())
62}
63
64pub async fn read_frame<R>(r: &mut R) -> Result<Frame, CliError>
75where
76 R: AsyncReadExt + Unpin,
77{
78 let mut ver_buf = [0u8; 2];
79 read_exact_or_closed(r, &mut ver_buf).await?;
80 let peer = u16::from_le_bytes(ver_buf);
81 if peer != PROTO_VERSION {
82 return Err(CliError::ProtocolVersionMismatch {
83 peer,
84 ours: PROTO_VERSION,
85 });
86 }
87
88 let mut len_buf = [0u8; 4];
89 read_exact_or_closed(r, &mut len_buf).await?;
90 let len = u32::from_le_bytes(len_buf);
91 if len > MAX_FRAME_BYTES {
92 return Err(CliError::FrameTooLarge {
93 len_bytes: len as u64,
94 limit_bytes: MAX_FRAME_BYTES as u64,
95 });
96 }
97
98 let mut buf = AlignedVec::<16>::with_capacity(len as usize);
100 buf.resize(len as usize, 0);
101 read_exact_or_closed(r, &mut buf).await?;
102 rkyv::from_bytes::<Frame, rancor::Error>(&buf)
103 .map_err(|e| CliError::MalformedFrame(e.to_string()))
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::daemon::proto::Event;
110 use crate::daemon::proto::JobId;
111 use crate::daemon::proto::LogLevel;
112 use crate::daemon::proto::Response;
113 use tokio::io::duplex;
114
115 #[tokio::test]
116 async fn round_trip_through_duplex() {
117 let (mut a, mut b) = duplex(64 * 1024);
118 let frame = Frame::Response(Response::Stopping);
119
120 write_frame(&mut a, &frame).await.unwrap();
121 let back = read_frame(&mut b).await.unwrap();
122 assert!(matches!(back, Frame::Response(Response::Stopping)));
123 }
124
125 #[tokio::test]
126 async fn rejects_mismatched_version() {
127 let (mut a, mut b) = duplex(64);
128 a.write_all(&999u16.to_le_bytes()).await.unwrap();
130 a.write_all(&0u32.to_le_bytes()).await.unwrap();
131 a.flush().await.unwrap();
132 let err = read_frame(&mut b).await.unwrap_err();
133 match err {
134 CliError::ProtocolVersionMismatch { peer, ours } => {
135 assert_eq!(peer, 999);
136 assert_eq!(ours, PROTO_VERSION);
137 }
138 other => panic!("wrong error: {other:?}"),
139 }
140 }
141
142 #[tokio::test]
143 async fn rejects_oversized_length() {
144 let (mut a, mut b) = duplex(64);
145 a.write_all(&PROTO_VERSION.to_le_bytes()).await.unwrap();
146 a.write_all(&(MAX_FRAME_BYTES + 1).to_le_bytes())
147 .await
148 .unwrap();
149 a.flush().await.unwrap();
150 let err = read_frame(&mut b).await.unwrap_err();
151 assert!(matches!(err, CliError::FrameTooLarge { .. }));
152 }
153
154 #[tokio::test]
155 async fn event_with_log_round_trips() {
156 let (mut a, mut b) = duplex(64 * 1024);
157 let frame = Frame::Event(Event::Log {
158 job_id: Some(JobId(1)),
159 level: LogLevel::Info,
160 target: "t".into(),
161 message: "hello".into(),
162 });
163 write_frame(&mut a, &frame).await.unwrap();
164 let back = read_frame(&mut b).await.unwrap();
165 match back {
166 Frame::Event(Event::Log { message, .. }) => assert_eq!(message, "hello"),
167 other => panic!("wrong: {other:?}"),
168 }
169 }
170}