1use anyhow::{Result, bail};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use uuid::Uuid;
6
7const TAG_CONTROL: u8 = 0x01;
8const TAG_DATA: u8 = 0x02;
9const MAX_FRAME_SIZE: usize = 1_048_576; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(tag = "type")]
13pub enum ClientMessage {
14 CreateSession { shell: Option<String> },
15 ListSessions,
16 AttachSession { id: Uuid, cols: u16, rows: u16 },
17 DetachSession,
18 ResizeSession { id: Uuid, cols: u16, rows: u16 },
19 KillSession { id: Uuid },
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23#[serde(tag = "type")]
24pub enum ServerMessage {
25 SessionCreated { id: Uuid },
26 Sessions { sessions: Vec<SessionInfo> },
27 Attached { id: Uuid },
28 Detached,
29 SessionEnded { id: Uuid, exit_code: Option<i32> },
30 ClientJoined { session_id: Uuid, client_id: Uuid },
31 ClientLeft { session_id: Uuid, client_id: Uuid },
32 Error { message: String },
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
36pub struct SessionInfo {
37 pub id: Uuid,
38 pub cols: u16,
39 pub rows: u16,
40 pub created_at: DateTime<Utc>,
41 pub client_count: usize,
42}
43
44#[derive(Debug)]
45pub enum Frame {
46 Control(Vec<u8>),
47 Data(Vec<u8>),
48}
49
50pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
51 let len = (1 + payload.len()) as u32;
52 w.write_all(&len.to_be_bytes()).await?;
53 w.write_u8(TAG_CONTROL).await?;
54 w.write_all(payload).await?;
55 w.flush().await?;
56 Ok(())
57}
58
59pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
60 let len = (1 + payload.len()) as u32;
61 w.write_all(&len.to_be_bytes()).await?;
62 w.write_u8(TAG_DATA).await?;
63 w.write_all(payload).await?;
64 w.flush().await?;
65 Ok(())
66}
67
68pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
69 let mut len_buf = [0u8; 4];
70 match r.read_exact(&mut len_buf).await {
71 Ok(_) => {}
72 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
73 Err(e) => return Err(e.into()),
74 }
75 let len = u32::from_be_bytes(len_buf) as usize;
76 if len == 0 {
77 bail!("invalid frame: zero length");
78 }
79 if len > MAX_FRAME_SIZE {
80 bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
81 }
82 let tag = {
83 let mut tag_buf = [0u8; 1];
84 r.read_exact(&mut tag_buf).await?;
85 tag_buf[0]
86 };
87 let payload_len = len - 1;
88 let mut payload = vec![0u8; payload_len];
89 if payload_len > 0 {
90 r.read_exact(&mut payload).await?;
91 }
92 match tag {
93 TAG_CONTROL => Ok(Some(Frame::Control(payload))),
94 TAG_DATA => Ok(Some(Frame::Data(payload))),
95 other => bail!("unknown frame tag: 0x{:02x}", other),
96 }
97}
98
99pub async fn send_client_message<W: AsyncWrite + Unpin>(
101 w: &mut W,
102 msg: &ClientMessage,
103) -> Result<()> {
104 let json = serde_json::to_vec(msg)?;
105 write_control(w, &json).await
106}
107
108pub async fn send_server_message<W: AsyncWrite + Unpin>(
110 w: &mut W,
111 msg: &ServerMessage,
112) -> Result<()> {
113 let json = serde_json::to_vec(msg)?;
114 write_control(w, &json).await
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn serde_round_trip_client() {
123 let msgs = vec![
124 ClientMessage::CreateSession {
125 shell: Some("bash".into()),
126 },
127 ClientMessage::ListSessions,
128 ClientMessage::AttachSession {
129 id: Uuid::nil(),
130 cols: 120,
131 rows: 40,
132 },
133 ClientMessage::DetachSession,
134 ClientMessage::ResizeSession {
135 id: Uuid::nil(),
136 cols: 80,
137 rows: 24,
138 },
139 ClientMessage::KillSession { id: Uuid::nil() },
140 ];
141 for msg in msgs {
142 let json = serde_json::to_string(&msg).unwrap();
143 let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
144 assert_eq!(msg, decoded);
145 }
146 }
147
148 #[test]
149 fn serde_round_trip_server() {
150 let msgs = vec![
151 ServerMessage::SessionCreated { id: Uuid::nil() },
152 ServerMessage::Sessions {
153 sessions: vec![SessionInfo {
154 id: Uuid::nil(),
155 cols: 80,
156 rows: 24,
157 created_at: Utc::now(),
158 client_count: 2,
159 }],
160 },
161 ServerMessage::Attached { id: Uuid::nil() },
162 ServerMessage::Detached,
163 ServerMessage::SessionEnded {
164 id: Uuid::nil(),
165 exit_code: Some(0),
166 },
167 ServerMessage::ClientJoined {
168 session_id: Uuid::nil(),
169 client_id: Uuid::nil(),
170 },
171 ServerMessage::ClientLeft {
172 session_id: Uuid::nil(),
173 client_id: Uuid::nil(),
174 },
175 ServerMessage::Error {
176 message: "fail".into(),
177 },
178 ];
179 for msg in msgs {
180 let json = serde_json::to_string(&msg).unwrap();
181 let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
182 assert_eq!(msg, decoded);
183 }
184 }
185
186 #[tokio::test]
187 async fn frame_round_trip_control() {
188 let (mut client, mut server) = tokio::io::duplex(1024);
189 let payload = b"hello control";
190 write_control(&mut client, payload).await.unwrap();
191 drop(client);
192 let frame = read_frame(&mut server).await.unwrap().unwrap();
193 match frame {
194 Frame::Control(data) => assert_eq!(data, payload),
195 Frame::Data(_) => panic!("expected control frame"),
196 }
197 }
198
199 #[tokio::test]
200 async fn frame_round_trip_data() {
201 let (mut client, mut server) = tokio::io::duplex(1024);
202 let payload = b"hello data";
203 write_data(&mut client, payload).await.unwrap();
204 drop(client);
205 let frame = read_frame(&mut server).await.unwrap().unwrap();
206 match frame {
207 Frame::Data(data) => assert_eq!(data, payload),
208 Frame::Control(_) => panic!("expected data frame"),
209 }
210 }
211
212 #[tokio::test]
213 async fn frame_eof_returns_none() {
214 let (client, mut server) = tokio::io::duplex(1024);
215 drop(client);
216 let frame = read_frame(&mut server).await.unwrap();
217 assert!(frame.is_none());
218 }
219
220 #[tokio::test]
221 async fn frame_bad_tag() {
222 let (mut client, mut server) = tokio::io::duplex(1024);
223 let len: u32 = 2; client.write_all(&len.to_be_bytes()).await.unwrap();
226 client.write_u8(0xFF).await.unwrap();
227 client.write_u8(0x00).await.unwrap();
228 drop(client);
229 let result = read_frame(&mut server).await;
230 assert!(result.is_err());
231 assert!(
232 result
233 .unwrap_err()
234 .to_string()
235 .contains("unknown frame tag")
236 );
237 }
238
239 #[tokio::test]
240 async fn frame_too_large() {
241 let (mut client, mut server) = tokio::io::duplex(1024);
242 let len: u32 = 2 * 1024 * 1024;
244 client.write_all(&len.to_be_bytes()).await.unwrap();
245 drop(client);
246 let result = read_frame(&mut server).await;
247 assert!(result.is_err());
248 assert!(result.unwrap_err().to_string().contains("frame too large"));
249 }
250
251 #[tokio::test]
252 async fn send_client_message_round_trip() {
253 let (mut client, mut server) = tokio::io::duplex(4096);
254 let msg = ClientMessage::CreateSession {
255 shell: Some("zsh".into()),
256 };
257 send_client_message(&mut client, &msg).await.unwrap();
258 drop(client);
259 let frame = read_frame(&mut server).await.unwrap().unwrap();
260 match frame {
261 Frame::Control(data) => {
262 let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
263 assert_eq!(decoded, msg);
264 }
265 Frame::Data(_) => panic!("expected control frame"),
266 }
267 }
268}