1use anyhow::{Result, bail};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use uuid::Uuid;
6
7const TAG_CONTROL: u8 = 0x01;
8const TAG_DATA: u8 = 0x02;
9const MAX_FRAME_SIZE: usize = 1_048_576; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(tag = "type")]
13pub enum ClientMessage {
14 Authenticate { token: String },
15 CreateSession { shell: Option<String> },
16 ListSessions,
17 AttachSession { id: Uuid },
18 DetachSession,
19 ResizeSession { id: Uuid, cols: u16, rows: u16 },
20 KillSession { id: Uuid },
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
24#[serde(tag = "type")]
25pub enum ServerMessage {
26 Authenticated,
27 SessionCreated { id: Uuid },
28 Sessions { sessions: Vec<SessionInfo> },
29 Attached { id: Uuid },
30 Detached,
31 SessionEnded { id: Uuid, exit_code: Option<i32> },
32 ClientJoined { session_id: Uuid, client_id: Uuid },
33 ClientLeft { session_id: Uuid, client_id: Uuid },
34 Error { message: String },
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
38pub struct SessionInfo {
39 pub id: Uuid,
40 pub cols: u16,
41 pub rows: u16,
42 pub created_at: DateTime<Utc>,
43 pub client_count: usize,
44}
45
46#[derive(Debug)]
47pub enum Frame {
48 Control(Vec<u8>),
49 Data(Vec<u8>),
50}
51
52pub async fn write_control<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
53 let len = (1 + payload.len()) as u32;
54 w.write_all(&len.to_be_bytes()).await?;
55 w.write_u8(TAG_CONTROL).await?;
56 w.write_all(payload).await?;
57 w.flush().await?;
58 Ok(())
59}
60
61pub async fn write_data<W: AsyncWrite + Unpin>(w: &mut W, payload: &[u8]) -> Result<()> {
62 let len = (1 + payload.len()) as u32;
63 w.write_all(&len.to_be_bytes()).await?;
64 w.write_u8(TAG_DATA).await?;
65 w.write_all(payload).await?;
66 w.flush().await?;
67 Ok(())
68}
69
70pub async fn read_frame<R: AsyncRead + Unpin>(r: &mut R) -> Result<Option<Frame>> {
71 let mut len_buf = [0u8; 4];
72 match r.read_exact(&mut len_buf).await {
73 Ok(_) => {}
74 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
75 Err(e) => return Err(e.into()),
76 }
77 let len = u32::from_be_bytes(len_buf) as usize;
78 if len == 0 {
79 bail!("invalid frame: zero length");
80 }
81 if len > MAX_FRAME_SIZE {
82 bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
83 }
84 let tag = {
85 let mut tag_buf = [0u8; 1];
86 r.read_exact(&mut tag_buf).await?;
87 tag_buf[0]
88 };
89 let payload_len = len - 1;
90 let mut payload = vec![0u8; payload_len];
91 if payload_len > 0 {
92 r.read_exact(&mut payload).await?;
93 }
94 match tag {
95 TAG_CONTROL => Ok(Some(Frame::Control(payload))),
96 TAG_DATA => Ok(Some(Frame::Data(payload))),
97 other => bail!("unknown frame tag: 0x{:02x}", other),
98 }
99}
100
101pub async fn send_client_message<W: AsyncWrite + Unpin>(
103 w: &mut W,
104 msg: &ClientMessage,
105) -> Result<()> {
106 let json = serde_json::to_vec(msg)?;
107 write_control(w, &json).await
108}
109
110pub async fn send_server_message<W: AsyncWrite + Unpin>(
112 w: &mut W,
113 msg: &ServerMessage,
114) -> Result<()> {
115 let json = serde_json::to_vec(msg)?;
116 write_control(w, &json).await
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn serde_round_trip_client() {
125 let msgs = vec![
126 ClientMessage::Authenticate {
127 token: "test-token".into(),
128 },
129 ClientMessage::CreateSession {
130 shell: Some("bash".into()),
131 },
132 ClientMessage::ListSessions,
133 ClientMessage::AttachSession { id: Uuid::nil() },
134 ClientMessage::DetachSession,
135 ClientMessage::ResizeSession {
136 id: Uuid::nil(),
137 cols: 80,
138 rows: 24,
139 },
140 ClientMessage::KillSession { id: Uuid::nil() },
141 ];
142 for msg in msgs {
143 let json = serde_json::to_string(&msg).unwrap();
144 let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
145 assert_eq!(msg, decoded);
146 }
147 }
148
149 #[test]
150 fn serde_round_trip_server() {
151 let msgs = vec![
152 ServerMessage::Authenticated,
153 ServerMessage::SessionCreated { id: Uuid::nil() },
154 ServerMessage::Sessions {
155 sessions: vec![SessionInfo {
156 id: Uuid::nil(),
157 cols: 80,
158 rows: 24,
159 created_at: Utc::now(),
160 client_count: 2,
161 }],
162 },
163 ServerMessage::Attached { id: Uuid::nil() },
164 ServerMessage::Detached,
165 ServerMessage::SessionEnded {
166 id: Uuid::nil(),
167 exit_code: Some(0),
168 },
169 ServerMessage::ClientJoined {
170 session_id: Uuid::nil(),
171 client_id: Uuid::nil(),
172 },
173 ServerMessage::ClientLeft {
174 session_id: Uuid::nil(),
175 client_id: Uuid::nil(),
176 },
177 ServerMessage::Error {
178 message: "fail".into(),
179 },
180 ];
181 for msg in msgs {
182 let json = serde_json::to_string(&msg).unwrap();
183 let decoded: ServerMessage = serde_json::from_str(&json).unwrap();
184 assert_eq!(msg, decoded);
185 }
186 }
187
188 #[tokio::test]
189 async fn frame_round_trip_control() {
190 let (mut client, mut server) = tokio::io::duplex(1024);
191 let payload = b"hello control";
192 write_control(&mut client, payload).await.unwrap();
193 drop(client);
194 let frame = read_frame(&mut server).await.unwrap().unwrap();
195 match frame {
196 Frame::Control(data) => assert_eq!(data, payload),
197 Frame::Data(_) => panic!("expected control frame"),
198 }
199 }
200
201 #[tokio::test]
202 async fn frame_round_trip_data() {
203 let (mut client, mut server) = tokio::io::duplex(1024);
204 let payload = b"hello data";
205 write_data(&mut client, payload).await.unwrap();
206 drop(client);
207 let frame = read_frame(&mut server).await.unwrap().unwrap();
208 match frame {
209 Frame::Data(data) => assert_eq!(data, payload),
210 Frame::Control(_) => panic!("expected data frame"),
211 }
212 }
213
214 #[tokio::test]
215 async fn frame_eof_returns_none() {
216 let (client, mut server) = tokio::io::duplex(1024);
217 drop(client);
218 let frame = read_frame(&mut server).await.unwrap();
219 assert!(frame.is_none());
220 }
221
222 #[tokio::test]
223 async fn frame_bad_tag() {
224 let (mut client, mut server) = tokio::io::duplex(1024);
225 let len: u32 = 2; client.write_all(&len.to_be_bytes()).await.unwrap();
228 client.write_u8(0xFF).await.unwrap();
229 client.write_u8(0x00).await.unwrap();
230 drop(client);
231 let result = read_frame(&mut server).await;
232 assert!(result.is_err());
233 assert!(
234 result
235 .unwrap_err()
236 .to_string()
237 .contains("unknown frame tag")
238 );
239 }
240
241 #[tokio::test]
242 async fn frame_too_large() {
243 let (mut client, mut server) = tokio::io::duplex(1024);
244 let len: u32 = 2 * 1024 * 1024;
246 client.write_all(&len.to_be_bytes()).await.unwrap();
247 drop(client);
248 let result = read_frame(&mut server).await;
249 assert!(result.is_err());
250 assert!(result.unwrap_err().to_string().contains("frame too large"));
251 }
252
253 #[tokio::test]
254 async fn send_client_message_round_trip() {
255 let (mut client, mut server) = tokio::io::duplex(4096);
256 let msg = ClientMessage::CreateSession {
257 shell: Some("zsh".into()),
258 };
259 send_client_message(&mut client, &msg).await.unwrap();
260 drop(client);
261 let frame = read_frame(&mut server).await.unwrap().unwrap();
262 match frame {
263 Frame::Control(data) => {
264 let decoded: ClientMessage = serde_json::from_slice(&data).unwrap();
265 assert_eq!(decoded, msg);
266 }
267 Frame::Data(_) => panic!("expected control frame"),
268 }
269 }
270}